Skip to content

Attention Mechanisms

Attention is a sequence operation that lets each position read information from other positions. Instead of compressing the whole past into one recurrent hidden state,...

Attention is a sequence operation that lets each position read information from other positions. Instead of compressing the whole past into one recurrent hidden state, attention builds a context-dependent representation by comparing queries with keys and mixing values.

Given an input sequence represented as a matrix

XRT×dmodel, X \in \mathbb{R}^{T \times d_{\text{model}}},

we construct three projections:

Q=XWQ,K=XWK,V=XWV. Q = XW_Q, \qquad K = XW_K, \qquad V = XW_V.

Here, QQ contains queries, KK contains keys, and VV contains values. The attention scores are dot products between queries and keys:

S=QKdk. S = \frac{QK^\top}{\sqrt{d_k}}.

The scaling factor dk\sqrt{d_k} keeps dot products from growing too large as the key dimension increases. Large scores can push softmax into saturated regions, producing small gradients.

The attention weights are:

A=softmax(S), A = \operatorname{softmax}(S),

and the output is:

Y=AV. Y = AV.

This is the core scaled dot-product attention operation.

Interpretation

For each position ii, the query qiq_i is compared with every key kjk_j. The result is a score:

sij=qikjdk. s_{ij} = \frac{q_i^\top k_j}{\sqrt{d_k}}.

After softmax, the row aia_i becomes a probability distribution over source positions. The output at position ii is a weighted sum of value vectors:

yi=j=1Taijvj. y_i = \sum_{j=1}^{T} a_{ij}v_j.

Thus attention has three steps:

  1. compare positions,
  2. normalize comparisons,
  3. mix values.

The operation is differentiable with respect to QQ, KK, VV, and the projection matrices that produced them.

Backward Pass Through Attention

Automatic differentiation treats attention as a composition of standard tensor operations:

Q = X @ W_Q
K = X @ W_K
V = X @ W_V

S = Q @ K.T / sqrt(d_k)
A = softmax(S)
Y = A @ V

The backward pass applies local derivative rules in reverse order.

Given an upstream adjoint

Yˉ=LY, \bar{Y} = \frac{\partial L}{\partial Y},

the final matrix multiplication gives:

Aˉ=YˉV, \bar{A} = \bar{Y}V^\top, Vˉ=AYˉ. \bar{V} = A^\top \bar{Y}.

The softmax backward rule maps Aˉ\bar{A} into Sˉ\bar{S}. For each row:

sˉ=a(aˉ(aaˉ)1). \bar{s} = a \odot \left( \bar{a} - (a^\top \bar{a})\mathbf{1} \right).

Then the score matrix multiplication gives:

Qˉ=1dkSˉK, \bar{Q} = \frac{1}{\sqrt{d_k}}\bar{S}K, Kˉ=1dkSˉQ. \bar{K} = \frac{1}{\sqrt{d_k}}\bar{S}^\top Q.

Finally, gradients flow into the projection matrices:

WˉQ=XQˉ,WˉK=XKˉ,WˉV=XVˉ. \bar{W}_Q = X^\top \bar{Q}, \qquad \bar{W}_K = X^\top \bar{K}, \qquad \bar{W}_V = X^\top \bar{V}.

If the same input XX is used to produce QQ, KK, and VV, then the gradient with respect to XX receives three contributions:

Xˉ=QˉWQ+KˉWK+VˉWV. \bar{X} = \bar{Q}W_Q^\top + \bar{K}W_K^\top + \bar{V}W_V^\top.

This additive accumulation is ordinary reverse mode AD. Attention has no special derivative principle. Its importance comes from the shape and cost of the tensor operations.

Softmax Stability

The softmax operation is usually implemented with a numerical shift:

softmax(s)i=exp(sim)jexp(sjm),m=maxjsj. \operatorname{softmax}(s)_i = \frac{\exp(s_i - m)} {\sum_j \exp(s_j - m)}, \qquad m = \max_j s_j.

Subtracting mm does not change the result, but it reduces overflow risk.

In attention, this matters because scores can be large. Without the shift, exponentials may overflow. With very negative masked values, exponentials may underflow to zero, which is usually intended.

AD differentiates the implemented stable softmax. The mathematical result matches ordinary softmax where values are finite, but the floating point behavior is much safer.

Masks

Attention commonly uses masks. A mask changes which positions can be read.

A causal mask prevents a position from attending to future positions:

Sij=when j>i. S_{ij} = -\infty \quad \text{when } j > i.

A padding mask prevents attention to padding tokens:

Sij=when position j is padding. S_{ij} = -\infty \quad \text{when position } j \text{ is padding}.

In implementation, -\infty is often represented by a large negative finite value. After softmax, masked positions receive probability zero or nearly zero.

The derivative through a masked position is also zero or nearly zero. The mask is usually treated as constant metadata, not as a differentiable input.

Multi-Head Attention

Multi-head attention runs several attention operations in parallel. Each head has its own projections:

Qh=XWQ(h),Kh=XWK(h),Vh=XWV(h). Q_h = XW_Q^{(h)}, \qquad K_h = XW_K^{(h)}, \qquad V_h = XW_V^{(h)}.

Each head produces:

Yh=softmax(QhKhdh)Vh. Y_h = \operatorname{softmax} \left( \frac{Q_hK_h^\top}{\sqrt{d_h}} \right) V_h.

The head outputs are concatenated and projected:

Y=concat(Y1,,YH)WO. Y = \operatorname{concat}(Y_1,\ldots,Y_H)W_O.

Multi-head attention gives the model several independent comparison spaces. One head may attend to local syntax, another to long-range dependencies, another to positional patterns. This interpretation is useful but informal; heads are learned components, not hand-coded roles.

For AD, multi-head attention is mostly batching and reshaping. Gradients flow independently through each head until the concatenation and output projection combine them.

Self-Attention and Cross-Attention

In self-attention, QQ, KK, and VV all come from the same sequence XX. This is common in encoders and decoders.

In cross-attention, queries come from one sequence and keys and values come from another:

Q=XtargetWQ,K=XsourceWK,V=XsourceWV. Q = X_{\text{target}}W_Q, \qquad K = X_{\text{source}}W_K, \qquad V = X_{\text{source}}W_V.

Cross-attention is used in encoder-decoder models, retrieval-augmented models, and multimodal models. It lets one stream read from another stream.

The backward pass respects this separation. Gradients from QQ flow into the target stream. Gradients from KK and VV flow into the source stream.

Complexity

For a sequence of length TT, standard attention forms the score matrix:

SRT×T. S \in \mathbb{R}^{T \times T}.

The time and memory complexity are both quadratic in sequence length:

O(T2d)time,O(T2)attention matrix memory. O(T^2 d) \quad \text{time}, \qquad O(T^2) \quad \text{attention matrix memory}.

This quadratic cost is the main bottleneck for long-context models.

The backward pass has similar asymptotic cost and often higher memory pressure because it needs forward residuals such as attention weights, softmax statistics, or recomputable inputs.

Fused Attention Kernels

Production systems often use fused attention kernels. Instead of materializing every intermediate tensor, the kernel combines score computation, masking, softmax, dropout, and value mixing.

The goal is to reduce memory traffic. Attention is frequently limited by memory bandwidth rather than only by arithmetic throughput.

A fused forward kernel may avoid storing the full attention matrix. A fused backward kernel then recomputes selected values from compact saved statistics. This is a form of rematerialization specialized to attention.

From the AD point of view, the fused kernel must expose a correct backward rule. The graph may contain one primitive operation:

Y = attention(Q, K, V, mask)

rather than many smaller operations. The derivative contract remains the same: given Yˉ\bar{Y}, return Qˉ\bar{Q}, Kˉ\bar{K}, and Vˉ\bar{V}.

Dropout in Attention

Attention dropout randomly removes some attention weights during training. A simplified form is:

A~=MAp, \tilde{A} = \frac{M \odot A}{p},

where MM is a Bernoulli mask and pp is the keep probability.

The output becomes:

Y=A~V. Y = \tilde{A}V.

During the backward pass, gradients flow only through entries kept by the dropout mask. The mask must be saved or regenerated deterministically for the backward pass.

Dropout changes the sampled computation. AD computes the derivative of that sampled computation, not of the expectation over all masks.

Positional Information

Attention alone does not know token order. If the rows of XX are permuted, plain self-attention without positional information is equivariant to that permutation.

Sequence models therefore add positional information. Common methods include absolute position embeddings, relative position biases, and rotary position embeddings.

Some positional mechanisms are simple additions to XX. Others modify QQ, KK, or the attention scores. AD differentiates through learned positional parameters if they are trainable. Fixed positional transforms are treated as constant operations.

Attention as Differentiable Memory Access

Attention can be viewed as soft memory access. The keys are addresses, the values are contents, and the query asks what to retrieve.

Unlike hard indexing, attention uses a differentiable weighted average. This permits gradients to flow through both the addressing mechanism and the retrieved values.

Hard lookup chooses one item:

y = V[index]

Soft attention mixes many items:

y=jajvj. y = \sum_j a_jv_j.

The second form is differentiable with respect to the scores that produce aja_j. This is one reason attention is useful in neural architectures: it provides content-based routing while preserving gradient flow.

Failure Modes

Attention mechanisms can fail numerically or statistically.

Large score magnitudes can saturate softmax, making most attention probabilities near zero and gradients small.

Poor masking can leak future tokens in language modeling or allow padding tokens to influence outputs.

Long sequences can exhaust memory because attention matrices scale as T2T^2.

Dropout masks and fused kernels can create reproducibility issues if random state or kernel determinism is poorly controlled.

Mixed precision can introduce instability in score computation or softmax if accumulation is not handled carefully.

These problems are usually solved at the systems layer: stable softmax, careful masks, fused kernels, checkpointing, precision policy, and tests for causal correctness.

Interface to AD Systems

An AD system may expose attention in two ways.

The first is decomposed attention: the graph contains matrix multiplications, scaling, masking, softmax, dropout, and another matrix multiplication. The AD engine derives the backward pass from primitive rules.

The second is primitive attention: the graph contains a single attention operation with a custom backward rule. This is common for performance.

Both designs should produce the same gradients up to floating point differences. The primitive design gives the runtime more control over memory layout and kernel fusion. The decomposed design is simpler and easier to inspect.

Attention is therefore a good example of the tension between mathematical clarity and systems efficiency. The clean formula is short. The production implementation is a carefully engineered differentiable kernel.