A transformer decoder is a neural network block that maps a prefix sequence to a sequence of next-token representations. It is used when the model must generate output one step at a time.
A transformer decoder is a neural network block that maps a prefix sequence to a sequence of next-token representations. It is used when the model must generate output one step at a time.
Decoder-only transformers are the core architecture behind GPT-style language models. Encoder-decoder transformers also use decoder blocks, but those decoders include an additional cross-attention sublayer that reads encoder outputs.
The Decoder Problem
Suppose we have a sequence of tokens
A decoder learns to predict each next token from the previous tokens:
For a full sequence, the model factorizes the probability as
This is called autoregressive modeling.
The input to the decoder is usually a batch of token IDs:
After embedding, the decoder processes
The output is
Each output position is then projected to vocabulary logits:
where
Here is the vocabulary size.
Causal Self-Attention
The defining feature of a transformer decoder is causal self-attention. Position may attend only to positions . It may not attend to future positions.
This restriction prevents information leakage during training.
For example, when predicting token , the model may use
but it may not use
The causal mask is usually a lower-triangular matrix:
A value of 1 means attention is allowed. A value of 0 means attention is blocked.
In attention score form, blocked positions receive a very negative value before softmax:
After softmax, masked future positions receive probability zero.
Decoder Layer Structure
A decoder-only transformer layer usually contains two sublayers:
- Causal multi-head self-attention.
- Feedforward network.
A pre-norm decoder layer is
This is almost the same as an encoder layer. The main difference is the causal mask.
An encoder-decoder transformer decoder usually contains three sublayers:
- Causal self-attention over the generated prefix.
- Cross-attention over encoder outputs.
- Feedforward network.
The corresponding pre-norm form is
Decoder-only models omit the cross-attention sublayer.
Shifted Inputs and Training Targets
During training, a decoder predicts the next token at every position in parallel.
Suppose the original sequence is
Deep learning uses tensorsA tokenizer may produce token IDs:
[x1, x2, x3, x4]The decoder input is shifted right:
[x1, x2, x3]The target is shifted left:
[x2, x3, x4]The model receives each prefix and learns to predict the following token.
For a batch tensor:
we create
In PyTorch:
tokens = torch.tensor([
[10, 25, 83, 91, 2],
[10, 77, 19, 34, 2],
])
inputs = tokens[:, :-1]
targets = tokens[:, 1:]
print(inputs.shape) # torch.Size([2, 4])
print(targets.shape) # torch.Size([2, 4])The model outputs logits with shape
[B, T - 1, vocab_size]The loss compares these logits against the target token IDs.
Vocabulary Projection and Cross-Entropy
The decoder output at each position is a vector in . To predict tokens, the model maps this vector to vocabulary logits:
The vector contains one score for each token in the vocabulary.
The softmax converts logits into probabilities:
Training usually minimizes cross-entropy:
For all positions and examples, the average loss is
In PyTorch:
import torch
from torch import nn
B, T, V = 4, 16, 30_000
logits = torch.randn(B, T, V)
targets = torch.randint(0, V, (B, T))
loss = nn.functional.cross_entropy(
logits.reshape(B * T, V),
targets.reshape(B * T),
)
print(loss)Cross-entropy expects class scores with shape [N, V] and labels with shape [N], so the batch and time axes are flattened.
A Minimal Decoder Layer in PyTorch
A decoder layer can be built from multi-head attention, layer normalization, residual connections, and a feedforward network.
import torch
from torch import nn
class DecoderLayer(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=n_heads,
dropout=dropout,
batch_first=True,
)
self.drop1 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
)
self.drop2 = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
# x: [B, T, D]
# attn_mask: [T, T], True means "block attention"
y = self.norm1(x)
attn_out, _ = self.attn(
y, y, y,
attn_mask=attn_mask,
need_weights=False,
)
x = x + self.drop1(attn_out)
y = self.norm2(x)
ffn_out = self.ffn(y)
x = x + self.drop2(ffn_out)
return xThe causal mask can be created as:
def causal_mask(T: int, device=None):
return torch.triu(
torch.ones(T, T, dtype=torch.bool, device=device),
diagonal=1,
)Example:
B, T, D = 4, 16, 256
layer = DecoderLayer(d_model=D, n_heads=8, d_ff=1024)
x = torch.randn(B, T, D)
mask = causal_mask(T, x.device)
out = layer(x, mask)
print(out.shape) # torch.Size([4, 16, 256])The output shape matches the input shape, so decoder layers can be stacked.
A Minimal Decoder-Only Language Model
A decoder-only language model adds token embeddings, positional embeddings, stacked decoder layers, final normalization, and vocabulary projection.
class DecoderOnlyLM(nn.Module):
def __init__(
self,
vocab_size: int,
max_len: int,
d_model: int,
n_heads: int,
d_ff: int,
n_layers: int,
dropout: float = 0.1,
):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_len, d_model)
self.layers = nn.ModuleList([
DecoderLayer(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, tokens: torch.Tensor):
# tokens: [B, T]
B, T = tokens.shape
positions = torch.arange(T, device=tokens.device)
positions = positions.unsqueeze(0).expand(B, T)
x = self.token_emb(tokens) + self.pos_emb(positions)
mask = causal_mask(T, tokens.device)
for layer in self.layers:
x = layer(x, mask)
x = self.norm(x)
logits = self.lm_head(x)
return logitsExample:
model = DecoderOnlyLM(
vocab_size=30_000,
max_len=512,
d_model=256,
n_heads=8,
d_ff=1024,
n_layers=6,
)
tokens = torch.randint(0, 30_000, (4, 128))
logits = model(tokens)
print(logits.shape) # torch.Size([4, 128, 30000])This model is structurally similar to small GPT-style systems.
Padding Masks and Causal Masks
A real decoder often needs two masks.
The causal mask blocks future positions. The padding mask blocks padding tokens.
For a batch:
tokens = torch.tensor([
[10, 25, 83, 91, 2],
[10, 77, 2, 0, 0],
])where 0 is padding, the padding mask is
key_padding_mask = tokens.eq(0)This has shape [B, T].
The causal mask has shape [T, T].
In PyTorch nn.MultiheadAttention, these masks are passed separately:
attn_out, _ = self.attn(
y, y, y,
attn_mask=causal,
key_padding_mask=key_padding_mask,
need_weights=False,
)The causal mask prevents looking ahead. The padding mask prevents attention to meaningless padded positions.
Autoregressive Generation
During inference, a decoder generates tokens one at a time.
Given a prompt
Deep learning isthe tokenizer produces a prefix. The model predicts a distribution for the next token. A decoding rule selects one token. The selected token is appended to the prefix. The process repeats.
Basic greedy decoding:
@torch.no_grad()
def generate_greedy(model, tokens, max_new_tokens: int, eos_id: int | None = None):
model.eval()
for _ in range(max_new_tokens):
logits = model(tokens)
next_logits = logits[:, -1, :]
next_token = next_logits.argmax(dim=-1, keepdim=True)
tokens = torch.cat([tokens, next_token], dim=1)
if eos_id is not None and (next_token == eos_id).all():
break
return tokensThis recomputes attention over the whole prefix at every step. It is simple, but inefficient for long generation.
Key-Value Caching
During generation, previous tokens do not change. A decoder can cache their key and value tensors.
Without caching, generating new tokens requires repeatedly processing the full growing sequence.
With key-value caching, each new step only computes query, key, and value for the new token, then attends to cached keys and values from previous tokens.
For each layer, the cache stores
At generation step , the new key and value are appended to the cache. The query for the new token attends to all cached keys.
KV caching is essential for efficient language model serving. It reduces repeated computation and improves latency.
Decoding Strategies
A decoder produces a probability distribution. The next token can be selected in several ways.
| Strategy | Description |
|---|---|
| Greedy decoding | Select the highest-probability token |
| Beam search | Keep several high-scoring partial sequences |
| Temperature sampling | Rescale logits before sampling |
| Top-k sampling | Sample only from the top tokens |
| Nucleus sampling | Sample from the smallest set whose probability mass exceeds |
| Contrastive decoding | Balance probability and representation diversity |
Temperature modifies logits as
When , the distribution becomes sharper. When , the distribution becomes flatter.
Greedy decoding is deterministic but may produce repetitive text. Sampling improves diversity but may reduce reliability. Beam search is useful in translation but often too rigid for open-ended generation.
Decoder-Only Versus Encoder-Decoder Models
Decoder-only models and encoder-decoder models use different conditioning patterns.
A decoder-only model represents the prompt and the generated output in one sequence:
[prompt tokens][generated tokens]The causal mask ensures each position only sees earlier positions.
An encoder-decoder model first encodes the input sequence, then decodes the output sequence while using cross-attention to the encoder output.
| Architecture | Input handling | Output generation | Typical use |
|---|---|---|---|
| Decoder-only | Prompt and output in one stream | Causal next-token prediction | Chat, completion, code generation |
| Encoder-decoder | Separate source and target streams | Causal generation conditioned on encoder states | Translation, summarization, structured generation |
| Encoder-only | Full bidirectional context | Usually no autoregressive generation | Classification, tagging, embeddings |
Modern large language models are often decoder-only because the architecture is simple, scalable, and flexible. Encoder-decoder models remain useful when the task has a clear source-target structure.
Cross-Attention in Encoder-Decoder Decoders
In encoder-decoder models, the decoder contains cross-attention. In cross-attention, queries come from the decoder hidden states, while keys and values come from the encoder output.
Let
be the decoder states, and
be the encoder outputs.
Then
The decoder attends from output positions to input positions:
Cross-attention lets the decoder generate text while selectively reading source information.
Common Decoder Hyperparameters
A transformer decoder is controlled by the same core hyperparameters as an encoder.
| Hyperparameter | Meaning |
|---|---|
| Model dimension | |
| Number of decoder layers | |
| Number of attention heads | |
| Feedforward hidden dimension | |
| Maximum context length | |
| Vocabulary size | |
| Dropout | Regularization rate |
| Positional encoding | Position representation method |
Decoder models are sensitive to context length because self-attention scales quadratically during training. During inference, KV caching reduces repeated computation, but memory still grows with the number of generated tokens.
Summary
A transformer decoder is designed for autoregressive generation. It predicts the next token from previous tokens using causal self-attention.
A decoder-only model stacks causal decoder layers and projects hidden states to vocabulary logits. An encoder-decoder decoder adds cross-attention so generated tokens can attend to an encoded source sequence.
The key differences from an encoder are causal masking, shifted training targets, autoregressive inference, and KV caching. These mechanisms turn the transformer block into a practical sequence generator.