Teacher forcing is a training method for autoregressive sequence models. It is used when a model generates an output sequence one token at a time, but during training we already know the correct output sequence.
Teacher forcing is a training method for autoregressive sequence models. It is used when a model generates an output sequence one token at a time, but during training we already know the correct output sequence.
In a sequence-to-sequence task, the model learns
At step , the decoder predicts using the input sequence and the previous target tokens . Teacher forcing means that, during training, the previous tokens are the correct tokens from the dataset, not the model’s own previous predictions.
Autoregressive Decoding
Suppose the target sentence is
<bos> I like cats <eos>The decoder input is shifted right:
<bos> I like catsThe training target is:
I like cats <eos>The model receives the first sequence and is trained to predict the second sequence.
| Decoder input | Target prediction |
|---|---|
<bos> | I |
I | like |
like | cats |
cats | <eos> |
This lets the whole target sequence be trained in parallel. The model does not need to generate token 1 before training token 2. The correct prefix is already available.
Why Teacher Forcing Is Used
Without teacher forcing, training would require the model to sample or choose its own token at each step, then feed that token back into the decoder. Early in training, the model’s predictions are poor, so later decoder states receive poor inputs. This makes optimization unstable.
Teacher forcing gives the decoder clean input prefixes. This makes the learning problem easier:
is trained using the true , so each prediction step receives the correct history.
This has three practical benefits. It improves training stability, allows parallel loss computation over all target positions, and reduces the amount of error compounding during early optimization.
Training with Shifted Targets
Let the full target token sequence be
where
and
The decoder input is
and the prediction target is
The model computes logits
where is batch size, is sequence length, and is vocabulary size.
The target tensor has shape
The loss is cross-entropy over all non-padding target positions:
In practice, padding tokens are ignored.
import torch
import torch.nn.functional as F
B, T, V = 32, 20, 50000
pad_id = 0
logits = torch.randn(B, T, V)
targets = torch.randint(0, V, (B, T))
loss = F.cross_entropy(
logits.reshape(B * T, V),
targets.reshape(B * T),
ignore_index=pad_id,
)The ignore_index argument prevents padding tokens from contributing to the loss.
Teacher Forcing in a Training Step
A typical sequence-to-sequence batch contains source tokens and target tokens.
src_tokens.shape # [B, S]
tgt_tokens.shape # [B, T + 1]The target includes both the beginning token and the final token.
tgt_input = tgt_tokens[:, :-1]
tgt_target = tgt_tokens[:, 1:]The model receives tgt_input and predicts tgt_target.
logits = model(src_tokens, tgt_input)
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
tgt_target.reshape(-1),
ignore_index=pad_id,
)This is teacher forcing in its most common form. The decoder does not receive its own sampled output during training. It receives the ground-truth prefix.
Inference Is Different
During inference, the target sequence is unknown. The model must generate it.
Generation begins with <bos>:
<bos>The model predicts the next token. That predicted token is appended to the decoder input. The process repeats until <eos> is generated or a maximum length is reached.
For greedy decoding:
def greedy_decode(model, src_tokens, bos_id, eos_id, max_len):
device = src_tokens.device
B = src_tokens.size(0)
generated = torch.full((B, 1), bos_id, dtype=torch.long, device=device)
for _ in range(max_len):
logits = model(src_tokens, generated)
next_logits = logits[:, -1, :]
next_token = next_logits.argmax(dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=1)
if (next_token == eos_id).all():
break
return generatedHere the model conditions on its own previous predictions. This differs from training, where the previous tokens are correct dataset tokens.
Exposure Bias
Teacher forcing creates a mismatch between training and inference.
During training, the model sees prefixes sampled from the data distribution:
During inference, the model sees prefixes sampled from its own distribution:
If the model makes an early mistake, later predictions are conditioned on a prefix it may not have seen during training. This can cause errors to compound.
This problem is called exposure bias.
For example, suppose the correct output is:
the cat sat on the matIf the model generates:
the dogthen all future predictions are conditioned on the dog, not on the cat. The decoder must recover from its own mistake, but teacher forcing gave it little practice with such corrupted prefixes.
Scheduled Sampling
Scheduled sampling is one attempt to reduce exposure bias. During training, the decoder sometimes receives the correct previous token and sometimes receives the model’s previous prediction.
At step , the previous decoder input is chosen as
The probability may start near 1 and decrease during training.
The intuition is simple: early in training, use teacher forcing because the model is weak. Later in training, expose the model to its own predictions so it learns to recover from mistakes.
A simplified training loop may look like this:
import torch
def scheduled_sampling_decode(
model,
src_tokens,
tgt_tokens,
bos_id,
teacher_forcing_prob,
):
B, T_plus_1 = tgt_tokens.shape
T = T_plus_1 - 1
device = tgt_tokens.device
generated_inputs = torch.full((B, 1), bos_id, dtype=torch.long, device=device)
logits_steps = []
for t in range(T):
logits = model(src_tokens, generated_inputs)
next_logits = logits[:, -1, :]
logits_steps.append(next_logits)
predicted = next_logits.argmax(dim=-1, keepdim=True)
gold = tgt_tokens[:, t + 1 : t + 2]
use_teacher = torch.rand(B, 1, device=device) < teacher_forcing_prob
next_input = torch.where(use_teacher, gold, predicted)
generated_inputs = torch.cat([generated_inputs, next_input], dim=1)
return torch.stack(logits_steps, dim=1)This illustrates the idea, but it is slower than full teacher forcing because decoding proceeds step by step.
Tradeoffs of Scheduled Sampling
Scheduled sampling tries to make training closer to inference, but it introduces complications.
It reduces parallelism because the model must generate intermediate predictions. It also changes the training distribution in a way that can make optimization less clean. The model may learn from prefixes that are partly correct and partly generated, which can be useful but also noisy.
For large transformer models, standard teacher forcing remains dominant because it is simple, parallel, and efficient. Exposure bias is often handled indirectly through better models, larger data, improved decoding, reinforcement learning, preference optimization, or task-specific fine-tuning.
Teacher Forcing in Transformers
In transformer decoders, teacher forcing is implemented by passing the whole shifted target sequence at once.
The decoder receives
<bos> y1 y2 ... yT-1and predicts
y1 y2 ... yTA causal mask prevents position from seeing future target tokens.
Without the causal mask, the decoder could attend to the token it is supposed to predict. That would leak the answer.
def causal_mask(T, device):
mask = torch.triu(torch.ones(T, T, device=device), diagonal=1)
return mask.masked_fill(mask == 1, float("-inf"))Then:
tgt_input = tgt_tokens[:, :-1]
tgt_target = tgt_tokens[:, 1:]
mask = causal_mask(tgt_input.size(1), tgt_input.device)
logits = model(src_tokens, tgt_input, tgt_mask=mask)The causal mask preserves autoregressive structure while still allowing parallel training.
Teacher Forcing Ratio
In recurrent sequence-to-sequence code, the term teacher forcing ratio often appears. It means the probability of using the ground-truth previous token instead of the model’s predicted token.
A ratio of 1.0 means full teacher forcing. A ratio of 0.0 means the model always feeds back its own predictions. Intermediate values implement scheduled sampling.
teacher_forcing_ratio = 0.5This concept is common in RNN implementations. In transformer implementations, full teacher forcing with a causal mask is more common.
Padding and Loss Masking
Teacher forcing must handle padding correctly. Suppose a batch contains target sequences with different lengths:
<bos> I like cats <eos>
<bos> hello <eos> <pad> <pad>The model should not be penalized for predictions at padding positions.
After shifting:
decoder input:
<bos> I like cats
<bos> hello <eos> <pad>
target:
I like cats <eos>
hello <eos> <pad> <pad>The loss should ignore <pad> targets.
loss = F.cross_entropy(
logits.reshape(-1, vocab_size),
tgt_target.reshape(-1),
ignore_index=pad_id,
)Padding masks are also needed inside attention layers so the model does not attend to padding tokens.
Practical Rules
Use full teacher forcing for most transformer encoder-decoder models. It gives efficient parallel training and stable optimization.
Use a causal mask in the decoder. This is required even when the full target prefix is available.
Shift target tokens carefully. The decoder input excludes the last token. The target excludes the first token.
Ignore padding tokens in the loss. Otherwise the model wastes capacity predicting padding.
When using RNN decoders, decide whether the implementation trains one step at a time or all steps at once. Stepwise implementations make scheduled sampling easier. Parallel implementations are faster.
Common Bugs
A common bug is using the same target tensor as both decoder input and prediction target. This lets the model see the token it is supposed to predict.
Incorrect:
logits = model(src_tokens, tgt_tokens)
loss = cross_entropy(logits, tgt_tokens)Correct:
tgt_input = tgt_tokens[:, :-1]
tgt_target = tgt_tokens[:, 1:]
logits = model(src_tokens, tgt_input)
loss = cross_entropy(logits, tgt_target)Another common bug is forgetting the causal mask in a transformer decoder. The model may train well but fail during inference because it learned to rely on future tokens.
A third common bug is computing loss on padding tokens. This can make the model overproduce padding and distort evaluation.
Summary
Teacher forcing trains an autoregressive decoder using the correct previous tokens from the dataset. It converts sequence generation into parallel next-token prediction.
The method is efficient and stable, especially for transformer models. Its main weakness is the training-inference mismatch: during training the model sees correct prefixes, while during inference it sees its own generated prefixes. This mismatch is called exposure bias.
In PyTorch, teacher forcing usually means shifting the target sequence into tgt_input and tgt_target, applying a causal mask in the decoder, and computing cross-entropy loss over non-padding target positions.