Masked language modeling trains a model to recover missing tokens from their surrounding context.
Masked language modeling trains a model to recover missing tokens from their surrounding context. Instead of predicting only the next token, the model receives a corrupted sequence and learns to reconstruct selected hidden tokens.
Given an original sequence
some tokens are replaced with a special mask token. The model predicts the original tokens at the masked positions.
For example:
may become
The training targets are
and
Masked language modeling is most closely associated with BERT-style encoder models. It is especially useful when a model should produce bidirectional representations of text, rather than generate text strictly from left to right.
Bidirectional Context
Autoregressive language models predict each token using only previous tokens:
Masked language models predict a hidden token using both left and right context:
where means the sequence with token hidden.
For example, in the sentence
the word “trained” can be inferred from both sides:
The left context “the model was” suggests a past participle. The right context “on a large corpus” suggests a training-related verb.
This bidirectional access is useful for representation learning. The model can form token representations that depend on the whole sentence.
Encoder-Only Transformers
Masked language modeling is usually implemented with an encoder-only transformer.
An encoder-only transformer differs from a decoder-only autoregressive transformer in its attention mask. In a decoder-only model, causal masking prevents each token from seeing future tokens. In an encoder-only model, every token can attend to every other token.
For a sequence length , the attention pattern is fully visible:
This allows each representation to integrate information from both directions.
If the input sequence is
the transformer produces contextual hidden states
The hidden state at a masked position is used to predict the original token.
Corruption Process
Masked language modeling requires a corruption process. A subset of token positions is selected for prediction.
In the original BERT training procedure, 15 percent of tokens are selected. For selected positions:
| Replacement | Probability |
|---|---|
Replace with [MASK] | 80 percent |
| Replace with a random token | 10 percent |
| Keep unchanged | 10 percent |
This design reduces the mismatch between pretraining and downstream use. The [MASK] token appears during pretraining but usually does not appear during fine-tuning or inference. Keeping some selected tokens unchanged and replacing some with random tokens forces the model to build robust contextual representations.
A simplified corruption function in PyTorch:
import torch
def mask_tokens(input_ids, mask_token_id, vocab_size, mask_prob=0.15):
labels = input_ids.clone()
selected = torch.rand(input_ids.shape) < mask_prob
labels[~selected] = -100
corrupted = input_ids.clone()
replace_with_mask = selected & (torch.rand(input_ids.shape) < 0.8)
corrupted[replace_with_mask] = mask_token_id
replace_with_random = selected & ~replace_with_mask & (
torch.rand(input_ids.shape) < 0.5
)
random_tokens = torch.randint(
low=0,
high=vocab_size,
size=input_ids.shape,
)
corrupted[replace_with_random] = random_tokens[replace_with_random]
return corrupted, labelsThe label value -100 is commonly used with torch.nn.CrossEntropyLoss to ignore unmasked positions.
Training Objective
Let be the set of masked positions. The model is trained to maximize the log probability of the original tokens at those positions:
where is the corrupted sequence.
The loss is
Only selected positions contribute to the loss. The model still computes hidden states for all positions, but the prediction objective is applied only to masked tokens.
In PyTorch, logits usually have shape
and labels have shape
The loss can be computed as:
import torch.nn.functional as F
loss = F.cross_entropy(
logits.reshape(-1, vocab_size),
labels.reshape(-1),
ignore_index=-100,
)The ignored positions do not contribute to the loss.
Masked Language Model Head
An encoder transformer produces contextual hidden states:
A language modeling head maps each hidden state to vocabulary logits:
where
and
Thus
At each selected masked position, the corresponding vector of logits is used to predict the original token.
A minimal head:
import torch.nn as nn
class MaskedLMHead(nn.Module):
def __init__(self, hidden_dim, vocab_size):
super().__init__()
self.proj = nn.Linear(hidden_dim, vocab_size)
def forward(self, hidden_states):
return self.proj(hidden_states)Many production models use a small multilayer head with normalization and activation, but the mathematical role is the same.
Difference from Autoregressive Modeling
Masked language modeling and autoregressive modeling solve different prediction problems.
| Property | Autoregressive LM | Masked LM |
|---|---|---|
| Context | Left context only | Left and right context |
| Typical architecture | Decoder-only transformer | Encoder-only transformer |
| Objective | Predict next token | Predict hidden tokens |
| Generation | Natural sequential generation | Indirect generation |
| Representation learning | Strong | Very strong for understanding tasks |
| Common examples | GPT-style models | BERT-style models |
Autoregressive models define an explicit left-to-right sequence probability:
Masked language models usually do not define a simple normalized probability for the full sequence in the same way. They learn conditional predictions over masked positions. This makes them excellent encoders, but less natural as open-ended generators.
Why Masked Models Work Well for Understanding
Many natural language understanding tasks require a representation of the whole input. Examples include classification, entailment, similarity, named entity recognition, retrieval, and question answering.
For these tasks, bidirectional context is valuable.
Consider sentiment classification:
The representation of “bad” should depend on “not.” The representation of “not” should also depend on “bad.” Bidirectional attention gives each token access to both directions.
For sentence-level tasks, a special classification token such as [CLS] is often inserted at the beginning. Its final hidden state is used as a sequence representation:
A classifier can then compute
This design made encoder-only transformers highly effective for many supervised NLP benchmarks.
Fine-Tuning Masked Language Models
After pretraining, a masked language model can be adapted to downstream tasks.
For text classification, attach a classification head to the sequence representation:
class TextClassifier(nn.Module):
def __init__(self, encoder, hidden_dim, num_classes):
super().__init__()
self.encoder = encoder
self.classifier = nn.Linear(hidden_dim, num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
)
cls_state = outputs.last_hidden_state[:, 0, :]
logits = self.classifier(cls_state)
return logitsFor token classification, attach a classifier to every token state:
where is the number of labels.
This is used for named entity recognition, part-of-speech tagging, and sequence labeling.
For extractive question answering, the model predicts start and end positions in a passage:
Masked language model pretraining gives the encoder useful general-purpose representations, which can then be specialized with supervised data.
Whole-Word and Span Masking
Subword tokenization creates a complication. A word may be split into several subword tokens. For example:
If only one subword is masked, the task may become too easy. The model can recover the missing subword from the visible pieces.
Whole-word masking selects all subword tokens belonging to the same word. The model must reconstruct the full word from context.
Span masking goes further by masking contiguous spans of text. For example:
may become
Span masking is often closer to real language understanding because phrases, not only individual tokens, carry meaning.
Denoising Objectives
Masked language modeling is part of a broader family of denoising objectives.
A denoising model receives corrupted input and learns to reconstruct the original input.
Possible corruptions include:
| Corruption | Example |
|---|---|
| Token masking | Replace tokens with [MASK] |
| Token deletion | Remove tokens |
| Token replacement | Replace tokens with random tokens |
| Span masking | Hide contiguous spans |
| Sentence permutation | Shuffle sentence order |
| Text infilling | Replace spans with sentinel tokens |
Encoder-decoder models such as T5 use text infilling objectives. The encoder reads corrupted text. The decoder generates the missing spans.
For example:
Input:
Target:
This connects masked language modeling to sequence-to-sequence pretraining.
Limitations
Masked language modeling has several limitations.
First, the [MASK] token creates a pretraining and inference mismatch. The model sees artificial tokens during pretraining that may not appear later.
Second, only a fraction of tokens contribute to the loss. Autoregressive models receive a training signal at every position, while masked models often train on selected positions only.
Third, masked models are less direct for open-ended generation. They can fill blanks, but generating long text requires iterative masking strategies or different architectures.
Fourth, independent prediction of multiple masked tokens can be problematic. If two adjacent tokens are masked, the model may predict each one without fully modeling their joint dependency.
These limitations do not make masked language modeling weak. They explain why masked models are often used as encoders, while autoregressive models dominate open-ended generation.
Role in Modern Deep Learning
Masked language modeling helped establish large-scale self-supervised pretraining for language understanding. It showed that a model could learn broadly useful representations from unlabeled text and then transfer to many supervised tasks.
The same denoising principle now appears beyond text. Vision models mask image patches. Audio models mask time-frequency regions. Multimodal models mask tokens, patches, frames, or latent codes.
The general pattern is:
Masked language modeling is therefore both a specific NLP objective and an instance of a larger principle: useful representations can be learned by reconstructing missing information from context.