# Teacher Forcing

Teacher forcing is a training method for autoregressive sequence models. It is used when a model generates an output sequence one token at a time, but during training we already know the correct output sequence.

In a sequence-to-sequence task, the model learns

$$
p(y \mid x) =
\prod_{t=1}^{T}
p(y_t \mid y_{<t}, x).
$$

At step $t$, the decoder predicts $y_t$ using the input sequence $x$ and the previous target tokens $y_{<t}$. Teacher forcing means that, during training, the previous tokens are the correct tokens from the dataset, not the model’s own previous predictions.

### Autoregressive Decoding

Suppose the target sentence is

```text
<bos> I like cats <eos>
```

The decoder input is shifted right:

```text
<bos> I like cats
```

The training target is:

```text
I like cats <eos>
```

The model receives the first sequence and is trained to predict the second sequence.

| Decoder input | Target prediction |
|---|---|
| `<bos>` | `I` |
| `I` | `like` |
| `like` | `cats` |
| `cats` | `<eos>` |

This lets the whole target sequence be trained in parallel. The model does not need to generate token 1 before training token 2. The correct prefix is already available.

### Why Teacher Forcing Is Used

Without teacher forcing, training would require the model to sample or choose its own token at each step, then feed that token back into the decoder. Early in training, the model’s predictions are poor, so later decoder states receive poor inputs. This makes optimization unstable.

Teacher forcing gives the decoder clean input prefixes. This makes the learning problem easier:

$$
p(y_t \mid y_{<t}, x)
$$

is trained using the true $y_{<t}$, so each prediction step receives the correct history.

This has three practical benefits. It improves training stability, allows parallel loss computation over all target positions, and reduces the amount of error compounding during early optimization.

### Training with Shifted Targets

Let the full target token sequence be

$$
Y = (y_0, y_1, \ldots, y_T),
$$

where

$$
y_0 = \texttt{<bos>}
$$

and

$$
y_T = \texttt{<eos>}.
$$

The decoder input is

$$
Y_{\text{in}} = (y_0, y_1, \ldots, y_{T-1}),
$$

and the prediction target is

$$
Y_{\text{out}} = (y_1, y_2, \ldots, y_T).
$$

The model computes logits

$$
Z \in \mathbb{R}^{B \times T \times V},
$$

where $B$ is batch size, $T$ is sequence length, and $V$ is vocabulary size.

The target tensor has shape

$$
Y_{\text{out}} \in \mathbb{N}^{B \times T}.
$$

The loss is cross-entropy over all non-padding target positions:

$$
\mathcal{L} =
-\sum_{b=1}^{B}
\sum_{t=1}^{T}
\log p_\theta(y_{b,t} \mid y_{b,<t}, x_b).
$$

In practice, padding tokens are ignored.

```python
import torch
import torch.nn.functional as F

B, T, V = 32, 20, 50000
pad_id = 0

logits = torch.randn(B, T, V)
targets = torch.randint(0, V, (B, T))

loss = F.cross_entropy(
    logits.reshape(B * T, V),
    targets.reshape(B * T),
    ignore_index=pad_id,
)
```

The `ignore_index` argument prevents padding tokens from contributing to the loss.

### Teacher Forcing in a Training Step

A typical sequence-to-sequence batch contains source tokens and target tokens.

```python
src_tokens.shape  # [B, S]
tgt_tokens.shape  # [B, T + 1]
```

The target includes both the beginning token and the final token.

```python
tgt_input = tgt_tokens[:, :-1]
tgt_target = tgt_tokens[:, 1:]
```

The model receives `tgt_input` and predicts `tgt_target`.

```python
logits = model(src_tokens, tgt_input)

loss = F.cross_entropy(
    logits.reshape(-1, logits.size(-1)),
    tgt_target.reshape(-1),
    ignore_index=pad_id,
)
```

This is teacher forcing in its most common form. The decoder does not receive its own sampled output during training. It receives the ground-truth prefix.

### Inference Is Different

During inference, the target sequence is unknown. The model must generate it.

Generation begins with `<bos>`:

```text
<bos>
```

The model predicts the next token. That predicted token is appended to the decoder input. The process repeats until `<eos>` is generated or a maximum length is reached.

For greedy decoding:

```python
def greedy_decode(model, src_tokens, bos_id, eos_id, max_len):
    device = src_tokens.device
    B = src_tokens.size(0)

    generated = torch.full((B, 1), bos_id, dtype=torch.long, device=device)

    for _ in range(max_len):
        logits = model(src_tokens, generated)
        next_logits = logits[:, -1, :]
        next_token = next_logits.argmax(dim=-1, keepdim=True)

        generated = torch.cat([generated, next_token], dim=1)

        if (next_token == eos_id).all():
            break

    return generated
```

Here the model conditions on its own previous predictions. This differs from training, where the previous tokens are correct dataset tokens.

### Exposure Bias

Teacher forcing creates a mismatch between training and inference.

During training, the model sees prefixes sampled from the data distribution:

$$
(y_1, \ldots, y_{t-1}).
$$

During inference, the model sees prefixes sampled from its own distribution:

$$
(\hat{y}_1, \ldots, \hat{y}_{t-1}).
$$

If the model makes an early mistake, later predictions are conditioned on a prefix it may not have seen during training. This can cause errors to compound.

This problem is called exposure bias.

For example, suppose the correct output is:

```text
the cat sat on the mat
```

If the model generates:

```text
the dog
```

then all future predictions are conditioned on `the dog`, not on `the cat`. The decoder must recover from its own mistake, but teacher forcing gave it little practice with such corrupted prefixes.

### Scheduled Sampling

Scheduled sampling is one attempt to reduce exposure bias. During training, the decoder sometimes receives the correct previous token and sometimes receives the model’s previous prediction.

At step $t$, the previous decoder input is chosen as

$$
\tilde{y}_{t-1} =
\begin{cases}
y_{t-1}, & \text{with probability } p, \\
\hat{y}_{t-1}, & \text{with probability } 1-p.
\end{cases}
$$

The probability $p$ may start near 1 and decrease during training.

The intuition is simple: early in training, use teacher forcing because the model is weak. Later in training, expose the model to its own predictions so it learns to recover from mistakes.

A simplified training loop may look like this:

```python
import torch

def scheduled_sampling_decode(
    model,
    src_tokens,
    tgt_tokens,
    bos_id,
    teacher_forcing_prob,
):
    B, T_plus_1 = tgt_tokens.shape
    T = T_plus_1 - 1
    device = tgt_tokens.device

    generated_inputs = torch.full((B, 1), bos_id, dtype=torch.long, device=device)
    logits_steps = []

    for t in range(T):
        logits = model(src_tokens, generated_inputs)
        next_logits = logits[:, -1, :]
        logits_steps.append(next_logits)

        predicted = next_logits.argmax(dim=-1, keepdim=True)
        gold = tgt_tokens[:, t + 1 : t + 2]

        use_teacher = torch.rand(B, 1, device=device) < teacher_forcing_prob
        next_input = torch.where(use_teacher, gold, predicted)

        generated_inputs = torch.cat([generated_inputs, next_input], dim=1)

    return torch.stack(logits_steps, dim=1)
```

This illustrates the idea, but it is slower than full teacher forcing because decoding proceeds step by step.

### Tradeoffs of Scheduled Sampling

Scheduled sampling tries to make training closer to inference, but it introduces complications.

It reduces parallelism because the model must generate intermediate predictions. It also changes the training distribution in a way that can make optimization less clean. The model may learn from prefixes that are partly correct and partly generated, which can be useful but also noisy.

For large transformer models, standard teacher forcing remains dominant because it is simple, parallel, and efficient. Exposure bias is often handled indirectly through better models, larger data, improved decoding, reinforcement learning, preference optimization, or task-specific fine-tuning.

### Teacher Forcing in Transformers

In transformer decoders, teacher forcing is implemented by passing the whole shifted target sequence at once.

The decoder receives

```text
<bos> y1 y2 ... yT-1
```

and predicts

```text
y1 y2 ... yT
```

A causal mask prevents position $t$ from seeing future target tokens.

Without the causal mask, the decoder could attend to the token it is supposed to predict. That would leak the answer.

```python
def causal_mask(T, device):
    mask = torch.triu(torch.ones(T, T, device=device), diagonal=1)
    return mask.masked_fill(mask == 1, float("-inf"))
```

Then:

```python
tgt_input = tgt_tokens[:, :-1]
tgt_target = tgt_tokens[:, 1:]

mask = causal_mask(tgt_input.size(1), tgt_input.device)

logits = model(src_tokens, tgt_input, tgt_mask=mask)
```

The causal mask preserves autoregressive structure while still allowing parallel training.

### Teacher Forcing Ratio

In recurrent sequence-to-sequence code, the term teacher forcing ratio often appears. It means the probability of using the ground-truth previous token instead of the model’s predicted token.

A ratio of 1.0 means full teacher forcing. A ratio of 0.0 means the model always feeds back its own predictions. Intermediate values implement scheduled sampling.

```python
teacher_forcing_ratio = 0.5
```

This concept is common in RNN implementations. In transformer implementations, full teacher forcing with a causal mask is more common.

### Padding and Loss Masking

Teacher forcing must handle padding correctly. Suppose a batch contains target sequences with different lengths:

```text
<bos> I like cats <eos>
<bos> hello <eos> <pad> <pad>
```

The model should not be penalized for predictions at padding positions.

After shifting:

```text
decoder input:
<bos> I like cats
<bos> hello <eos> <pad>

target:
I like cats <eos>
hello <eos> <pad> <pad>
```

The loss should ignore `<pad>` targets.

```python
loss = F.cross_entropy(
    logits.reshape(-1, vocab_size),
    tgt_target.reshape(-1),
    ignore_index=pad_id,
)
```

Padding masks are also needed inside attention layers so the model does not attend to padding tokens.

### Practical Rules

Use full teacher forcing for most transformer encoder-decoder models. It gives efficient parallel training and stable optimization.

Use a causal mask in the decoder. This is required even when the full target prefix is available.

Shift target tokens carefully. The decoder input excludes the last token. The target excludes the first token.

Ignore padding tokens in the loss. Otherwise the model wastes capacity predicting padding.

When using RNN decoders, decide whether the implementation trains one step at a time or all steps at once. Stepwise implementations make scheduled sampling easier. Parallel implementations are faster.

### Common Bugs

A common bug is using the same target tensor as both decoder input and prediction target. This lets the model see the token it is supposed to predict.

Incorrect:

```python
logits = model(src_tokens, tgt_tokens)
loss = cross_entropy(logits, tgt_tokens)
```

Correct:

```python
tgt_input = tgt_tokens[:, :-1]
tgt_target = tgt_tokens[:, 1:]

logits = model(src_tokens, tgt_input)
loss = cross_entropy(logits, tgt_target)
```

Another common bug is forgetting the causal mask in a transformer decoder. The model may train well but fail during inference because it learned to rely on future tokens.

A third common bug is computing loss on padding tokens. This can make the model overproduce padding and distort evaluation.

### Summary

Teacher forcing trains an autoregressive decoder using the correct previous tokens from the dataset. It converts sequence generation into parallel next-token prediction.

The method is efficient and stable, especially for transformer models. Its main weakness is the training-inference mismatch: during training the model sees correct prefixes, while during inference it sees its own generated prefixes. This mismatch is called exposure bias.

In PyTorch, teacher forcing usually means shifting the target sequence into `tgt_input` and `tgt_target`, applying a causal mask in the decoder, and computing cross-entropy loss over non-padding target positions.

