Cross-attention is attention between two different sequences or sources of information. The queries come from one sequence, while the keys and values come from another.
Cross-attention is attention between two different sequences or sources of information. The queries come from one sequence, while the keys and values come from another.
Self-attention asks: “Which positions in this same sequence should I use?”
Cross-attention asks: “Which positions in another sequence should I use?”
This distinction is central in encoder-decoder transformers, retrieval systems, image captioning models, text-to-image models, and multimodal systems.
Basic Idea
Suppose we have two sequences:
and
In cross-attention, the query sequence attends to the source sequence .
The model computes:
Then it applies scaled dot-product attention:
The output has one vector for each query position in , but each vector is built from information in .
Shape View
Let
be the query-side sequence, and let
be the source-side sequence.
After projection:
The attention score matrix has shape:
Each row corresponds to one query position. Each column corresponds to one source position.
Thus cross-attention produces an alignment matrix between two sequences.
Encoder-Decoder Attention
The standard transformer decoder contains two attention blocks.
First, it applies causal self-attention over the partial output sequence. This lets the decoder use tokens it has already generated.
Second, it applies cross-attention over encoder outputs. This lets the decoder read from the input sequence.
For machine translation, the encoder reads the source sentence:
The decoder generates the target sentence:
At each target position, cross-attention lets the decoder select relevant source positions.
When generating “chat,” the decoder may attend strongly to “cat.” When generating “tapis,” it may attend strongly to “mat.”
Cross-Attention as Conditional Generation
Cross-attention is a mechanism for conditioning one sequence on another.
A decoder language model without cross-attention generates from its own prefix. A decoder with cross-attention generates from its prefix plus an external source.
This pattern appears in many tasks.
| Task | Query source | Key-value source |
|---|---|---|
| Translation | Target prefix | Source sentence |
| Summarization | Summary prefix | Document |
| Question answering | Answer prefix | Context passage |
| Image captioning | Text prefix | Image features |
| Text-to-image generation | Image latents | Text embeddings |
| Speech recognition | Text prefix | Audio features |
The query side asks what it needs. The source side provides information.
Cross-Attention in Multimodal Models
Cross-attention is especially useful when the model combines different modalities.
In image captioning, text tokens may attend to image patch features. The text decoder asks, at each step, which image regions are relevant for the next word.
In text-to-image diffusion models, image latent representations may attend to text embeddings. The image generation process asks which words or phrases should influence each spatial region.
In visual question answering, question tokens may attend to image features, or image tokens may attend to question tokens.
Cross-attention provides a general interface between modalities because the query and source sequences do not need to have the same length or structure.
Difference from Concatenated Self-Attention
One alternative is to concatenate two sequences and apply self-attention:
This allows all tokens to attend to all other tokens. Cross-attention is more controlled.
In cross-attention from to :
The query sequence reads from the source sequence, but the source sequence does not necessarily read back from the query sequence.
This directionality is useful. In an encoder-decoder model, the encoder representation can be computed once. The decoder repeatedly attends to it during generation.
| Method | Interaction |
|---|---|
| Concatenated self-attention | Both sequences jointly interact |
| Cross-attention | One sequence reads from another |
| Encoder-decoder attention | Decoder reads encoder outputs |
Cross-attention can also be cheaper when the source representation is reused.
Masks in Cross-Attention
Cross-attention commonly uses source-side padding masks.
If the source sequence contains padding tokens, the query sequence should not attend to them.
Suppose source tokens have shape:
[B, T_x]A source padding mask has shape:
[B, T_x]It can be reshaped to:
[B, 1, T_x]so it broadcasts over all query positions.
Unlike decoder self-attention, cross-attention usually does not use a causal mask over the source. The full source sequence is already known.
PyTorch Implementation
A minimal cross-attention module looks like this:
import math
import torch
import torch.nn as nn
class CrossAttention(nn.Module):
def __init__(self, query_dim, source_dim, key_dim, value_dim):
super().__init__()
self.key_dim = key_dim
self.q_proj = nn.Linear(query_dim, key_dim)
self.k_proj = nn.Linear(source_dim, key_dim)
self.v_proj = nn.Linear(source_dim, value_dim)
def forward(self, query_states, source_states, source_mask=None):
"""
query_states: [B, T_y, D_y]
source_states: [B, T_x, D_x]
source_mask: [B, T_x] or broadcastable to [B, T_y, T_x]
"""
Q = self.q_proj(query_states)
K = self.k_proj(source_states)
V = self.v_proj(source_states)
scores = Q @ K.transpose(-2, -1)
scores = scores / math.sqrt(self.key_dim)
if source_mask is not None:
if source_mask.ndim == 2:
source_mask = source_mask[:, None, :]
scores = scores.masked_fill(source_mask == 0, float("-inf"))
weights = torch.softmax(scores, dim=-1)
output = weights @ V
return output, weightsThe shape flow is:
| Tensor | Shape |
|---|---|
| Query states | [B, T_y, D_y] |
| Source states | [B, T_x, D_x] |
| Queries | [B, T_y, d_k] |
| Keys | [B, T_x, d_k] |
| Values | [B, T_x, d_v] |
| Scores | [B, T_y, T_x] |
| Weights | [B, T_y, T_x] |
| Output | [B, T_y, d_v] |
Cross-Attention in Decoder Layers
A transformer decoder layer usually has this structure:
causal self-attention
cross-attention
feedforward networkThe causal self-attention block lets each generated token depend on previous generated tokens.
The cross-attention block lets each generated token depend on the encoded input.
The feedforward network then transforms each position independently.
This separation is useful. The decoder first builds an internal representation of the output prefix, then reads the source sequence, then refines the result.
Key-Value Caching
During autoregressive decoding, the encoder output does not change. Therefore the cross-attention keys and values can be computed once and reused at every decoding step.
This is called key-value caching.
For cross-attention:
can be cached because is fixed.
At each decoding step, only the new query must be computed.
Caching reduces repeated computation during inference. It is especially important for long source sequences.
Cross-Attention and Retrieval
Cross-attention can also be used to condition on retrieved documents.
A retrieval system first selects relevant passages. The model then attends to the passage tokens while generating an answer.
In this setting:
| Component | Role |
|---|---|
| User question and generated prefix | Query side |
| Retrieved documents | Key-value side |
| Cross-attention weights | Soft selection over retrieved content |
This design separates retrieval from generation. Retrieval provides candidate evidence. Cross-attention decides which parts of that evidence to use.
Limitations
Cross-attention still has pairwise cost. If the query length is and the source length is , the score matrix has size:
Very long source sequences can therefore be expensive.
Cross-attention also depends on source quality. If the encoder representation is weak, or the retrieved documents are irrelevant, cross-attention cannot fully repair the input.
In multimodal models, the alignment problem can also be difficult. Text tokens and image regions may correspond only indirectly.
Summary
Cross-attention lets one sequence read from another sequence. Queries come from the target or receiving side, while keys and values come from the source side.
It is the main mechanism behind encoder-decoder transformers, multimodal conditioning, retrieval-augmented generation, and many conditional generation systems.
Self-attention builds context within a sequence. Cross-attention transfers information across sequences.