Skip to content

Masked Language Modeling

Masked language modeling trains a model to recover missing tokens from their surrounding context.

Masked language modeling trains a model to recover missing tokens from their surrounding context. Instead of predicting only the next token, the model receives a corrupted sequence and learns to reconstruct selected hidden tokens.

Given an original sequence

x1:T=(x1,x2,,xT), x_{1:T} = (x_1, x_2, \ldots, x_T),

some tokens are replaced with a special mask token. The model predicts the original tokens at the masked positions.

For example:

deep learning models use tensors \text{deep learning models use tensors}

may become

deep [MASK] models use [MASK]. \text{deep [MASK] models use [MASK]}.

The training targets are

learning \text{learning}

and

tensors. \text{tensors}.

Masked language modeling is most closely associated with BERT-style encoder models. It is especially useful when a model should produce bidirectional representations of text, rather than generate text strictly from left to right.

Bidirectional Context

Autoregressive language models predict each token using only previous tokens:

pθ(xtx1:t1). p_\theta(x_t \mid x_{1:t-1}).

Masked language models predict a hidden token using both left and right context:

pθ(xtxt), p_\theta(x_t \mid x_{\setminus t}),

where xtx_{\setminus t} means the sequence with token xtx_t hidden.

For example, in the sentence

the model was trained on a large corpus \text{the model was trained on a large corpus}

the word “trained” can be inferred from both sides:

the model was [MASK] on a large corpus. \text{the model was [MASK] on a large corpus}.

The left context “the model was” suggests a past participle. The right context “on a large corpus” suggests a training-related verb.

This bidirectional access is useful for representation learning. The model can form token representations that depend on the whole sentence.

Encoder-Only Transformers

Masked language modeling is usually implemented with an encoder-only transformer.

An encoder-only transformer differs from a decoder-only autoregressive transformer in its attention mask. In a decoder-only model, causal masking prevents each token from seeing future tokens. In an encoder-only model, every token can attend to every other token.

For a sequence length TT, the attention pattern is fully visible:

[111111111111]. \begin{bmatrix} 1 & 1 & 1 & \cdots & 1 \\ 1 & 1 & 1 & \cdots & 1 \\ \vdots & \vdots & \ddots & \vdots & \vdots \\ 1 & 1 & 1 & \cdots & 1 \end{bmatrix}.

This allows each representation to integrate information from both directions.

If the input sequence is

[CLS],x1,x2,,xT,[SEP], [\text{CLS}], x_1, x_2, \ldots, x_T, [\text{SEP}],

the transformer produces contextual hidden states

h0,h1,h2,,hT,hT+1. h_0, h_1, h_2, \ldots, h_T, h_{T+1}.

The hidden state at a masked position is used to predict the original token.

Corruption Process

Masked language modeling requires a corruption process. A subset of token positions is selected for prediction.

In the original BERT training procedure, 15 percent of tokens are selected. For selected positions:

ReplacementProbability
Replace with [MASK]80 percent
Replace with a random token10 percent
Keep unchanged10 percent

This design reduces the mismatch between pretraining and downstream use. The [MASK] token appears during pretraining but usually does not appear during fine-tuning or inference. Keeping some selected tokens unchanged and replacing some with random tokens forces the model to build robust contextual representations.

A simplified corruption function in PyTorch:

import torch

def mask_tokens(input_ids, mask_token_id, vocab_size, mask_prob=0.15):
    labels = input_ids.clone()

    selected = torch.rand(input_ids.shape) < mask_prob

    labels[~selected] = -100

    corrupted = input_ids.clone()

    replace_with_mask = selected & (torch.rand(input_ids.shape) < 0.8)
    corrupted[replace_with_mask] = mask_token_id

    replace_with_random = selected & ~replace_with_mask & (
        torch.rand(input_ids.shape) < 0.5
    )
    random_tokens = torch.randint(
        low=0,
        high=vocab_size,
        size=input_ids.shape,
    )
    corrupted[replace_with_random] = random_tokens[replace_with_random]

    return corrupted, labels

The label value -100 is commonly used with torch.nn.CrossEntropyLoss to ignore unmasked positions.

Training Objective

Let MM be the set of masked positions. The model is trained to maximize the log probability of the original tokens at those positions:

maxθtMlogpθ(xtx~), \max_\theta \sum_{t \in M} \log p_\theta(x_t \mid \tilde{x}),

where x~\tilde{x} is the corrupted sequence.

The loss is

L(θ)=tMlogpθ(xtx~). \mathcal{L}(\theta) = - \sum_{t \in M} \log p_\theta(x_t \mid \tilde{x}).

Only selected positions contribute to the loss. The model still computes hidden states for all positions, but the prediction objective is applied only to masked tokens.

In PyTorch, logits usually have shape

[B,T,V], [B, T, |V|],

and labels have shape

[B,T]. [B, T].

The loss can be computed as:

import torch.nn.functional as F

loss = F.cross_entropy(
    logits.reshape(-1, vocab_size),
    labels.reshape(-1),
    ignore_index=-100,
)

The ignored positions do not contribute to the loss.

Masked Language Model Head

An encoder transformer produces contextual hidden states:

HRB×T×d. H \in \mathbb{R}^{B \times T \times d}.

A language modeling head maps each hidden state to vocabulary logits:

Z=HW+b, Z = HW + b,

where

WRd×V W \in \mathbb{R}^{d \times |V|}

and

bRV. b \in \mathbb{R}^{|V|}.

Thus

ZRB×T×V. Z \in \mathbb{R}^{B \times T \times |V|}.

At each selected masked position, the corresponding vector of logits is used to predict the original token.

A minimal head:

import torch.nn as nn

class MaskedLMHead(nn.Module):
    def __init__(self, hidden_dim, vocab_size):
        super().__init__()
        self.proj = nn.Linear(hidden_dim, vocab_size)

    def forward(self, hidden_states):
        return self.proj(hidden_states)

Many production models use a small multilayer head with normalization and activation, but the mathematical role is the same.

Difference from Autoregressive Modeling

Masked language modeling and autoregressive modeling solve different prediction problems.

PropertyAutoregressive LMMasked LM
ContextLeft context onlyLeft and right context
Typical architectureDecoder-only transformerEncoder-only transformer
ObjectivePredict next tokenPredict hidden tokens
GenerationNatural sequential generationIndirect generation
Representation learningStrongVery strong for understanding tasks
Common examplesGPT-style modelsBERT-style models

Autoregressive models define an explicit left-to-right sequence probability:

p(x1:T)=t=1Tp(xtx1:t1). p(x_{1:T}) = \prod_{t=1}^{T} p(x_t \mid x_{1:t-1}).

Masked language models usually do not define a simple normalized probability for the full sequence in the same way. They learn conditional predictions over masked positions. This makes them excellent encoders, but less natural as open-ended generators.

Why Masked Models Work Well for Understanding

Many natural language understanding tasks require a representation of the whole input. Examples include classification, entailment, similarity, named entity recognition, retrieval, and question answering.

For these tasks, bidirectional context is valuable.

Consider sentiment classification:

the film was not bad \text{the film was not bad}

The representation of “bad” should depend on “not.” The representation of “not” should also depend on “bad.” Bidirectional attention gives each token access to both directions.

For sentence-level tasks, a special classification token such as [CLS] is often inserted at the beginning. Its final hidden state is used as a sequence representation:

hCLS. h_{\text{CLS}}.

A classifier can then compute

y^=softmax(WhCLS+b). \hat{y} = \text{softmax}(Wh_{\text{CLS}} + b).

This design made encoder-only transformers highly effective for many supervised NLP benchmarks.

Fine-Tuning Masked Language Models

After pretraining, a masked language model can be adapted to downstream tasks.

For text classification, attach a classification head to the sequence representation:

class TextClassifier(nn.Module):
    def __init__(self, encoder, hidden_dim, num_classes):
        super().__init__()
        self.encoder = encoder
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )

        cls_state = outputs.last_hidden_state[:, 0, :]
        logits = self.classifier(cls_state)

        return logits

For token classification, attach a classifier to every token state:

ZRB×T×K, Z \in \mathbb{R}^{B \times T \times K},

where KK is the number of labels.

This is used for named entity recognition, part-of-speech tagging, and sequence labeling.

For extractive question answering, the model predicts start and end positions in a passage:

pstart(t),pend(t). p_{\text{start}}(t), \quad p_{\text{end}}(t).

Masked language model pretraining gives the encoder useful general-purpose representations, which can then be specialized with supervised data.

Whole-Word and Span Masking

Subword tokenization creates a complication. A word may be split into several subword tokens. For example:

unbelievableun,##believ,##able. \text{unbelievable} \rightarrow \text{un}, \text{\#\#believ}, \text{\#\#able}.

If only one subword is masked, the task may become too easy. The model can recover the missing subword from the visible pieces.

Whole-word masking selects all subword tokens belonging to the same word. The model must reconstruct the full word from context.

Span masking goes further by masking contiguous spans of text. For example:

the model was trained on a large dataset \text{the model was trained on a large dataset}

may become

the model was [MASK] [MASK] a large dataset. \text{the model was [MASK] [MASK] a large dataset}.

Span masking is often closer to real language understanding because phrases, not only individual tokens, carry meaning.

Denoising Objectives

Masked language modeling is part of a broader family of denoising objectives.

A denoising model receives corrupted input and learns to reconstruct the original input.

Possible corruptions include:

CorruptionExample
Token maskingReplace tokens with [MASK]
Token deletionRemove tokens
Token replacementReplace tokens with random tokens
Span maskingHide contiguous spans
Sentence permutationShuffle sentence order
Text infillingReplace spans with sentinel tokens

Encoder-decoder models such as T5 use text infilling objectives. The encoder reads corrupted text. The decoder generates the missing spans.

For example:

Input:

deep learning <extra_id_0> tensors \text{deep learning <extra\_id\_0> tensors}

Target:

<extra_id_0> models use \text{<extra\_id\_0> models use}

This connects masked language modeling to sequence-to-sequence pretraining.

Limitations

Masked language modeling has several limitations.

First, the [MASK] token creates a pretraining and inference mismatch. The model sees artificial tokens during pretraining that may not appear later.

Second, only a fraction of tokens contribute to the loss. Autoregressive models receive a training signal at every position, while masked models often train on selected positions only.

Third, masked models are less direct for open-ended generation. They can fill blanks, but generating long text requires iterative masking strategies or different architectures.

Fourth, independent prediction of multiple masked tokens can be problematic. If two adjacent tokens are masked, the model may predict each one without fully modeling their joint dependency.

These limitations do not make masked language modeling weak. They explain why masked models are often used as encoders, while autoregressive models dominate open-ended generation.

Role in Modern Deep Learning

Masked language modeling helped establish large-scale self-supervised pretraining for language understanding. It showed that a model could learn broadly useful representations from unlabeled text and then transfer to many supervised tasks.

The same denoising principle now appears beyond text. Vision models mask image patches. Audio models mask time-frequency regions. Multimodal models mask tokens, patches, frames, or latent codes.

The general pattern is:

corrupt inputpredict missing structurelearn useful representations. \text{corrupt input} \rightarrow \text{predict missing structure} \rightarrow \text{learn useful representations}.

Masked language modeling is therefore both a specific NLP objective and an instance of a larger principle: useful representations can be learned by reconstructing missing information from context.