Self-attention is attention applied within a single sequence. The same input supplies the queries, keys, and values. Each position builds a new representation by reading from other positions in the same sequence.
Self-attention is attention applied within a single sequence. The same input supplies the queries, keys, and values. Each position builds a new representation by reading from other positions in the same sequence.
Given an input tensor
where is batch size, is sequence length, and is model dimension, self-attention first projects into three tensors:
Here , , and are learned matrices. They define what each token asks for, what each token exposes for comparison, and what information each token contributes.
The output is
Self-attention changes each token representation from a local embedding into a context-dependent embedding.
Why Self-Attention Matters
A token often cannot be understood alone. Its meaning depends on nearby and distant tokens.
Consider:
The bank raised interest rates.and:
The boat reached the river bank.The token bank has different meanings in the two sentences. A good representation of bank must depend on the surrounding words. Self-attention gives the model a direct way to combine information from the whole sequence.
Unlike recurrent networks, self-attention does not process tokens strictly from left to right. All token pairs can interact in parallel. This makes transformer training efficient on GPUs and accelerators.
Unlike convolutional networks, self-attention has a global receptive field from the first layer. A token can attend to any other visible token in one operation.
Pairwise Token Interaction
For a sequence of length , self-attention computes a score between every query token and every key token.
If , the score matrix has shape :
The entry measures how strongly token attends to token .
After softmax, each row becomes a probability distribution:
The output for token is
Thus each output token is a weighted mixture of value vectors from the sequence.
Tensor Shapes in PyTorch
A standard input tensor has shape:
[B, T, D]A simple self-attention layer preserves this shape:
[B, T, D] -> [B, T, D]Internally, the score tensor has shape:
[B, T, T]For each batch item, it stores all pairwise token scores.
A minimal implementation:
import math
import torch
from torch import nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x: torch.Tensor):
# x: [B, T, D]
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
d_k = q.shape[-1]
scores = q @ k.transpose(-2, -1)
scores = scores / math.sqrt(d_k)
weights = F.softmax(scores, dim=-1)
z = weights @ v
return self.out_proj(z), weightsUsage:
x = torch.randn(8, 16, 64)
attn = SelfAttention(d_model=64)
z, weights = attn(x)
print(z.shape) # torch.Size([8, 16, 64])
print(weights.shape) # torch.Size([8, 16, 16])The output has the same sequence length as the input. The weight matrix shows how each token reads from all tokens.
Learned Projections
The projections , , and are important. Self-attention does not compare the raw input vectors directly. It compares learned query and key representations.
For an input vector , the model computes:
These projections let the model separate three roles:
| Projection | Role |
|---|---|
| Query | What this position wants to find |
| Key | How this position can be found |
| Value | What this position contributes |
For example, in a language model, one attention head may learn patterns related to subject-verb agreement. Another may learn local phrase structure. Another may track delimiter matching. These roles are learned from data through gradient descent.
Self-Attention as Message Passing
Self-attention can also be viewed as message passing on a fully connected graph. Each token is a node. Each token sends a value vector to every other token. Attention weights determine how strongly each message is received.
For token , the update is
This resembles graph neural network message passing, except the graph is dense and the edge weights are computed dynamically from the token representations.
This view is useful because it explains why self-attention is flexible. The model does not need a fixed neighborhood. It constructs a data-dependent interaction pattern at every layer.
Bidirectional and Causal Self-Attention
There are two major forms of self-attention.
Bidirectional self-attention allows each token to attend to all tokens in the sequence. This is common in encoder models such as BERT-style systems. It is useful when the whole input is available.
Causal self-attention allows each token to attend only to itself and previous tokens. This is common in decoder-only language models. It is required for next-token prediction, where future tokens must remain hidden.
For causal self-attention, the attention matrix is lower triangular:
In PyTorch, causal masking can be done manually:
def causal_self_attention(x: torch.Tensor, layer: SelfAttention):
z, weights = layer(x)
return z, weightsA correct causal implementation must apply the mask before softmax:
class CausalSelfAttention(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x: torch.Tensor):
# x: [B, T, D]
B, T, D = x.shape
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
scores = q @ k.transpose(-2, -1)
scores = scores / math.sqrt(D)
mask = torch.triu(
torch.ones(T, T, device=x.device, dtype=torch.bool),
diagonal=1,
)
scores = scores.masked_fill(mask, float("-inf"))
weights = F.softmax(scores, dim=-1)
z = weights @ v
return self.out_proj(z), weightsFor production code, PyTorch’s built-in scaled dot-product attention is preferred:
z = F.scaled_dot_product_attention(q, k, v, is_causal=True)Padding Masks
Batches often contain sequences of different lengths. To process them together, we pad short sequences with a special padding token.
For example:
[the, cat, sleeps, <pad>, <pad>]
[the, small, dog, runs, fast]The model should not use the padding tokens as meaningful context. A padding mask hides them from attention.
Suppose attention_mask has shape [B, T], where 1 means real token and 0 means padding. We can convert it into a mask over keys:
def apply_padding_mask(scores: torch.Tensor, attention_mask: torch.Tensor):
# scores: [B, T, T]
# attention_mask: [B, T]
key_mask = attention_mask[:, None, :].bool()
scores = scores.masked_fill(~key_mask, float("-inf"))
return scoresThe mask is applied before softmax. After softmax, padding positions receive zero attention weight.
Self-Attention With Masks
A general self-attention layer often needs both causal and padding masks. Padding masks remove invalid tokens. Causal masks remove future tokens.
class MaskedSelfAttention(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor | None = None,
causal: bool = False,
):
# x: [B, T, D]
B, T, D = x.shape
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
scores = q @ k.transpose(-2, -1)
scores = scores / math.sqrt(D)
if attention_mask is not None:
# attention_mask: [B, T], 1 for real token, 0 for pad
key_mask = attention_mask[:, None, :].bool()
scores = scores.masked_fill(~key_mask, float("-inf"))
if causal:
causal_mask = torch.triu(
torch.ones(T, T, device=x.device, dtype=torch.bool),
diagonal=1,
)
scores = scores.masked_fill(causal_mask, float("-inf"))
weights = F.softmax(scores, dim=-1)
z = weights @ v
return self.out_proj(z), weightsThis implementation is useful for study. Real transformer code usually folds these details into optimized attention kernels.
Limitations of Single-Head Self-Attention
The simple self-attention layer above uses one attention pattern. That means each token forms one weighted mixture of the sequence.
This is limiting. A token may need several kinds of context at once. In a sentence, it may need syntactic context, semantic context, positional context, and coreference context. One attention distribution may be too narrow.
Multi-head attention addresses this by running several attention mechanisms in parallel. Each head has its own projections and its own attention weights. The results are concatenated and projected back into the model dimension.
This is the topic of Section 15.3.
Computational Cost
Self-attention requires all pairwise token comparisons. For sequence length , the attention matrix has entries.
The main cost is
The memory for attention weights is
This quadratic scaling is the main weakness of standard self-attention. It is efficient for moderate sequence lengths and expensive for very long contexts.
For this reason, long-context models often modify attention. Common strategies include local attention, sparse attention, sliding windows, memory tokens, recurrence, linear attention, and state-space layers.
Summary
Self-attention lets each position in a sequence read from other positions in the same sequence. It computes queries, keys, and values from one input tensor, forms pairwise scores, normalizes them with softmax, and combines value vectors.
Bidirectional self-attention is used when the full input is visible. Causal self-attention is used for autoregressive generation. Padding masks prevent the model from reading artificial padding tokens.
The result is a context-dependent sequence representation. This mechanism is the core building block of transformer models.