# Multi-Head Attention

Multi-head attention runs several attention operations in parallel. Each head has its own query, key, and value projections. The outputs of the heads are concatenated and projected back to the model dimension.

A single attention head computes one pattern of interaction. Multiple heads allow the model to learn several interaction patterns at the same time.

### Motivation

A sentence contains many kinds of relationships. One token may need its subject. Another may need its object. Another may need a nearby modifier. Another may need a distant entity introduced earlier.

A single attention distribution can focus on multiple positions, but it still produces one weighted mixture of values. This can be too restrictive. Multi-head attention gives the model several separate attention distributions.

One head may learn local syntax. Another may track long-range dependencies. Another may attend to punctuation or document boundaries. Another may specialize in semantic relationships.

The model does not receive these roles explicitly. It learns them from the training objective.

### Single-Head Attention

For one attention head, we compute:

$$
Q = XW_Q,\qquad K = XW_K,\qquad V = XW_V.
$$

Then:

$$
\operatorname{head} =
\operatorname{softmax}
\left(
\frac{QK^\top}{\sqrt{d_k}}
\right)V.
$$

The head output has shape:

$$
[B, T, d_v].
$$

A single head therefore gives each token one contextualized vector.

### Multiple Projection Sets

In multi-head attention, we use $H$ heads. Each head has separate learned projections:

$$
W_Q^{(h)},\quad W_K^{(h)},\quad W_V^{(h)}
$$

for

$$
h = 1,\ldots,H.
$$

Each head computes:

$$
\operatorname{head}_h =
\operatorname{Attention}
\left(
XW_Q^{(h)},
XW_K^{(h)},
XW_V^{(h)}
\right).
$$

The heads are then concatenated:

$$
\operatorname{Concat}
(
\operatorname{head}_1,\ldots,\operatorname{head}_H
).
$$

Finally, a learned output projection mixes the head outputs:

$$
\operatorname{MHA}(X) =
\operatorname{Concat}
(
\operatorname{head}_1,\ldots,\operatorname{head}_H
)
W_O.
$$

genui{"math_block_widget_always_prefetch_v2":{"content":"\\operatorname{MHA}(X)=\\operatorname{Concat}(\\operatorname{head}_1,\\ldots,\\operatorname{head}_H)W_O"}}

This is the standard multi-head attention form used in transformers.

### Shape Convention

Let the model dimension be $D$, the number of heads be $H$, and the per-head dimension be

$$
d_h = \frac{D}{H}.
$$

For example, if $D=768$ and $H=12$, then

$$
d_h = 64.
$$

The input has shape:

```python
[B, T, D]
```

After projection and reshaping, queries, keys, and values usually have shape:

```python
[B, H, T, d_h]
```

The attention scores have shape:

```python
[B, H, T, T]
```

The attention output per head has shape:

```python
[B, H, T, d_h]
```

After transposing and concatenating heads, the output returns to:

```python
[B, T, D]
```

### Why Split the Dimension

A common design keeps the total model dimension fixed. Instead of giving each head dimension $D$, the model splits $D$ across heads.

With $H$ heads, each head uses dimension $d_h = D/H$.

This keeps the total computation roughly similar to single-head attention with dimension $D$. The model gains multiple attention patterns without multiplying the hidden width by $H$.

The split also encourages specialization. Each head operates in a lower-dimensional subspace. The output projection then recombines these subspaces.

### PyTorch Implementation

A minimal multi-head self-attention module is:

```python
import math
import torch
import torch.nn as nn

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()

        if embed_dim % num_heads != 0:
            raise ValueError("embed_dim must be divisible by num_heads")

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask=None):
        """
        x:    [B, T, D]
        mask: broadcastable to [B, 1, T, T] or [B, H, T, T]
        """

        B, T, D = x.shape

        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)

        Q = Q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        scores = Q @ K.transpose(-2, -1)
        scores = scores / math.sqrt(self.head_dim)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))

        weights = torch.softmax(scores, dim=-1)

        out = weights @ V
        out = out.transpose(1, 2).contiguous()
        out = out.view(B, T, D)

        out = self.out_proj(out)

        return out, weights
```

The important reshaping step is:

```python
Q = Q.view(B, T, H, d_h).transpose(1, 2)
```

It changes the layout from

```python
[B, T, D]
```

to

```python
[B, H, T, d_h]
```

so each head can compute attention independently.

### Using PyTorch Built-In MultiheadAttention

PyTorch provides `nn.MultiheadAttention`.

A simple use with `batch_first=True` is:

```python
import torch
import torch.nn as nn

mha = nn.MultiheadAttention(
    embed_dim=768,
    num_heads=12,
    batch_first=True,
)

x = torch.randn(4, 128, 768)

out, weights = mha(x, x, x)
print(out.shape)      # torch.Size([4, 128, 768])
print(weights.shape)  # torch.Size([4, 128, 128])
```

The arguments follow the query-key-value pattern:

```python
out, weights = mha(query, key, value)
```

For self-attention, all three are the same tensor. For cross-attention, the query comes from one sequence, while key and value come from another.

### Multi-Head Cross-Attention

Multi-head attention also applies to cross-attention.

Suppose:

```python
query_states:  [B, T_y, D]
source_states: [B, T_x, D]
```

Then:

```python
out, weights = mha(
    query_states,
    source_states,
    source_states,
)
```

The query sequence receives information from the source sequence.

In an encoder-decoder transformer, this operation appears inside the decoder. The decoder hidden states are queries. The encoder outputs are keys and values.

### Attention Head Specialization

Different heads can learn different patterns.

Some heads may attend locally. Some may attend to syntactic parents. Some may attend to delimiter tokens. Some may copy information from previous positions. Some may focus on task-specific cues.

However, head specialization should be interpreted carefully. Attention patterns are useful diagnostics, but a head’s visible attention weights do not fully explain model behavior. The output projection mixes heads, residual paths carry information around attention blocks, and later layers can transform or override earlier signals.

### Redundant Heads

Not every head is equally important. In many trained models, some heads can be pruned with little loss in performance.

This suggests that multi-head attention provides capacity and optimization flexibility, but some learned heads may be redundant. Redundancy can also improve robustness and make training easier.

Head pruning and head importance analysis are useful tools for compression and interpretability.

### Number of Heads

The number of heads is a design choice. Common transformer configurations use values such as 8, 12, 16, 32, or more.

The choice interacts with the model dimension. Since

$$
d_h = D/H,
$$

increasing the number of heads decreases the per-head dimension if $D$ is fixed.

Too few heads may limit the diversity of attention patterns. Too many heads may make each head too narrow. In practice, per-head dimensions around 64 or 128 are common in many transformer families.

### Masks in Multi-Head Attention

Masks usually broadcast across heads.

A causal mask may have shape:

```python
[T, T]
```

or

```python
[1, 1, T, T]
```

A padding mask may have shape:

```python
[B, 1, 1, T]
```

The final score tensor has shape:

```python
[B, H, T_q, T_k]
```

The mask must be broadcastable to this shape.

For causal self-attention:

```python
T = x.size(1)

causal_mask = torch.tril(
    torch.ones(T, T, device=x.device)
).bool()

causal_mask = causal_mask[None, None, :, :]
```

This mask allows every head to use the same causal structure.

### Built-In Scaled Dot-Product Attention

A modern PyTorch implementation can use `torch.nn.functional.scaled_dot_product_attention`.

```python
import torch.nn.functional as F

out = F.scaled_dot_product_attention(
    Q,
    K,
    V,
    attn_mask=None,
    is_causal=False,
)
```

Here `Q`, `K`, and `V` have shape:

```python
[B, H, T, d_h]
```

This function may dispatch to optimized attention kernels depending on hardware, dtype, mask type, and PyTorch version.

It is usually better than a handwritten attention implementation for speed and memory use.

### Multi-Head Attention and Representation Mixing

Multi-head attention has two kinds of mixing.

First, each head mixes tokens through attention weights. This is sequence mixing.

Second, the output projection $W_O$ mixes the concatenated head channels. This is feature mixing.

A transformer block then adds a feedforward network, which further mixes features independently at each position.

Thus a transformer alternates between:

| Operation | Main role |
|---|---|
| Multi-head attention | Mix information across positions |
| Feedforward network | Transform features within each position |

This alternation gives transformers much of their expressive power.

### Summary

Multi-head attention runs several attention heads in parallel. Each head has its own projections and produces its own attention distribution. The head outputs are concatenated and projected back to the model dimension.

This gives the model multiple learned routes for information flow. Some heads may capture local structure, some long-range dependencies, and some task-specific patterns.

Multi-head attention is the core attention module in transformer architectures.

