A transformer decoder maps a partial output sequence to predictions for the next token or next output step.
A transformer decoder maps a partial output sequence to predictions for the next token or next output step. Unlike an encoder, a decoder usually cannot see future positions. It must produce each representation using only the current and previous tokens.
Decoder-only transformers are the core architecture behind modern autoregressive language models. Encoder-decoder transformers also use decoder blocks, but those decoders include cross-attention to read from encoder outputs.
Given token embeddings
a decoder returns contextual states
For language modeling, these states are projected into vocabulary logits:
The output shape is
where is vocabulary size.
Causal Self-Attention
The defining feature of a decoder is causal self-attention. Position may attend only to positions . It may not attend to positions after .
For a sequence
the cat satthe model trains by predicting the next token at each position:
| Input context | Target |
|---|---|
the | cat |
the cat | sat |
the cat sat | next token |
This requires a triangular mask over the attention matrix.
The mask is added before softmax:
After softmax, future positions receive zero attention weight.
Decoder Layer Structure
A decoder-only layer has two main sublayers:
| Sublayer | Purpose |
|---|---|
| Causal multi-head self-attention | Reads previous tokens |
| Feedforward network | Applies nonlinear token-wise computation |
Using pre-normalization, the layer is
This is the common decoder-only block used in GPT-style models.
Encoder-Decoder Decoder Layer
In an encoder-decoder transformer, the decoder has three sublayers:
| Sublayer | Purpose |
|---|---|
| Causal self-attention | Reads previous target tokens |
| Cross-attention | Reads encoder outputs |
| Feedforward network | Applies nonlinear token-wise computation |
The layer can be written as
Here is the encoder output.
Encoder-decoder decoders are common in translation, summarization, speech recognition, and structured generation tasks where a source sequence is first encoded and then decoded.
A Minimal Decoder-Only Layer in PyTorch
import torch
from torch import nn
class FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class TransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model: int,
num_heads: int,
d_ff: int,
dropout: float = 0.1,
):
super().__init__()
self.self_attn = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=num_heads,
dropout=dropout,
batch_first=True,
)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, key_padding_mask=None) -> torch.Tensor:
# x: [B, T, D]
h = self.norm1(x)
attn_out, _ = self.self_attn(
query=h,
key=h,
value=h,
key_padding_mask=key_padding_mask,
need_weights=False,
is_causal=True,
)
x = x + self.dropout1(attn_out)
h = self.norm2(x)
ffn_out = self.ffn(h)
x = x + self.dropout2(ffn_out)
return xSome PyTorch versions require an explicit attention mask when is_causal=True. A manual mask is straightforward:
def causal_mask(T: int, device=None) -> torch.Tensor:
return torch.triu(
torch.ones(T, T, device=device, dtype=torch.bool),
diagonal=1,
)Then pass it as attn_mask:
mask = causal_mask(T, device=x.device)
attn_out, _ = self.self_attn(
query=h,
key=h,
value=h,
attn_mask=mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)Stacking Decoder Layers
A decoder is a stack of decoder layers:
In PyTorch:
class TransformerDecoder(nn.Module):
def __init__(
self,
num_layers: int,
d_model: int,
num_heads: int,
d_ff: int,
dropout: float = 0.1,
):
super().__init__()
self.layers = nn.ModuleList([
TransformerDecoderLayer(
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
dropout=dropout,
)
for _ in range(num_layers)
])
self.final_norm = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor, key_padding_mask=None) -> torch.Tensor:
for layer in self.layers:
x = layer(x, key_padding_mask=key_padding_mask)
return self.final_norm(x)Example:
decoder = TransformerDecoder(
num_layers=6,
d_model=256,
num_heads=8,
d_ff=1024,
)
x = torch.randn(2, 32, 256)
y = decoder(x)
print(y.shape) # torch.Size([2, 32, 256])A Small Decoder-Only Language Model
A decoder-only language model combines token embeddings, positional embeddings, decoder layers, and a vocabulary projection.
class DecoderOnlyLanguageModel(nn.Module):
def __init__(
self,
vocab_size: int,
max_length: int,
num_layers: int,
d_model: int,
num_heads: int,
d_ff: int,
dropout: float = 0.1,
pad_token_id: int = 0,
):
super().__init__()
self.pad_token_id = pad_token_id
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.position_embedding = nn.Embedding(max_length, d_model)
self.dropout = nn.Dropout(dropout)
self.decoder = TransformerDecoder(
num_layers=num_layers,
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
dropout=dropout,
)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
# token_ids: [B, T]
B, T = token_ids.shape
positions = torch.arange(T, device=token_ids.device)
positions = positions.unsqueeze(0).expand(B, T)
x = self.token_embedding(token_ids)
p = self.position_embedding(positions)
x = self.dropout(x + p)
key_padding_mask = token_ids.eq(self.pad_token_id)
h = self.decoder(x, key_padding_mask=key_padding_mask)
logits = self.lm_head(h)
# logits: [B, T, vocab_size]
return logitsUsage:
model = DecoderOnlyLanguageModel(
vocab_size=50_000,
max_length=512,
num_layers=6,
d_model=256,
num_heads=8,
d_ff=1024,
)
token_ids = torch.randint(0, 50_000, (4, 128))
logits = model(token_ids)
print(logits.shape) # torch.Size([4, 128, 50000])Next-Token Training Objective
A decoder-only language model is trained to predict the next token. Given a token sequence
the model predicts
In code, we shift inputs and targets:
import torch.nn.functional as F
token_ids = torch.randint(0, 50_000, (4, 128))
input_ids = token_ids[:, :-1]
target_ids = token_ids[:, 1:]
logits = model(input_ids)
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
target_ids.reshape(-1),
)The logits have shape [B, T - 1, V]. The targets have shape [B, T - 1]. Cross-entropy compares each position’s vocabulary distribution with the next token.
Autoregressive Generation
During generation, the model repeatedly predicts one token and appends it to the context.
A simple greedy decoding loop:
@torch.no_grad()
def generate_greedy(
model: nn.Module,
input_ids: torch.Tensor,
max_new_tokens: int,
eos_token_id: int | None = None,
):
model.eval()
for _ in range(max_new_tokens):
logits = model(input_ids)
next_logits = logits[:, -1, :]
next_token = next_logits.argmax(dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_token], dim=1)
if eos_token_id is not None:
if torch.all(next_token.squeeze(-1) == eos_token_id):
break
return input_idsGreedy decoding chooses the most likely token at each step. Other methods include temperature sampling, top-k sampling, nucleus sampling, beam search, and contrastive decoding.
KV Caching
Naive generation recomputes attention over the full context at every step. If the current sequence length is , each new token repeats work for all previous tokens.
KV caching stores past keys and values for each layer. At generation step , the model computes the new query, key, and value only for the latest token. It reuses previous keys and values.
Without caching, decoding is simple but inefficient. With caching, decoding is much faster for long outputs.
Conceptually, for each layer, the cache stores:
past_keys: [B, H, T_past, d_head]
past_values: [B, H, T_past, d_head]At the next step, new keys and values are appended:
keys = torch.cat([past_keys, new_keys], dim=2)
values = torch.cat([past_values, new_values], dim=2)KV caching changes inference code more than training code. During training, the full sequence is processed in parallel. During generation, tokens are processed incrementally.
Decoder-Only Versus Encoder-Decoder Models
Decoder-only models are trained to continue text. They are natural for open-ended generation, chat, code completion, and instruction following.
Encoder-decoder models separate input understanding from output generation. They are natural for translation, summarization, speech-to-text, and tasks where the source and target have different structures.
| Property | Decoder-only | Encoder-decoder |
|---|---|---|
| Input processing | Causal self-attention over one sequence | Encoder reads source, decoder generates target |
| Cross-attention | Usually absent | Present |
| Common tasks | Language modeling, chat, code | Translation, summarization, speech |
| Generation | Autoregressive | Autoregressive |
| Prompt format | Source and target in one sequence | Source and target separated |
Many modern systems use decoder-only models because one next-token objective can cover many tasks. Encoder-decoder models remain strong when the task has a clear input-output structure.
Summary
A transformer decoder uses causal self-attention to produce representations that cannot depend on future tokens. This makes it suitable for autoregressive generation.
A decoder-only language model combines token embeddings, position embeddings, stacked decoder layers, and a vocabulary projection. It is trained by next-token prediction using shifted input and target sequences.
During inference, the model generates tokens one at a time. KV caching avoids recomputing past keys and values, making long generation practical.