Skip to content

Transformer Decoders

A transformer decoder maps a partial output sequence to predictions for the next token or next output step.

A transformer decoder maps a partial output sequence to predictions for the next token or next output step. Unlike an encoder, a decoder usually cannot see future positions. It must produce each representation using only the current and previous tokens.

Decoder-only transformers are the core architecture behind modern autoregressive language models. Encoder-decoder transformers also use decoder blocks, but those decoders include cross-attention to read from encoder outputs.

Given token embeddings

XRB×T×D, X \in \mathbb{R}^{B \times T \times D},

a decoder returns contextual states

HRB×T×D. H \in \mathbb{R}^{B \times T \times D}.

For language modeling, these states are projected into vocabulary logits:

Z=HWvocab+b. Z = HW_{\text{vocab}} + b.

The output shape is

ZRB×T×V, Z \in \mathbb{R}^{B \times T \times V},

where VV is vocabulary size.

Causal Self-Attention

The defining feature of a decoder is causal self-attention. Position tt may attend only to positions 0,,t0,\ldots,t. It may not attend to positions after tt.

For a sequence

the cat sat

the model trains by predicting the next token at each position:

Input contextTarget
thecat
the catsat
the cat satnext token

This requires a triangular mask over the attention matrix.

Mij={0,ji,,j>i. M_{ij} = \begin{cases} 0, & j \leq i, \\ -\infty, & j > i. \end{cases}

The mask is added before softmax:

A=softmax(QKdh+M). A = \operatorname{softmax} \left( \frac{QK^\top}{\sqrt{d_h}} + M \right).

After softmax, future positions receive zero attention weight.

Decoder Layer Structure

A decoder-only layer has two main sublayers:

SublayerPurpose
Causal multi-head self-attentionReads previous tokens
Feedforward networkApplies nonlinear token-wise computation

Using pre-normalization, the layer is

X1=X+CausalSelfAttention(LayerNorm(X)), X_1 = X + \operatorname{CausalSelfAttention}(\operatorname{LayerNorm}(X)), Y=X1+FFN(LayerNorm(X1)). Y = X_1 + \operatorname{FFN}(\operatorname{LayerNorm}(X_1)).

This is the common decoder-only block used in GPT-style models.

Encoder-Decoder Decoder Layer

In an encoder-decoder transformer, the decoder has three sublayers:

SublayerPurpose
Causal self-attentionReads previous target tokens
Cross-attentionReads encoder outputs
Feedforward networkApplies nonlinear token-wise computation

The layer can be written as

X1=X+CausalSelfAttention(LayerNorm(X)), X_1 = X + \operatorname{CausalSelfAttention}(\operatorname{LayerNorm}(X)), X2=X1+CrossAttention(LayerNorm(X1),E), X_2 = X_1 + \operatorname{CrossAttention}(\operatorname{LayerNorm}(X_1), E), Y=X2+FFN(LayerNorm(X2)). Y = X_2 + \operatorname{FFN}(\operatorname{LayerNorm}(X_2)).

Here EE is the encoder output.

Encoder-decoder decoders are common in translation, summarization, speech recognition, and structured generation tasks where a source sequence is first encoded and then decoded.

A Minimal Decoder-Only Layer in PyTorch

import torch
from torch import nn

class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

class TransformerDecoderLayer(nn.Module):
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        d_ff: int,
        dropout: float = 0.1,
    ):
        super().__init__()

        self.self_attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True,
        )

        self.ffn = FeedForward(d_model, d_ff, dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, key_padding_mask=None) -> torch.Tensor:
        # x: [B, T, D]
        h = self.norm1(x)

        attn_out, _ = self.self_attn(
            query=h,
            key=h,
            value=h,
            key_padding_mask=key_padding_mask,
            need_weights=False,
            is_causal=True,
        )

        x = x + self.dropout1(attn_out)

        h = self.norm2(x)
        ffn_out = self.ffn(h)

        x = x + self.dropout2(ffn_out)
        return x

Some PyTorch versions require an explicit attention mask when is_causal=True. A manual mask is straightforward:

def causal_mask(T: int, device=None) -> torch.Tensor:
    return torch.triu(
        torch.ones(T, T, device=device, dtype=torch.bool),
        diagonal=1,
    )

Then pass it as attn_mask:

mask = causal_mask(T, device=x.device)

attn_out, _ = self.self_attn(
    query=h,
    key=h,
    value=h,
    attn_mask=mask,
    key_padding_mask=key_padding_mask,
    need_weights=False,
)

Stacking Decoder Layers

A decoder is a stack of decoder layers:

H(0)=X, H^{(0)} = X, H(+1)=DecoderLayer()(H()). H^{(\ell+1)} = \operatorname{DecoderLayer}^{(\ell)}(H^{(\ell)}).

In PyTorch:

class TransformerDecoder(nn.Module):
    def __init__(
        self,
        num_layers: int,
        d_model: int,
        num_heads: int,
        d_ff: int,
        dropout: float = 0.1,
    ):
        super().__init__()

        self.layers = nn.ModuleList([
            TransformerDecoderLayer(
                d_model=d_model,
                num_heads=num_heads,
                d_ff=d_ff,
                dropout=dropout,
            )
            for _ in range(num_layers)
        ])

        self.final_norm = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor, key_padding_mask=None) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x, key_padding_mask=key_padding_mask)

        return self.final_norm(x)

Example:

decoder = TransformerDecoder(
    num_layers=6,
    d_model=256,
    num_heads=8,
    d_ff=1024,
)

x = torch.randn(2, 32, 256)
y = decoder(x)

print(y.shape)  # torch.Size([2, 32, 256])

A Small Decoder-Only Language Model

A decoder-only language model combines token embeddings, positional embeddings, decoder layers, and a vocabulary projection.

class DecoderOnlyLanguageModel(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        max_length: int,
        num_layers: int,
        d_model: int,
        num_heads: int,
        d_ff: int,
        dropout: float = 0.1,
        pad_token_id: int = 0,
    ):
        super().__init__()

        self.pad_token_id = pad_token_id
        self.vocab_size = vocab_size

        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_length, d_model)

        self.dropout = nn.Dropout(dropout)

        self.decoder = TransformerDecoder(
            num_layers=num_layers,
            d_model=d_model,
            num_heads=num_heads,
            d_ff=d_ff,
            dropout=dropout,
        )

        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        # token_ids: [B, T]
        B, T = token_ids.shape

        positions = torch.arange(T, device=token_ids.device)
        positions = positions.unsqueeze(0).expand(B, T)

        x = self.token_embedding(token_ids)
        p = self.position_embedding(positions)

        x = self.dropout(x + p)

        key_padding_mask = token_ids.eq(self.pad_token_id)

        h = self.decoder(x, key_padding_mask=key_padding_mask)
        logits = self.lm_head(h)

        # logits: [B, T, vocab_size]
        return logits

Usage:

model = DecoderOnlyLanguageModel(
    vocab_size=50_000,
    max_length=512,
    num_layers=6,
    d_model=256,
    num_heads=8,
    d_ff=1024,
)

token_ids = torch.randint(0, 50_000, (4, 128))
logits = model(token_ids)

print(logits.shape)  # torch.Size([4, 128, 50000])

Next-Token Training Objective

A decoder-only language model is trained to predict the next token. Given a token sequence

[x0,x1,,xT1], [x_0,x_1,\ldots,x_{T-1}],

the model predicts

[x1,x2,,xT]. [x_1,x_2,\ldots,x_T].

In code, we shift inputs and targets:

import torch.nn.functional as F

token_ids = torch.randint(0, 50_000, (4, 128))

input_ids = token_ids[:, :-1]
target_ids = token_ids[:, 1:]

logits = model(input_ids)

loss = F.cross_entropy(
    logits.reshape(-1, logits.size(-1)),
    target_ids.reshape(-1),
)

The logits have shape [B, T - 1, V]. The targets have shape [B, T - 1]. Cross-entropy compares each position’s vocabulary distribution with the next token.

Autoregressive Generation

During generation, the model repeatedly predicts one token and appends it to the context.

A simple greedy decoding loop:

@torch.no_grad()
def generate_greedy(
    model: nn.Module,
    input_ids: torch.Tensor,
    max_new_tokens: int,
    eos_token_id: int | None = None,
):
    model.eval()

    for _ in range(max_new_tokens):
        logits = model(input_ids)

        next_logits = logits[:, -1, :]
        next_token = next_logits.argmax(dim=-1, keepdim=True)

        input_ids = torch.cat([input_ids, next_token], dim=1)

        if eos_token_id is not None:
            if torch.all(next_token.squeeze(-1) == eos_token_id):
                break

    return input_ids

Greedy decoding chooses the most likely token at each step. Other methods include temperature sampling, top-k sampling, nucleus sampling, beam search, and contrastive decoding.

KV Caching

Naive generation recomputes attention over the full context at every step. If the current sequence length is TT, each new token repeats work for all previous tokens.

KV caching stores past keys and values for each layer. At generation step tt, the model computes the new query, key, and value only for the latest token. It reuses previous keys and values.

Without caching, decoding is simple but inefficient. With caching, decoding is much faster for long outputs.

Conceptually, for each layer, the cache stores:

past_keys:   [B, H, T_past, d_head]
past_values: [B, H, T_past, d_head]

At the next step, new keys and values are appended:

keys = torch.cat([past_keys, new_keys], dim=2)
values = torch.cat([past_values, new_values], dim=2)

KV caching changes inference code more than training code. During training, the full sequence is processed in parallel. During generation, tokens are processed incrementally.

Decoder-Only Versus Encoder-Decoder Models

Decoder-only models are trained to continue text. They are natural for open-ended generation, chat, code completion, and instruction following.

Encoder-decoder models separate input understanding from output generation. They are natural for translation, summarization, speech-to-text, and tasks where the source and target have different structures.

PropertyDecoder-onlyEncoder-decoder
Input processingCausal self-attention over one sequenceEncoder reads source, decoder generates target
Cross-attentionUsually absentPresent
Common tasksLanguage modeling, chat, codeTranslation, summarization, speech
GenerationAutoregressiveAutoregressive
Prompt formatSource and target in one sequenceSource and target separated

Many modern systems use decoder-only models because one next-token objective can cover many tasks. Encoder-decoder models remain strong when the task has a clear input-output structure.

Summary

A transformer decoder uses causal self-attention to produce representations that cannot depend on future tokens. This makes it suitable for autoregressive generation.

A decoder-only language model combines token embeddings, position embeddings, stacked decoder layers, and a vocabulary projection. It is trained by next-token prediction using shifted input and target sequences.

During inference, the model generates tokens one at a time. KV caching avoids recomputing past keys and values, making long generation practical.