# Transformer Decoders

A transformer decoder maps a partial output sequence to predictions for the next token or next output step. Unlike an encoder, a decoder usually cannot see future positions. It must produce each representation using only the current and previous tokens.

Decoder-only transformers are the core architecture behind modern autoregressive language models. Encoder-decoder transformers also use decoder blocks, but those decoders include cross-attention to read from encoder outputs.

Given token embeddings

$$
X \in \mathbb{R}^{B \times T \times D},
$$

a decoder returns contextual states

$$
H \in \mathbb{R}^{B \times T \times D}.
$$

For language modeling, these states are projected into vocabulary logits:

$$
Z = HW_{\text{vocab}} + b.
$$

The output shape is

$$
Z \in \mathbb{R}^{B \times T \times V},
$$

where $V$ is vocabulary size.

### Causal Self-Attention

The defining feature of a decoder is causal self-attention. Position $t$ may attend only to positions $0,\ldots,t$. It may not attend to positions after $t$.

For a sequence

```text
the cat sat
```

the model trains by predicting the next token at each position:

| Input context | Target |
|---|---|
| `the` | `cat` |
| `the cat` | `sat` |
| `the cat sat` | next token |

This requires a triangular mask over the attention matrix.

$$
M_{ij} =
\begin{cases}
0, & j \leq i, \\
-\infty, & j > i.
\end{cases}
$$

The mask is added before softmax:

$$
A =
\operatorname{softmax}
\left(
\frac{QK^\top}{\sqrt{d_h}} + M
\right).
$$

After softmax, future positions receive zero attention weight.

### Decoder Layer Structure

A decoder-only layer has two main sublayers:

| Sublayer | Purpose |
|---|---|
| Causal multi-head self-attention | Reads previous tokens |
| Feedforward network | Applies nonlinear token-wise computation |

Using pre-normalization, the layer is

$$
X_1 = X + \operatorname{CausalSelfAttention}(\operatorname{LayerNorm}(X)),
$$

$$
Y = X_1 + \operatorname{FFN}(\operatorname{LayerNorm}(X_1)).
$$

This is the common decoder-only block used in GPT-style models.

### Encoder-Decoder Decoder Layer

In an encoder-decoder transformer, the decoder has three sublayers:

| Sublayer | Purpose |
|---|---|
| Causal self-attention | Reads previous target tokens |
| Cross-attention | Reads encoder outputs |
| Feedforward network | Applies nonlinear token-wise computation |

The layer can be written as

$$
X_1 = X + \operatorname{CausalSelfAttention}(\operatorname{LayerNorm}(X)),
$$

$$
X_2 = X_1 + \operatorname{CrossAttention}(\operatorname{LayerNorm}(X_1), E),
$$

$$
Y = X_2 + \operatorname{FFN}(\operatorname{LayerNorm}(X_2)).
$$

Here $E$ is the encoder output.

Encoder-decoder decoders are common in translation, summarization, speech recognition, and structured generation tasks where a source sequence is first encoded and then decoded.

### A Minimal Decoder-Only Layer in PyTorch

```python
import torch
from torch import nn

class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

class TransformerDecoderLayer(nn.Module):
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        d_ff: int,
        dropout: float = 0.1,
    ):
        super().__init__()

        self.self_attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True,
        )

        self.ffn = FeedForward(d_model, d_ff, dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, key_padding_mask=None) -> torch.Tensor:
        # x: [B, T, D]
        h = self.norm1(x)

        attn_out, _ = self.self_attn(
            query=h,
            key=h,
            value=h,
            key_padding_mask=key_padding_mask,
            need_weights=False,
            is_causal=True,
        )

        x = x + self.dropout1(attn_out)

        h = self.norm2(x)
        ffn_out = self.ffn(h)

        x = x + self.dropout2(ffn_out)
        return x
```

Some PyTorch versions require an explicit attention mask when `is_causal=True`. A manual mask is straightforward:

```python
def causal_mask(T: int, device=None) -> torch.Tensor:
    return torch.triu(
        torch.ones(T, T, device=device, dtype=torch.bool),
        diagonal=1,
    )
```

Then pass it as `attn_mask`:

```python
mask = causal_mask(T, device=x.device)

attn_out, _ = self.self_attn(
    query=h,
    key=h,
    value=h,
    attn_mask=mask,
    key_padding_mask=key_padding_mask,
    need_weights=False,
)
```

### Stacking Decoder Layers

A decoder is a stack of decoder layers:

$$
H^{(0)} = X,
$$

$$
H^{(\ell+1)} =
\operatorname{DecoderLayer}^{(\ell)}(H^{(\ell)}).
$$

In PyTorch:

```python
class TransformerDecoder(nn.Module):
    def __init__(
        self,
        num_layers: int,
        d_model: int,
        num_heads: int,
        d_ff: int,
        dropout: float = 0.1,
    ):
        super().__init__()

        self.layers = nn.ModuleList([
            TransformerDecoderLayer(
                d_model=d_model,
                num_heads=num_heads,
                d_ff=d_ff,
                dropout=dropout,
            )
            for _ in range(num_layers)
        ])

        self.final_norm = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor, key_padding_mask=None) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x, key_padding_mask=key_padding_mask)

        return self.final_norm(x)
```

Example:

```python
decoder = TransformerDecoder(
    num_layers=6,
    d_model=256,
    num_heads=8,
    d_ff=1024,
)

x = torch.randn(2, 32, 256)
y = decoder(x)

print(y.shape)  # torch.Size([2, 32, 256])
```

### A Small Decoder-Only Language Model

A decoder-only language model combines token embeddings, positional embeddings, decoder layers, and a vocabulary projection.

```python
class DecoderOnlyLanguageModel(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        max_length: int,
        num_layers: int,
        d_model: int,
        num_heads: int,
        d_ff: int,
        dropout: float = 0.1,
        pad_token_id: int = 0,
    ):
        super().__init__()

        self.pad_token_id = pad_token_id
        self.vocab_size = vocab_size

        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_length, d_model)

        self.dropout = nn.Dropout(dropout)

        self.decoder = TransformerDecoder(
            num_layers=num_layers,
            d_model=d_model,
            num_heads=num_heads,
            d_ff=d_ff,
            dropout=dropout,
        )

        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        # token_ids: [B, T]
        B, T = token_ids.shape

        positions = torch.arange(T, device=token_ids.device)
        positions = positions.unsqueeze(0).expand(B, T)

        x = self.token_embedding(token_ids)
        p = self.position_embedding(positions)

        x = self.dropout(x + p)

        key_padding_mask = token_ids.eq(self.pad_token_id)

        h = self.decoder(x, key_padding_mask=key_padding_mask)
        logits = self.lm_head(h)

        # logits: [B, T, vocab_size]
        return logits
```

Usage:

```python
model = DecoderOnlyLanguageModel(
    vocab_size=50_000,
    max_length=512,
    num_layers=6,
    d_model=256,
    num_heads=8,
    d_ff=1024,
)

token_ids = torch.randint(0, 50_000, (4, 128))
logits = model(token_ids)

print(logits.shape)  # torch.Size([4, 128, 50000])
```

### Next-Token Training Objective

A decoder-only language model is trained to predict the next token. Given a token sequence

$$
[x_0,x_1,\ldots,x_{T-1}],
$$

the model predicts

$$
[x_1,x_2,\ldots,x_T].
$$

In code, we shift inputs and targets:

```python
import torch.nn.functional as F

token_ids = torch.randint(0, 50_000, (4, 128))

input_ids = token_ids[:, :-1]
target_ids = token_ids[:, 1:]

logits = model(input_ids)

loss = F.cross_entropy(
    logits.reshape(-1, logits.size(-1)),
    target_ids.reshape(-1),
)
```

The logits have shape `[B, T - 1, V]`. The targets have shape `[B, T - 1]`. Cross-entropy compares each position’s vocabulary distribution with the next token.

### Autoregressive Generation

During generation, the model repeatedly predicts one token and appends it to the context.

A simple greedy decoding loop:

```python
@torch.no_grad()
def generate_greedy(
    model: nn.Module,
    input_ids: torch.Tensor,
    max_new_tokens: int,
    eos_token_id: int | None = None,
):
    model.eval()

    for _ in range(max_new_tokens):
        logits = model(input_ids)

        next_logits = logits[:, -1, :]
        next_token = next_logits.argmax(dim=-1, keepdim=True)

        input_ids = torch.cat([input_ids, next_token], dim=1)

        if eos_token_id is not None:
            if torch.all(next_token.squeeze(-1) == eos_token_id):
                break

    return input_ids
```

Greedy decoding chooses the most likely token at each step. Other methods include temperature sampling, top-k sampling, nucleus sampling, beam search, and contrastive decoding.

### KV Caching

Naive generation recomputes attention over the full context at every step. If the current sequence length is $T$, each new token repeats work for all previous tokens.

KV caching stores past keys and values for each layer. At generation step $t$, the model computes the new query, key, and value only for the latest token. It reuses previous keys and values.

Without caching, decoding is simple but inefficient. With caching, decoding is much faster for long outputs.

Conceptually, for each layer, the cache stores:

```python
past_keys:   [B, H, T_past, d_head]
past_values: [B, H, T_past, d_head]
```

At the next step, new keys and values are appended:

```python
keys = torch.cat([past_keys, new_keys], dim=2)
values = torch.cat([past_values, new_values], dim=2)
```

KV caching changes inference code more than training code. During training, the full sequence is processed in parallel. During generation, tokens are processed incrementally.

### Decoder-Only Versus Encoder-Decoder Models

Decoder-only models are trained to continue text. They are natural for open-ended generation, chat, code completion, and instruction following.

Encoder-decoder models separate input understanding from output generation. They are natural for translation, summarization, speech-to-text, and tasks where the source and target have different structures.

| Property | Decoder-only | Encoder-decoder |
|---|---|---|
| Input processing | Causal self-attention over one sequence | Encoder reads source, decoder generates target |
| Cross-attention | Usually absent | Present |
| Common tasks | Language modeling, chat, code | Translation, summarization, speech |
| Generation | Autoregressive | Autoregressive |
| Prompt format | Source and target in one sequence | Source and target separated |

Many modern systems use decoder-only models because one next-token objective can cover many tasks. Encoder-decoder models remain strong when the task has a clear input-output structure.

### Summary

A transformer decoder uses causal self-attention to produce representations that cannot depend on future tokens. This makes it suitable for autoregressive generation.

A decoder-only language model combines token embeddings, position embeddings, stacked decoder layers, and a vocabulary projection. It is trained by next-token prediction using shifted input and target sequences.

During inference, the model generates tokens one at a time. KV caching avoids recomputing past keys and values, making long generation practical.

