Skip to content

Beam Search

Beam search is a decoding algorithm for autoregressive sequence models. It is used when a model must generate a sequence, but greedy decoding is too narrow.

Beam search is a decoding algorithm for autoregressive sequence models. It is used when a model must generate a sequence, but greedy decoding is too narrow.

In greedy decoding, the model chooses the most likely token at every step:

y^t=argmaxkp(yt=ky^<t,x). \hat{y}_t = \arg\max_k p(y_t = k \mid \hat{y}_{<t}, x).

This can fail because a locally best token may lead to a poor full sequence. Beam search keeps several partial sequences at once. These partial sequences are called beams.

Sequence Probability

A sequence-to-sequence model assigns probability to a target sequence as

p(yx)=t=1Tp(yty<t,x). p(y \mid x) = \prod_{t=1}^{T} p(y_t \mid y_{<t}, x).

Because products of many probabilities become very small, decoding usually works with log probabilities:

logp(yx)=t=1Tlogp(yty<t,x). \log p(y \mid x) = \sum_{t=1}^{T} \log p(y_t \mid y_{<t}, x).

Beam search tries to find a high-scoring sequence under this log-probability objective.

Greedy Search

Greedy search keeps only one hypothesis. At each step, it chooses the best next token.

next_token = logits[:, -1, :].argmax(dim=-1)

This is simple and fast. It has beam size 1.

Greedy search can miss better sequences. Suppose the best first token has probability 0.40, but it leads to weak continuations. Another token with probability 0.35 may lead to a much better final sequence. Greedy search discards the second token immediately.

Beam search avoids this by keeping the top KK partial sequences.

Beam Width

The beam width KK controls how many partial sequences are kept.

If K=1K=1, beam search becomes greedy search. If KK is large, search becomes broader but more expensive.

At each step, every current beam is expanded by possible next tokens. If the vocabulary has size VV, then KK beams produce K×VK \times V candidate continuations. Beam search keeps only the best KK candidates.

For example, with K=3K=3, the decoder keeps the three best partial sequences after every generation step.

BeamPartial sequenceScore
1<bos> I-0.2
2<bos> The-0.4
3<bos> A-0.9

After expansion, each beam proposes many next tokens. The algorithm ranks all continuations and keeps the top three.

Beam Search Algorithm

Beam search maintains a set of active hypotheses. Each hypothesis contains a token sequence and a score.

The basic algorithm is:

  1. Start with one hypothesis containing <bos> and score 0.
  2. Expand each active hypothesis with possible next tokens.
  3. Add the log probability of each next token to the hypothesis score.
  4. Keep the best KK hypotheses.
  5. Move hypotheses ending in <eos> to the completed set.
  6. Stop when enough completed hypotheses exist or the maximum length is reached.

The score for extending a hypothesis is

score(y1:t)=i=1tlogp(yiy<i,x). \text{score}(y_{1:t}) = \sum_{i=1}^{t} \log p(y_i \mid y_{<i}, x).

A completed sequence is usually one that ends with <eos>.

Length Bias

Raw log probability favors short sequences.

Each log probability is less than or equal to zero. Adding more tokens usually makes the total score smaller. As a result, beam search may prefer short outputs.

For example:

logp(<eos>x)=0.5 \log p(\texttt{<eos>} \mid x) = -0.5

may beat

logp(the cat sat <eos>x)=2.4. \log p(\texttt{the cat sat <eos>} \mid x) = -2.4.

The second sequence may be better, but it receives more negative terms.

A common fix is length normalization:

score(y)=t=1Tlogp(yty<t,x)Tα. \text{score}(y) = \frac{\sum_{t=1}^{T} \log p(y_t \mid y_{<t}, x)} {T^\alpha}.

The hyperparameter α\alpha controls how strongly length affects the score. If α=0\alpha=0, there is no normalization. If α=1\alpha=1, the score is average log probability per token.

Coverage and Repetition Problems

Beam search can produce repeated phrases, especially in summarization and translation.

Examples:

the company said the company said the company said

This happens because locally probable continuations can reinforce themselves.

Some systems add penalties for repeated n-grams. An n-gram is a sequence of nn consecutive tokens. A no-repeat trigram constraint prevents the model from generating the same three-token sequence twice.

Another problem is incomplete coverage. In translation, the model may ignore part of the source sentence. Attention-based systems sometimes add coverage penalties to encourage the decoder to attend to all important source tokens.

These penalties are decoding heuristics. They modify search behavior without changing the trained model.

Beam Search in PyTorch

A minimal single-example beam search implementation can be written as follows.

import torch
import torch.nn.functional as F

@torch.no_grad()
def beam_search(
    model,
    src_tokens,
    bos_id,
    eos_id,
    beam_size=4,
    max_len=64,
    length_alpha=0.7,
):
    device = src_tokens.device

    beams = [
        {
            "tokens": torch.tensor([[bos_id]], device=device),
            "score": 0.0,
            "finished": False,
        }
    ]

    completed = []

    for _ in range(max_len):
        candidates = []

        for beam in beams:
            if beam["finished"]:
                candidates.append(beam)
                continue

            logits = model(src_tokens, beam["tokens"])
            next_logits = logits[:, -1, :]
            log_probs = F.log_softmax(next_logits, dim=-1)

            top_log_probs, top_tokens = torch.topk(log_probs, beam_size, dim=-1)

            for log_prob, token in zip(top_log_probs[0], top_tokens[0]):
                token = token.view(1, 1)

                new_tokens = torch.cat([beam["tokens"], token], dim=1)
                new_score = beam["score"] + log_prob.item()

                candidate = {
                    "tokens": new_tokens,
                    "score": new_score,
                    "finished": token.item() == eos_id,
                }

                candidates.append(candidate)

        def normalized_score(candidate):
            length = candidate["tokens"].size(1) - 1
            length = max(length, 1)
            return candidate["score"] / (length ** length_alpha)

        candidates = sorted(candidates, key=normalized_score, reverse=True)

        beams = candidates[:beam_size]

        completed.extend([b for b in beams if b["finished"]])

        if len(completed) >= beam_size:
            break

    final_candidates = completed if completed else beams
    final_candidates = sorted(final_candidates, key=normalized_score, reverse=True)

    return final_candidates[0]["tokens"]

This version assumes src_tokens contains one source example. Batched beam search is more complex because each example maintains its own beam set.

Batched Beam Search

Production systems usually use batched beam search. Instead of decoding one example at a time, the system expands beams for many examples in parallel.

For an input batch of size BB and beam size KK, the source representations are usually repeated KK times. The decoder then processes B×KB \times K active hypotheses.

The token tensor may have shape

[B * K, T]

The score tensor may have shape

[B, K]

At each step, the model returns logits of shape

[B * K, V]

These logits are reshaped to

[B, K, V]

Then the algorithm selects the best KK candidates from the flattened dimension K×VK \times V.

scores = beam_scores[:, :, None] + log_probs
scores = scores.reshape(B, K * V)

top_scores, top_indices = torch.topk(scores, K, dim=-1)

next_beam = top_indices // V
next_token = top_indices % V

This is the core of vectorized beam search.

Beam Search and Attention Models

In encoder-decoder transformers, the encoder output is computed once. During beam search, each beam uses the same source representation but has a different target prefix.

For efficiency, the encoder output is expanded across beams:

encoder_out = encoder_out.repeat_interleave(beam_size, dim=0)

The decoder then runs on all beams.

In large transformer decoders, caching is important. Instead of recomputing all decoder states at every step, the model stores key-value tensors from previous layers. This reduces decoding cost from repeatedly processing the full prefix.

With caching, each step processes only the new token, while attention can still use previous cached states.

Beam Search Versus Sampling

Beam search is deterministic when model scores are deterministic. It tries to find a high-probability output.

Sampling methods are stochastic. They draw tokens from a probability distribution. Common sampling methods include temperature sampling, top-k sampling, and nucleus sampling.

Beam search is often used for tasks where accuracy and faithfulness matter, such as translation and speech recognition. Sampling is often used for open-ended generation, such as dialogue, story generation, and creative writing.

MethodBehaviorCommon use
Greedy searchFast, narrow, deterministicSimple generation
Beam searchBroader, deterministicTranslation, summarization, ASR
Top-k samplingRandom within top candidatesOpen-ended generation
Nucleus samplingRandom within probability massChat and creative text
Temperature samplingControls randomnessStyle and diversity

Choosing Beam Size

A larger beam size does not always give better outputs.

For machine translation, small values such as 4 or 5 are often effective. Very large beams may produce dull or overly short translations unless length penalties are tuned.

For summarization, larger beams may increase repetition or generic phrasing. For dialogue, beam search often performs poorly because the highest-probability response may be bland.

A practical starting point is:

TaskCommon beam size
Translation4 to 8
Summarization4 to 6
Speech recognition8 to 16
Code generation1 to 5
Open-ended dialogueUsually sampling instead

The best value depends on the model, task, length penalty, and evaluation metric.

Failure Modes

Beam search can fail in predictable ways.

It may prefer short outputs because sequence log probabilities decrease with length. Length normalization or length penalties address this.

It may produce generic outputs because high-probability text is often safe and common.

It may repeat phrases when local token probabilities reinforce loops.

It may reduce diversity because beams often share the same prefix. Diverse beam search modifies the score to encourage different beams.

It may amplify model miscalibration. If the model assigns too much probability to bad continuations, search can make the bad preference more visible.

Summary

Beam search is a practical decoding algorithm for sequence-to-sequence models. It keeps the best KK partial hypotheses instead of only one. This allows the decoder to recover from locally attractive but globally poor choices.

The core score is the sum of log probabilities over generated tokens. In practice, length normalization, repetition penalties, coverage penalties, and caching are often needed.

Beam search is useful for translation, summarization, and speech recognition. For open-ended generation, sampling methods are often better because they preserve diversity.