Attention is a method for letting a model choose which parts of an input are most relevant when producing an output.
Attention is a method for letting a model choose which parts of an input are most relevant when producing an output. It replaces the idea that all input positions should contribute equally. In a sequence model, attention allows one token to look at other tokens. In an image model, it allows one patch to look at other patches. In a multimodal model, it allows text tokens to look at image regions, audio frames, or retrieved documents.
The central operation is simple. Given a query, attention compares it with a set of keys, converts the comparisons into weights, and uses those weights to form a weighted average of values.
The three objects are:
| Object | Meaning |
|---|---|
| Query | What the current position is looking for |
| Key | What each candidate position offers for comparison |
| Value | The information retrieved from each candidate position |
For a single query , keys , and values , attention computes scores
The scores are converted into normalized weights by softmax:
The output is the weighted sum
The weight tells us how much the output should use value . Large score, large weight. Small score, small weight.
From Fixed Representations to Attention
Older sequence-to-sequence models often compressed the whole input sequence into one fixed-size vector. For short sequences this can work. For long sequences it creates a bottleneck. The decoder must recover all relevant information from one vector, even when different output positions need different parts of the input.
Attention removes this bottleneck. Instead of using one fixed representation, the model keeps a collection of hidden states and learns which ones to read from at each step.
For machine translation, when producing a target word, the decoder can attend to the source words most relevant to that word. For summarization, the model can attend to important sentences or phrases. For language modeling, each token can attend to earlier tokens that provide useful context.
The principle is general: attention is differentiable content-based retrieval.
Dot-Product Attention
The simplest common form is dot-product attention. Suppose we have matrices
Here contains queries, contains keys, and contains values.
The score matrix is
Its shape is
Each entry measures how strongly query matches key .
After applying softmax along the key dimension, we obtain the attention weight matrix
Then the attention output is
The shape of the output is
This is the core of attention. Scores are computed by similarity. Weights are computed by softmax. Values are combined by weighted averaging.
Scaled Dot-Product Attention
In transformers, dot-product attention is usually scaled by , where is the key dimension:
The scaling factor matters because dot products grow in magnitude as the feature dimension grows. If the entries of and have roughly unit variance, then has variance proportional to . Large scores push the softmax toward extreme probabilities. This can make gradients small and training unstable.
Dividing by keeps the score scale more stable.
In PyTorch:
import math
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V):
d = Q.shape[-1]
scores = Q @ K.transpose(-2, -1)
scores = scores / math.sqrt(d)
weights = F.softmax(scores, dim=-1)
output = weights @ V
return output, weightsThis function works for batched inputs because transpose(-2, -1) swaps the last two axes, and matrix multiplication applies over the leading batch axes.
Example:
B = 2 # batch size
T = 5 # sequence length
D = 8 # feature dimension
Q = torch.randn(B, T, D)
K = torch.randn(B, T, D)
V = torch.randn(B, T, D)
out, weights = scaled_dot_product_attention(Q, K, V)
print(out.shape) # torch.Size([2, 5, 8])
print(weights.shape) # torch.Size([2, 5, 5])The output has one vector per query position. The weight tensor has shape [B, T, T], meaning that each token in each batch item has a distribution over all token positions.
Attention as Weighted Averaging
Attention can be understood as a learned weighted average. The values contain candidate information. The attention weights decide how much of each value to include.
Suppose a sequence has five tokens. For token 3, the attention weights may be
Then the output for token 3 is mostly influenced by token 4, with smaller contributions from the other tokens.
This view is useful but incomplete. The model learns the queries, keys, and values through linear projections. The attention weights are not fixed linguistic rules. They are produced by trainable representations.
For an input tensor
a transformer layer usually computes
The matrices , , and are learned parameters. Thus the model learns what to search for, what to compare against, and what information to retrieve.
Self-Attention
Self-attention is attention where the queries, keys, and values are all computed from the same input sequence.
Given
we define
Then attention is computed between positions in the same sequence.
Self-attention allows each token to build a representation using information from other tokens. For example, in the sentence
The animal did not cross the street because it was tired.the word it may need to attend to animal to resolve its meaning. In another sentence,
The animal did not cross the street because it was flooded.the word it may need to attend to street.
Self-attention gives the model a mechanism for such contextual dependency. The meaning of a token becomes conditional on the surrounding sequence.
A minimal self-attention module in PyTorch can be written as:
import math
import torch
from torch import nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, d_model):
super().__init__()
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x):
# x: [B, T, D]
Q = self.q_proj(x)
K = self.k_proj(x)
V = self.v_proj(x)
d = Q.shape[-1]
scores = Q @ K.transpose(-2, -1)
scores = scores / math.sqrt(d)
weights = F.softmax(scores, dim=-1)
out = weights @ V
out = self.out_proj(out)
return out, weightsExample:
x = torch.randn(4, 10, 32)
attn = SelfAttention(d_model=32)
out, weights = attn(x)
print(out.shape) # torch.Size([4, 10, 32])
print(weights.shape) # torch.Size([4, 10, 10])The model maps a sequence of 10 input vectors to a sequence of 10 output vectors. Each output vector is a context-dependent representation of the corresponding input position.
Cross-Attention
Cross-attention uses queries from one source and keys and values from another source.
Suppose
is a target-side representation, and
is a source-side representation. Cross-attention computes
The target sequence asks questions. The source sequence provides keys and values.
This is common in encoder-decoder transformers. The decoder uses cross-attention to read from the encoder output. It is also common in multimodal systems. A text decoder may use cross-attention to read from image features, audio features, retrieved documents, or tool outputs.
A minimal PyTorch implementation:
class CrossAttention(nn.Module):
def __init__(self, d_model):
super().__init__()
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x, context):
# x: [B, T_x, D]
# context: [B, T_y, D]
Q = self.q_proj(x)
K = self.k_proj(context)
V = self.v_proj(context)
d = Q.shape[-1]
scores = Q @ K.transpose(-2, -1)
scores = scores / math.sqrt(d)
weights = F.softmax(scores, dim=-1)
out = weights @ V
out = self.out_proj(out)
return out, weightsThe attention weights have shape
[B, T_x, T_y]Each target position has a distribution over source positions.
Masked Attention
Sometimes a model should not attend to every position. A mask prevents attention from using forbidden positions.
There are two common masks.
A padding mask hides padding tokens. In a batch, sequences often have different lengths. Shorter sequences are padded to match the longest one. The model should ignore those padding positions.
A causal mask hides future tokens. In autoregressive language modeling, token may attend only to tokens . It must not attend to tokens after , because those tokens are unknown during generation.
For a sequence of length , a causal mask has the form:
The mask is added to the attention scores before softmax. Positions with score receive zero probability after softmax.
In PyTorch:
def causal_attention(Q, K, V):
B, T, D = Q.shape
scores = Q @ K.transpose(-2, -1)
scores = scores / math.sqrt(D)
mask = torch.triu(
torch.ones(T, T, device=Q.device, dtype=torch.bool),
diagonal=1,
)
scores = scores.masked_fill(mask, float("-inf"))
weights = F.softmax(scores, dim=-1)
out = weights @ V
return out, weightsExample:
Q = torch.randn(2, 4, 8)
K = torch.randn(2, 4, 8)
V = torch.randn(2, 4, 8)
out, weights = causal_attention(Q, K, V)
print(weights[0])The entries above the diagonal are zero after softmax. This means position 0 can attend only to position 0. Position 1 can attend to positions 0 and 1. Position 2 can attend to positions 0, 1, and 2.
Causal masking is the core constraint behind autoregressive generation.
Attention and Permutation Sensitivity
Self-attention alone does not know the order of tokens. If we permute the input tokens and apply the same permutation to queries, keys, and values, the attention operation follows the permutation. The operation itself has no built-in notion of first, second, or third position.
This is why transformers need positional information. Positional encodings or positional embeddings are added to token representations so the model can distinguish order.
Without position information, these two sequences contain the same set of tokens from the model’s perspective:
dog bites man
man bites dogLanguage depends on order. Therefore, attention must be combined with positional information to model sequences properly.
This topic is developed in Section 15.4.
Attention Complexity
For a sequence of length , self-attention computes a score matrix. The time and memory complexity are therefore quadratic in sequence length:
More precisely, for batch size , sequence length , and hidden dimension , the attention score computation costs approximately
The attention weights require memory proportional to
This quadratic scaling is acceptable for short and medium sequences, but expensive for very long sequences. A context length of 1,000 produces 1,000,000 pairwise scores per attention head. A context length of 100,000 produces 10,000,000,000 pairwise scores per head.
This is why efficient attention methods exist. Sparse attention, local attention, linear attention, memory compression, recurrence, and state-space models all attempt to reduce the cost of long-context modeling.
PyTorch Built-In Scaled Dot-Product Attention
Recent PyTorch versions include an optimized scaled dot-product attention function:
torch.nn.functional.scaled_dot_product_attentionIt computes the same mathematical operation but may use optimized kernels when available.
Example:
import torch
import torch.nn.functional as F
B = 2
H = 4
T = 16
D = 32
Q = torch.randn(B, H, T, D)
K = torch.randn(B, H, T, D)
V = torch.randn(B, H, T, D)
out = F.scaled_dot_product_attention(Q, K, V)
print(out.shape) # torch.Size([2, 4, 16, 32])Here is the number of attention heads. The expected shape is commonly
[B, H, T, D]for query, key, and value.
For causal attention:
out = F.scaled_dot_product_attention(
Q,
K,
V,
is_causal=True,
)This applies the causal mask internally.
For many real models, this function is preferable to manually implementing attention. The manual implementation is still useful because it exposes the mathematics.
Attention Weights and Interpretability
Attention weights are sometimes inspected to see which tokens a model is using. This can be useful for debugging, but it should be interpreted carefully.
A high attention weight means that a value vector contributed strongly to a particular output vector. It does not always mean the model has a human-interpretable reason for that contribution. The value vectors are learned representations, not raw tokens. Multiple heads and layers interact in complex ways.
Attention maps are therefore diagnostic tools, not complete explanations.
Summary
Attention is a differentiable retrieval mechanism. A query is compared with keys. The resulting scores are normalized into weights. The weights are used to combine values.
Self-attention uses one sequence as the source of queries, keys, and values. Cross-attention uses queries from one sequence and keys and values from another. Masks control which positions are visible. Positional information is needed because attention alone does not encode order.
The essential formula is
In PyTorch, attention can be implemented directly with matrix multiplication and softmax, or by using torch.nn.functional.scaled_dot_product_attention for optimized execution.