Skip to content

Multi-Head Attention

Multi-head attention runs several attention operations in parallel.

Multi-head attention runs several attention operations in parallel. Each attention operation is called a head. Each head has its own query, key, and value projections. The outputs of all heads are then concatenated and projected back into the model dimension.

Single-head attention gives each token one attention distribution. Multi-head attention gives each token several attention distributions. This lets the model read different kinds of context at the same time.

For example, in a language model, one head may attend to nearby tokens, another to a previous subject, another to punctuation, and another to a delimiter or special token. These patterns are learned from data. They are not manually assigned.

Motivation

Self-attention computes one weighted average of value vectors:

zi=j=1Taijvj. z_i = \sum_{j=1}^{T} a_{ij}v_j.

This can be too restrictive. A token may need more than one relation at once.

Consider:

The book that the students borrowed from the library was expensive.

The token was may need information about the subject book, not students. At the same time, the model may also need local phrase information, long-range dependency information, and semantic information. A single attention distribution has to combine all of these needs into one set of weights.

Multi-head attention separates this work. Each head forms its own representation subspace and its own attention pattern.

Head Dimensions

Let the input be

XRB×T×D. X \in \mathbb{R}^{B \times T \times D}.

Here DD is the model dimension. If we use HH heads, each head usually has dimension

dh=DH. d_h = \frac{D}{H}.

For example, if D=512D=512 and H=8H=8, then each head has dimension 6464.

Each head hh has its own projections:

Qh=XWhQ,Kh=XWhK,Vh=XWhV. Q_h = XW^Q_h,\quad K_h = XW^K_h,\quad V_h = XW^V_h.

Then each head computes

Zh=softmax(QhKhdh)Vh. Z_h = \operatorname{softmax} \left( \frac{Q_hK_h^\top}{\sqrt{d_h}} \right)V_h.

The head outputs are concatenated:

Z=concat(Z1,,ZH). Z = \operatorname{concat}(Z_1,\ldots,Z_H).

Finally, an output projection mixes the heads:

Y=ZWO. Y = ZW^O.

The result has the same shape as the input:

YRB×T×D. Y \in \mathbb{R}^{B \times T \times D}.

Tensor Shapes

In PyTorch, the input starts as:

[B, T, D]

After projecting queries, keys, and values, we reshape each tensor into:

[B, T, H, d_h]

Then we transpose to place the head axis before the sequence axis:

[B, H, T, d_h]

This layout is convenient because attention is computed independently for each head.

The attention scores have shape:

[B, H, T, T]

The head outputs have shape:

[B, H, T, d_h]

After transposing and concatenating heads, the tensor returns to:

[B, T, D]

Shape discipline is important. Most bugs in multi-head attention come from incorrect reshaping, transposing, or forgetting to call .contiguous() before .view().

PyTorch Implementation

A minimal multi-head self-attention layer can be implemented as follows:

import math
import torch
from torch import nn
import torch.nn.functional as F

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()

        if d_model % num_heads != 0:
            raise ValueError("d_model must be divisible by num_heads")

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, T, D]
        B, T, D = x.shape

        x = x.view(B, T, self.num_heads, self.d_head)
        x = x.transpose(1, 2)

        # [B, H, T, d_head]
        return x

    def merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, H, T, d_head]
        B, H, T, d_head = x.shape

        x = x.transpose(1, 2)
        x = x.contiguous()
        x = x.view(B, T, H * d_head)

        # [B, T, D]
        return x

    def forward(
        self,
        x: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        causal: bool = False,
    ):
        # x: [B, T, D]
        B, T, D = x.shape

        q = self.split_heads(self.q_proj(x))
        k = self.split_heads(self.k_proj(x))
        v = self.split_heads(self.v_proj(x))

        # q, k, v: [B, H, T, d_head]
        scores = q @ k.transpose(-2, -1)
        scores = scores / math.sqrt(self.d_head)

        # scores: [B, H, T, T]
        if attention_mask is not None:
            # attention_mask: [B, T], 1 for real token, 0 for padding
            key_mask = attention_mask[:, None, None, :].bool()
            scores = scores.masked_fill(~key_mask, float("-inf"))

        if causal:
            causal_mask = torch.triu(
                torch.ones(T, T, device=x.device, dtype=torch.bool),
                diagonal=1,
            )
            scores = scores.masked_fill(causal_mask, float("-inf"))

        weights = F.softmax(scores, dim=-1)

        out = weights @ v
        out = self.merge_heads(out)
        out = self.out_proj(out)

        return out, weights

Example use:

x = torch.randn(4, 16, 128)

attn = MultiHeadSelfAttention(d_model=128, num_heads=8)
out, weights = attn(x, causal=True)

print(out.shape)      # torch.Size([4, 16, 128])
print(weights.shape)  # torch.Size([4, 8, 16, 16])

The output is a sequence of contextual vectors. The weights store one attention matrix per head.

Using PyTorch Built-In MultiheadAttention

PyTorch provides nn.MultiheadAttention. It supports self-attention and cross-attention. Historically, this module used sequence-first tensors by default:

[T, B, D]

Modern code often sets batch_first=True, which allows:

[B, T, D]

Example:

import torch
from torch import nn

mha = nn.MultiheadAttention(
    embed_dim=128,
    num_heads=8,
    batch_first=True,
)

x = torch.randn(4, 16, 128)

out, weights = mha(
    query=x,
    key=x,
    value=x,
    need_weights=True,
)

print(out.shape)      # torch.Size([4, 16, 128])
print(weights.shape)  # torch.Size([4, 16, 16])

By default, the returned weights are averaged across heads. To return separate weights for each head:

out, weights = mha(
    query=x,
    key=x,
    value=x,
    need_weights=True,
    average_attn_weights=False,
)

print(weights.shape)  # torch.Size([4, 8, 16, 16])

For many production systems, nn.MultiheadAttention or torch.nn.functional.scaled_dot_product_attention is preferred over hand-written attention. The manual version remains useful because it reveals the tensor transformations.

Multi-Head Cross-Attention

Multi-head attention can also be used for cross-attention. In cross-attention, the queries come from one sequence, while keys and values come from another.

Let

XRB×Tx×D X \in \mathbb{R}^{B \times T_x \times D}

be the query sequence, and

CRB×Tc×D C \in \mathbb{R}^{B \times T_c \times D}

be the context sequence. Then

Q=XWQ,K=CWK,V=CWV. Q = XW_Q, \quad K = CW_K, \quad V = CW_V.

The attention scores have shape

[B,H,Tx,Tc]. [B,H,T_x,T_c].

Cross-attention is used in encoder-decoder transformers, image captioning, speech recognition, retrieval-augmented models, and multimodal systems.

Using PyTorch:

mha = nn.MultiheadAttention(
    embed_dim=256,
    num_heads=8,
    batch_first=True,
)

decoder_states = torch.randn(2, 12, 256)
encoder_states = torch.randn(2, 20, 256)

out, weights = mha(
    query=decoder_states,
    key=encoder_states,
    value=encoder_states,
    need_weights=True,
    average_attn_weights=False,
)

print(out.shape)      # torch.Size([2, 12, 256])
print(weights.shape)  # torch.Size([2, 8, 12, 20])

Each decoder position attends to the encoder positions.

Why Heads Use Smaller Dimensions

A common question is why each head uses dh=D/Hd_h = D/H instead of dimension DD. The reason is computational cost.

If every head used the full dimension DD, the cost would grow roughly linearly with the number of heads. By splitting the model dimension across heads, the total projected dimension stays DD. This keeps the layer cost close to single-head attention while allowing multiple attention patterns.

For D=512D=512 and H=8H=8, the model computes eight heads of dimension 6464. After concatenation, the result returns to dimension 512512.

This design gives parallel attention diversity without multiplying the output width.

Attention Heads as Representation Subspaces

Each head operates in its own learned subspace. The projections WQW_Q, WKW_K, and WVW_V map the input into head-specific coordinates.

This means two heads can attend to the same token for different reasons. One head may compare syntactic features. Another may compare positional or semantic features. Another may specialize in copying information from a previous token.

However, attention heads should be interpreted carefully. A visible attention pattern is not always a complete explanation of model behavior. The value vectors, output projection, residual paths, and later layers all affect the final representation.

Dropout in Attention

Transformer implementations often apply dropout to attention weights or to the output projection.

Attention dropout randomly removes some attention probability mass during training. This discourages the model from depending too strongly on a small number of positions.

A simple version:

class MultiHeadSelfAttentionWithDropout(MultiHeadSelfAttention):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__(d_model, num_heads)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        causal: bool = False,
    ):
        B, T, D = x.shape

        q = self.split_heads(self.q_proj(x))
        k = self.split_heads(self.k_proj(x))
        v = self.split_heads(self.v_proj(x))

        scores = q @ k.transpose(-2, -1)
        scores = scores / math.sqrt(self.d_head)

        if attention_mask is not None:
            key_mask = attention_mask[:, None, None, :].bool()
            scores = scores.masked_fill(~key_mask, float("-inf"))

        if causal:
            causal_mask = torch.triu(
                torch.ones(T, T, device=x.device, dtype=torch.bool),
                diagonal=1,
            )
            scores = scores.masked_fill(causal_mask, float("-inf"))

        weights = F.softmax(scores, dim=-1)
        weights = self.dropout(weights)

        out = weights @ v
        out = self.merge_heads(out)
        out = self.out_proj(out)

        return out, weights

In practice, dropout is usually included as part of a full transformer block, along with residual connections, layer normalization, and a feedforward network.

Numerical Stability

Attention uses softmax, so numerical stability matters. Very large positive scores can dominate the distribution. Very large negative scores can underflow to zero. PyTorch’s softmax implementation is stable, but masking must still be handled carefully.

The usual masking pattern is:

scores = scores.masked_fill(mask, float("-inf"))
weights = F.softmax(scores, dim=-1)

This works when every row has at least one unmasked position. If an entire row is masked, softmax receives all -inf values and can produce NaN. This can happen with incorrect padding masks.

A robust implementation ensures that every query has at least one valid key, or handles fully masked rows explicitly.

Parameter Count

For model dimension DD, multi-head attention usually has four learned linear projections:

WQ,WK,WV,WO. W_Q,\quad W_K,\quad W_V,\quad W_O.

Each is roughly D×DD \times D, ignoring bias. Therefore, the parameter count is approximately

4D2. 4D^2.

The number of heads changes how the dimension is partitioned. It usually does not change the total parameter count if DD is fixed.

For D=768D=768, the attention projection parameters are approximately:

4×7682=2,359,296. 4 \times 768^2 = 2{,}359{,}296.

Bias terms add only a small amount by comparison.

Summary

Multi-head attention extends self-attention by computing several attention patterns in parallel. Each head uses its own projections and operates in a smaller subspace. The head outputs are concatenated and mixed through an output projection.

The input and output shapes usually remain [B, T, D]. Internally, the computation uses [B, H, T, d_head], where HH is the number of heads and dhead=D/Hd_head = D/H.

Multi-head attention is the central operation inside transformer layers. It gives the model multiple ways to retrieve context, while keeping the total model dimension fixed.