# Attention Mechanisms

Attention is a method for letting a model choose which parts of an input are most relevant when producing an output. It replaces the idea that all input positions should contribute equally. In a sequence model, attention allows one token to look at other tokens. In an image model, it allows one patch to look at other patches. In a multimodal model, it allows text tokens to look at image regions, audio frames, or retrieved documents.

The central operation is simple. Given a query, attention compares it with a set of keys, converts the comparisons into weights, and uses those weights to form a weighted average of values.

The three objects are:

| Object | Meaning |
|---|---|
| Query | What the current position is looking for |
| Key | What each candidate position offers for comparison |
| Value | The information retrieved from each candidate position |

For a single query $q$, keys $k_1,\ldots,k_n$, and values $v_1,\ldots,v_n$, attention computes scores

$$
s_i = q^\top k_i.
$$

The scores are converted into normalized weights by softmax:

$$
\alpha_i =
\frac{\exp(s_i)}
{\sum_{j=1}^{n}\exp(s_j)}.
$$

The output is the weighted sum

$$
z = \sum_{i=1}^{n} \alpha_i v_i.
$$

The weight $\alpha_i$ tells us how much the output should use value $v_i$. Large score, large weight. Small score, small weight.

### From Fixed Representations to Attention

Older sequence-to-sequence models often compressed the whole input sequence into one fixed-size vector. For short sequences this can work. For long sequences it creates a bottleneck. The decoder must recover all relevant information from one vector, even when different output positions need different parts of the input.

Attention removes this bottleneck. Instead of using one fixed representation, the model keeps a collection of hidden states and learns which ones to read from at each step.

For machine translation, when producing a target word, the decoder can attend to the source words most relevant to that word. For summarization, the model can attend to important sentences or phrases. For language modeling, each token can attend to earlier tokens that provide useful context.

The principle is general: attention is differentiable content-based retrieval.

### Dot-Product Attention

The simplest common form is dot-product attention. Suppose we have matrices

$$
Q \in \mathbb{R}^{m \times d},
\quad
K \in \mathbb{R}^{n \times d},
\quad
V \in \mathbb{R}^{n \times d_v}.
$$

Here $Q$ contains $m$ queries, $K$ contains $n$ keys, and $V$ contains $n$ values.

The score matrix is

$$
S = QK^\top.
$$

Its shape is

$$
S \in \mathbb{R}^{m \times n}.
$$

Each entry $S_{ij}$ measures how strongly query $i$ matches key $j$.

After applying softmax along the key dimension, we obtain the attention weight matrix

$$
A = \operatorname{softmax}(S).
$$

Then the attention output is

$$
Z = AV.
$$

The shape of the output is

$$
Z \in \mathbb{R}^{m \times d_v}.
$$

This is the core of attention. Scores are computed by similarity. Weights are computed by softmax. Values are combined by weighted averaging.

### Scaled Dot-Product Attention

In transformers, dot-product attention is usually scaled by $\sqrt{d}$, where $d$ is the key dimension:

$$
\operatorname{Attention}(Q,K,V) =
\operatorname{softmax}
\left(
\frac{QK^\top}{\sqrt{d}}
\right)V.
$$

The scaling factor matters because dot products grow in magnitude as the feature dimension grows. If the entries of $q$ and $k$ have roughly unit variance, then $q^\top k$ has variance proportional to $d$. Large scores push the softmax toward extreme probabilities. This can make gradients small and training unstable.

Dividing by $\sqrt{d}$ keeps the score scale more stable.

In PyTorch:

```python
import math
import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V):
    d = Q.shape[-1]
    scores = Q @ K.transpose(-2, -1)
    scores = scores / math.sqrt(d)
    weights = F.softmax(scores, dim=-1)
    output = weights @ V
    return output, weights
```

This function works for batched inputs because `transpose(-2, -1)` swaps the last two axes, and matrix multiplication applies over the leading batch axes.

Example:

```python
B = 2      # batch size
T = 5      # sequence length
D = 8      # feature dimension

Q = torch.randn(B, T, D)
K = torch.randn(B, T, D)
V = torch.randn(B, T, D)

out, weights = scaled_dot_product_attention(Q, K, V)

print(out.shape)      # torch.Size([2, 5, 8])
print(weights.shape)  # torch.Size([2, 5, 5])
```

The output has one vector per query position. The weight tensor has shape `[B, T, T]`, meaning that each token in each batch item has a distribution over all token positions.

### Attention as Weighted Averaging

Attention can be understood as a learned weighted average. The values $V$ contain candidate information. The attention weights decide how much of each value to include.

Suppose a sequence has five tokens. For token 3, the attention weights may be

$$
[0.05,\ 0.10,\ 0.20,\ 0.55,\ 0.10].
$$

Then the output for token 3 is mostly influenced by token 4, with smaller contributions from the other tokens.

This view is useful but incomplete. The model learns the queries, keys, and values through linear projections. The attention weights are not fixed linguistic rules. They are produced by trainable representations.

For an input tensor

$$
X \in \mathbb{R}^{B \times T \times D},
$$

a transformer layer usually computes

$$
Q = XW_Q,\quad K = XW_K,\quad V = XW_V.
$$

The matrices $W_Q$, $W_K$, and $W_V$ are learned parameters. Thus the model learns what to search for, what to compare against, and what information to retrieve.

### Self-Attention

Self-attention is attention where the queries, keys, and values are all computed from the same input sequence.

Given

$$
X \in \mathbb{R}^{B \times T \times D},
$$

we define

$$
Q = XW_Q,\quad
K = XW_K,\quad
V = XW_V.
$$

Then attention is computed between positions in the same sequence.

Self-attention allows each token to build a representation using information from other tokens. For example, in the sentence

```text
The animal did not cross the street because it was tired.
```

the word `it` may need to attend to `animal` to resolve its meaning. In another sentence,

```text
The animal did not cross the street because it was flooded.
```

the word `it` may need to attend to `street`.

Self-attention gives the model a mechanism for such contextual dependency. The meaning of a token becomes conditional on the surrounding sequence.

A minimal self-attention module in PyTorch can be written as:

```python
import math
import torch
from torch import nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        # x: [B, T, D]
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)

        d = Q.shape[-1]
        scores = Q @ K.transpose(-2, -1)
        scores = scores / math.sqrt(d)

        weights = F.softmax(scores, dim=-1)
        out = weights @ V
        out = self.out_proj(out)

        return out, weights
```

Example:

```python
x = torch.randn(4, 10, 32)

attn = SelfAttention(d_model=32)
out, weights = attn(x)

print(out.shape)      # torch.Size([4, 10, 32])
print(weights.shape)  # torch.Size([4, 10, 10])
```

The model maps a sequence of 10 input vectors to a sequence of 10 output vectors. Each output vector is a context-dependent representation of the corresponding input position.

### Cross-Attention

Cross-attention uses queries from one source and keys and values from another source.

Suppose

$$
X \in \mathbb{R}^{B \times T_x \times D}
$$

is a target-side representation, and

$$
Y \in \mathbb{R}^{B \times T_y \times D}
$$

is a source-side representation. Cross-attention computes

$$
Q = XW_Q,
\quad
K = YW_K,
\quad
V = YW_V.
$$

The target sequence asks questions. The source sequence provides keys and values.

This is common in encoder-decoder transformers. The decoder uses cross-attention to read from the encoder output. It is also common in multimodal systems. A text decoder may use cross-attention to read from image features, audio features, retrieved documents, or tool outputs.

A minimal PyTorch implementation:

```python
class CrossAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x, context):
        # x:       [B, T_x, D]
        # context: [B, T_y, D]
        Q = self.q_proj(x)
        K = self.k_proj(context)
        V = self.v_proj(context)

        d = Q.shape[-1]
        scores = Q @ K.transpose(-2, -1)
        scores = scores / math.sqrt(d)

        weights = F.softmax(scores, dim=-1)
        out = weights @ V
        out = self.out_proj(out)

        return out, weights
```

The attention weights have shape

```python
[B, T_x, T_y]
```

Each target position has a distribution over source positions.

### Masked Attention

Sometimes a model should not attend to every position. A mask prevents attention from using forbidden positions.

There are two common masks.

A padding mask hides padding tokens. In a batch, sequences often have different lengths. Shorter sequences are padded to match the longest one. The model should ignore those padding positions.

A causal mask hides future tokens. In autoregressive language modeling, token $t$ may attend only to tokens $1,\ldots,t$. It must not attend to tokens after $t$, because those tokens are unknown during generation.

For a sequence of length $T$, a causal mask has the form:

$$
M_{ij} =
\begin{cases}
0, & j \leq i, \\
-\infty, & j > i.
\end{cases}
$$

The mask is added to the attention scores before softmax. Positions with score $-\infty$ receive zero probability after softmax.

In PyTorch:

```python
def causal_attention(Q, K, V):
    B, T, D = Q.shape

    scores = Q @ K.transpose(-2, -1)
    scores = scores / math.sqrt(D)

    mask = torch.triu(
        torch.ones(T, T, device=Q.device, dtype=torch.bool),
        diagonal=1,
    )

    scores = scores.masked_fill(mask, float("-inf"))

    weights = F.softmax(scores, dim=-1)
    out = weights @ V

    return out, weights
```

Example:

```python
Q = torch.randn(2, 4, 8)
K = torch.randn(2, 4, 8)
V = torch.randn(2, 4, 8)

out, weights = causal_attention(Q, K, V)

print(weights[0])
```

The entries above the diagonal are zero after softmax. This means position 0 can attend only to position 0. Position 1 can attend to positions 0 and 1. Position 2 can attend to positions 0, 1, and 2.

Causal masking is the core constraint behind autoregressive generation.

### Attention and Permutation Sensitivity

Self-attention alone does not know the order of tokens. If we permute the input tokens and apply the same permutation to queries, keys, and values, the attention operation follows the permutation. The operation itself has no built-in notion of first, second, or third position.

This is why transformers need positional information. Positional encodings or positional embeddings are added to token representations so the model can distinguish order.

Without position information, these two sequences contain the same set of tokens from the model’s perspective:

```text
dog bites man
man bites dog
```

Language depends on order. Therefore, attention must be combined with positional information to model sequences properly.

This topic is developed in Section 15.4.

### Attention Complexity

For a sequence of length $T$, self-attention computes a $T \times T$ score matrix. The time and memory complexity are therefore quadratic in sequence length:

$$
O(T^2).
$$

More precisely, for batch size $B$, sequence length $T$, and hidden dimension $D$, the attention score computation costs approximately

$$
O(BT^2D).
$$

The attention weights require memory proportional to

$$
O(BT^2).
$$

This quadratic scaling is acceptable for short and medium sequences, but expensive for very long sequences. A context length of 1,000 produces 1,000,000 pairwise scores per attention head. A context length of 100,000 produces 10,000,000,000 pairwise scores per head.

This is why efficient attention methods exist. Sparse attention, local attention, linear attention, memory compression, recurrence, and state-space models all attempt to reduce the cost of long-context modeling.

### PyTorch Built-In Scaled Dot-Product Attention

Recent PyTorch versions include an optimized scaled dot-product attention function:

```python
torch.nn.functional.scaled_dot_product_attention
```

It computes the same mathematical operation but may use optimized kernels when available.

Example:

```python
import torch
import torch.nn.functional as F

B = 2
H = 4
T = 16
D = 32

Q = torch.randn(B, H, T, D)
K = torch.randn(B, H, T, D)
V = torch.randn(B, H, T, D)

out = F.scaled_dot_product_attention(Q, K, V)

print(out.shape)  # torch.Size([2, 4, 16, 32])
```

Here $H$ is the number of attention heads. The expected shape is commonly

```python
[B, H, T, D]
```

for query, key, and value.

For causal attention:

```python
out = F.scaled_dot_product_attention(
    Q,
    K,
    V,
    is_causal=True,
)
```

This applies the causal mask internally.

For many real models, this function is preferable to manually implementing attention. The manual implementation is still useful because it exposes the mathematics.

### Attention Weights and Interpretability

Attention weights are sometimes inspected to see which tokens a model is using. This can be useful for debugging, but it should be interpreted carefully.

A high attention weight means that a value vector contributed strongly to a particular output vector. It does not always mean the model has a human-interpretable reason for that contribution. The value vectors are learned representations, not raw tokens. Multiple heads and layers interact in complex ways.

Attention maps are therefore diagnostic tools, not complete explanations.

### Summary

Attention is a differentiable retrieval mechanism. A query is compared with keys. The resulting scores are normalized into weights. The weights are used to combine values.

Self-attention uses one sequence as the source of queries, keys, and values. Cross-attention uses queries from one sequence and keys and values from another. Masks control which positions are visible. Positional information is needed because attention alone does not encode order.

The essential formula is

$$
\operatorname{Attention}(Q,K,V) =
\operatorname{softmax}
\left(
\frac{QK^\top}{\sqrt{d}}
\right)V.
$$

In PyTorch, attention can be implemented directly with matrix multiplication and softmax, or by using `torch.nn.functional.scaled_dot_product_attention` for optimized execution.

