Skip to content

Transformer Decoders

A transformer decoder is a neural network block that maps a prefix sequence to a sequence of next-token representations. It is used when the model must generate output one step at a time.

A transformer decoder is a neural network block that maps a prefix sequence to a sequence of next-token representations. It is used when the model must generate output one step at a time.

Decoder-only transformers are the core architecture behind GPT-style language models. Encoder-decoder transformers also use decoder blocks, but those decoders include an additional cross-attention sublayer that reads encoder outputs.

The Decoder Problem

Suppose we have a sequence of tokens

x1,x2,,xT. x_1, x_2, \ldots, x_T.

A decoder learns to predict each next token from the previous tokens:

p(xtx1,x2,,xt1). p(x_t \mid x_1, x_2, \ldots, x_{t-1}).

For a full sequence, the model factorizes the probability as

p(x1,,xT)=t=1Tp(xtx<t). p(x_1,\ldots,x_T) = \prod_{t=1}^{T} p(x_t \mid x_{<t}).

This is called autoregressive modeling.

The input to the decoder is usually a batch of token IDs:

tokensNB×T. \text{tokens}\in\mathbb{N}^{B\times T}.

After embedding, the decoder processes

XRB×T×D. X\in\mathbb{R}^{B\times T\times D}.

The output is

HRB×T×D. H\in\mathbb{R}^{B\times T\times D}.

Each output position is then projected to vocabulary logits:

Z=HWvocab+b, Z = HW_{\text{vocab}} + b,

where

ZRB×T×V. Z\in\mathbb{R}^{B\times T\times V}.

Here VV is the vocabulary size.

Causal Self-Attention

The defining feature of a transformer decoder is causal self-attention. Position tt may attend only to positions 1,,t1,\ldots,t. It may not attend to future positions.

This restriction prevents information leakage during training.

For example, when predicting token x5x_5, the model may use

x1,x2,x3,x4, x_1,x_2,x_3,x_4,

but it may not use

x5,x6,,xT. x_5,x_6,\ldots,x_T.

The causal mask is usually a lower-triangular matrix:

M=[1000110011101111]. M = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 0 \\ 1 & 1 & 1 & 1 \end{bmatrix}.

A value of 1 means attention is allowed. A value of 0 means attention is blocked.

In attention score form, blocked positions receive a very negative value before softmax:

Sij={Sij,ji,,j>i. S_{ij} = \begin{cases} S_{ij}, & j \le i, \\ -\infty, & j > i. \end{cases}

After softmax, masked future positions receive probability zero.

Decoder Layer Structure

A decoder-only transformer layer usually contains two sublayers:

  1. Causal multi-head self-attention.
  2. Feedforward network.

A pre-norm decoder layer is

Y=X+CausalSelfAttention(LayerNorm(X)), Y = X + \text{CausalSelfAttention}(\text{LayerNorm}(X)), H=Y+FeedForward(LayerNorm(Y)). H = Y + \text{FeedForward}(\text{LayerNorm}(Y)).

This is almost the same as an encoder layer. The main difference is the causal mask.

An encoder-decoder transformer decoder usually contains three sublayers:

  1. Causal self-attention over the generated prefix.
  2. Cross-attention over encoder outputs.
  3. Feedforward network.

The corresponding pre-norm form is

Y=X+CausalSelfAttention(LayerNorm(X)), Y = X + \text{CausalSelfAttention}(\text{LayerNorm}(X)), Z=Y+CrossAttention(LayerNorm(Y),Henc), Z = Y + \text{CrossAttention}(\text{LayerNorm}(Y), H_{\text{enc}}), H=Z+FeedForward(LayerNorm(Z)). H = Z + \text{FeedForward}(\text{LayerNorm}(Z)).

Decoder-only models omit the cross-attention sublayer.

Shifted Inputs and Training Targets

During training, a decoder predicts the next token at every position in parallel.

Suppose the original sequence is

Deep learning uses tensors

A tokenizer may produce token IDs:

[x1, x2, x3, x4]

The decoder input is shifted right:

[x1, x2, x3]

The target is shifted left:

[x2, x3, x4]

The model receives each prefix and learns to predict the following token.

For a batch tensor:

tokensNB×T, \text{tokens}\in\mathbb{N}^{B\times T},

we create

inputs=tokens:,0:T1, \text{inputs} = \text{tokens}_{:,0:T-1}, targets=tokens:,1:T. \text{targets} = \text{tokens}_{:,1:T}.

In PyTorch:

tokens = torch.tensor([
    [10, 25, 83, 91, 2],
    [10, 77, 19, 34, 2],
])

inputs = tokens[:, :-1]
targets = tokens[:, 1:]

print(inputs.shape)   # torch.Size([2, 4])
print(targets.shape)  # torch.Size([2, 4])

The model outputs logits with shape

[B, T - 1, vocab_size]

The loss compares these logits against the target token IDs.

Vocabulary Projection and Cross-Entropy

The decoder output at each position is a vector in RD\mathbb{R}^D. To predict tokens, the model maps this vector to vocabulary logits:

zt=Wvocabht+b. z_t = W_{\text{vocab}}h_t + b.

The vector ztRVz_t\in\mathbb{R}^V contains one score for each token in the vocabulary.

The softmax converts logits into probabilities:

p(xt+1=kxt)=exp(zt,k)j=1Vexp(zt,j). p(x_{t+1}=k\mid x_{\le t}) = \frac{\exp(z_{t,k})}{\sum_{j=1}^{V}\exp(z_{t,j})}.

Training usually minimizes cross-entropy:

Lt=logp(xt+1=ytxt). L_t = -\log p(x_{t+1}=y_t\mid x_{\le t}).

For all positions and examples, the average loss is

L=1BTb=1Bt=1TLb,t. L = \frac{1}{BT} \sum_{b=1}^{B} \sum_{t=1}^{T} L_{b,t}.

In PyTorch:

import torch
from torch import nn

B, T, V = 4, 16, 30_000

logits = torch.randn(B, T, V)
targets = torch.randint(0, V, (B, T))

loss = nn.functional.cross_entropy(
    logits.reshape(B * T, V),
    targets.reshape(B * T),
)

print(loss)

Cross-entropy expects class scores with shape [N, V] and labels with shape [N], so the batch and time axes are flattened.

A Minimal Decoder Layer in PyTorch

A decoder layer can be built from multi-head attention, layer normalization, residual connections, and a feedforward network.

import torch
from torch import nn

class DecoderLayer(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()

        self.norm1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=n_heads,
            dropout=dropout,
            batch_first=True,
        )
        self.drop1 = nn.Dropout(dropout)

        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
        )
        self.drop2 = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
        # x: [B, T, D]
        # attn_mask: [T, T], True means "block attention"

        y = self.norm1(x)

        attn_out, _ = self.attn(
            y, y, y,
            attn_mask=attn_mask,
            need_weights=False,
        )

        x = x + self.drop1(attn_out)

        y = self.norm2(x)
        ffn_out = self.ffn(y)

        x = x + self.drop2(ffn_out)
        return x

The causal mask can be created as:

def causal_mask(T: int, device=None):
    return torch.triu(
        torch.ones(T, T, dtype=torch.bool, device=device),
        diagonal=1,
    )

Example:

B, T, D = 4, 16, 256

layer = DecoderLayer(d_model=D, n_heads=8, d_ff=1024)

x = torch.randn(B, T, D)
mask = causal_mask(T, x.device)

out = layer(x, mask)
print(out.shape)  # torch.Size([4, 16, 256])

The output shape matches the input shape, so decoder layers can be stacked.

A Minimal Decoder-Only Language Model

A decoder-only language model adds token embeddings, positional embeddings, stacked decoder layers, final normalization, and vocabulary projection.

class DecoderOnlyLM(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        max_len: int,
        d_model: int,
        n_heads: int,
        d_ff: int,
        n_layers: int,
        dropout: float = 0.1,
    ):
        super().__init__()

        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)

        self.layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])

        self.norm = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, tokens: torch.Tensor):
        # tokens: [B, T]
        B, T = tokens.shape

        positions = torch.arange(T, device=tokens.device)
        positions = positions.unsqueeze(0).expand(B, T)

        x = self.token_emb(tokens) + self.pos_emb(positions)

        mask = causal_mask(T, tokens.device)

        for layer in self.layers:
            x = layer(x, mask)

        x = self.norm(x)
        logits = self.lm_head(x)

        return logits

Example:

model = DecoderOnlyLM(
    vocab_size=30_000,
    max_len=512,
    d_model=256,
    n_heads=8,
    d_ff=1024,
    n_layers=6,
)

tokens = torch.randint(0, 30_000, (4, 128))
logits = model(tokens)

print(logits.shape)  # torch.Size([4, 128, 30000])

This model is structurally similar to small GPT-style systems.

Padding Masks and Causal Masks

A real decoder often needs two masks.

The causal mask blocks future positions. The padding mask blocks padding tokens.

For a batch:

tokens = torch.tensor([
    [10, 25, 83, 91, 2],
    [10, 77, 2,  0, 0],
])

where 0 is padding, the padding mask is

key_padding_mask = tokens.eq(0)

This has shape [B, T].

The causal mask has shape [T, T].

In PyTorch nn.MultiheadAttention, these masks are passed separately:

attn_out, _ = self.attn(
    y, y, y,
    attn_mask=causal,
    key_padding_mask=key_padding_mask,
    need_weights=False,
)

The causal mask prevents looking ahead. The padding mask prevents attention to meaningless padded positions.

Autoregressive Generation

During inference, a decoder generates tokens one at a time.

Given a prompt

Deep learning is

the tokenizer produces a prefix. The model predicts a distribution for the next token. A decoding rule selects one token. The selected token is appended to the prefix. The process repeats.

Basic greedy decoding:

@torch.no_grad()
def generate_greedy(model, tokens, max_new_tokens: int, eos_id: int | None = None):
    model.eval()

    for _ in range(max_new_tokens):
        logits = model(tokens)

        next_logits = logits[:, -1, :]
        next_token = next_logits.argmax(dim=-1, keepdim=True)

        tokens = torch.cat([tokens, next_token], dim=1)

        if eos_id is not None and (next_token == eos_id).all():
            break

    return tokens

This recomputes attention over the whole prefix at every step. It is simple, but inefficient for long generation.

Key-Value Caching

During generation, previous tokens do not change. A decoder can cache their key and value tensors.

Without caching, generating NN new tokens requires repeatedly processing the full growing sequence.

With key-value caching, each new step only computes query, key, and value for the new token, then attends to cached keys and values from previous tokens.

For each layer, the cache stores

KcacheRB×h×T×dh, K_{\text{cache}}\in\mathbb{R}^{B\times h\times T\times d_h}, VcacheRB×h×T×dh. V_{\text{cache}}\in\mathbb{R}^{B\times h\times T\times d_h}.

At generation step tt, the new key and value are appended to the cache. The query for the new token attends to all cached keys.

KV caching is essential for efficient language model serving. It reduces repeated computation and improves latency.

Decoding Strategies

A decoder produces a probability distribution. The next token can be selected in several ways.

StrategyDescription
Greedy decodingSelect the highest-probability token
Beam searchKeep several high-scoring partial sequences
Temperature samplingRescale logits before sampling
Top-k samplingSample only from the top kk tokens
Nucleus samplingSample from the smallest set whose probability mass exceeds pp
Contrastive decodingBalance probability and representation diversity

Temperature modifies logits as

zi=ziτ. z'_i = \frac{z_i}{\tau}.

When τ<1\tau < 1, the distribution becomes sharper. When τ>1\tau > 1, the distribution becomes flatter.

Greedy decoding is deterministic but may produce repetitive text. Sampling improves diversity but may reduce reliability. Beam search is useful in translation but often too rigid for open-ended generation.

Decoder-Only Versus Encoder-Decoder Models

Decoder-only models and encoder-decoder models use different conditioning patterns.

A decoder-only model represents the prompt and the generated output in one sequence:

[prompt tokens][generated tokens]

The causal mask ensures each position only sees earlier positions.

An encoder-decoder model first encodes the input sequence, then decodes the output sequence while using cross-attention to the encoder output.

ArchitectureInput handlingOutput generationTypical use
Decoder-onlyPrompt and output in one streamCausal next-token predictionChat, completion, code generation
Encoder-decoderSeparate source and target streamsCausal generation conditioned on encoder statesTranslation, summarization, structured generation
Encoder-onlyFull bidirectional contextUsually no autoregressive generationClassification, tagging, embeddings

Modern large language models are often decoder-only because the architecture is simple, scalable, and flexible. Encoder-decoder models remain useful when the task has a clear source-target structure.

Cross-Attention in Encoder-Decoder Decoders

In encoder-decoder models, the decoder contains cross-attention. In cross-attention, queries come from the decoder hidden states, while keys and values come from the encoder output.

Let

YRB×Ty×D Y\in\mathbb{R}^{B\times T_y\times D}

be the decoder states, and

HencRB×Tx×D H_{\text{enc}}\in\mathbb{R}^{B\times T_x\times D}

be the encoder outputs.

Then

Q=YWQ, Q = YW_Q, K=HencWK, K = H_{\text{enc}}W_K, V=HencWV. V = H_{\text{enc}}W_V.

The decoder attends from output positions to input positions:

CrossAttention(Y,Henc)=softmax(QKdk)V. \text{CrossAttention}(Y,H_{\text{enc}}) = \text{softmax} \left( \frac{QK^\top}{\sqrt{d_k}} \right)V.

Cross-attention lets the decoder generate text while selectively reading source information.

Common Decoder Hyperparameters

A transformer decoder is controlled by the same core hyperparameters as an encoder.

HyperparameterMeaning
DDModel dimension
LLNumber of decoder layers
hhNumber of attention heads
DffD_{\text{ff}}Feedforward hidden dimension
TmaxT_{\max}Maximum context length
VVVocabulary size
DropoutRegularization rate
Positional encodingPosition representation method

Decoder models are sensitive to context length because self-attention scales quadratically during training. During inference, KV caching reduces repeated computation, but memory still grows with the number of generated tokens.

Summary

A transformer decoder is designed for autoregressive generation. It predicts the next token from previous tokens using causal self-attention.

A decoder-only model stacks causal decoder layers and projects hidden states to vocabulary logits. An encoder-decoder decoder adds cross-attention so generated tokens can attend to an encoded source sequence.

The key differences from an encoder are causal masking, shifted training targets, autoregressive inference, and KV caching. These mechanisms turn the transformer block into a practical sequence generator.