Skip to content

Self-Attention

Self-attention is attention applied within a single sequence. The same input supplies the queries, keys, and values. Each position builds a new representation by reading from other positions in the same sequence.

Self-attention is attention applied within a single sequence. The same input supplies the queries, keys, and values. Each position builds a new representation by reading from other positions in the same sequence.

Given an input tensor

XRB×T×D, X \in \mathbb{R}^{B \times T \times D},

where BB is batch size, TT is sequence length, and DD is model dimension, self-attention first projects XX into three tensors:

Q=XWQ,K=XWK,V=XWV. Q = XW_Q,\quad K = XW_K,\quad V = XW_V.

Here WQW_Q, WKW_K, and WVW_V are learned matrices. They define what each token asks for, what each token exposes for comparison, and what information each token contributes.

The output is

Z=softmax(QKdk)V. Z = \operatorname{softmax} \left( \frac{QK^\top}{\sqrt{d_k}} \right)V.

Self-attention changes each token representation from a local embedding into a context-dependent embedding.

Why Self-Attention Matters

A token often cannot be understood alone. Its meaning depends on nearby and distant tokens.

Consider:

The bank raised interest rates.

and:

The boat reached the river bank.

The token bank has different meanings in the two sentences. A good representation of bank must depend on the surrounding words. Self-attention gives the model a direct way to combine information from the whole sequence.

Unlike recurrent networks, self-attention does not process tokens strictly from left to right. All token pairs can interact in parallel. This makes transformer training efficient on GPUs and accelerators.

Unlike convolutional networks, self-attention has a global receptive field from the first layer. A token can attend to any other visible token in one operation.

Pairwise Token Interaction

For a sequence of length TT, self-attention computes a score between every query token and every key token.

If T=4T=4, the score matrix has shape 4×44 \times 4:

S=[s11s12s13s14s21s22s23s24s31s32s33s34s41s42s43s44]. S = \begin{bmatrix} s_{11} & s_{12} & s_{13} & s_{14} \\ s_{21} & s_{22} & s_{23} & s_{24} \\ s_{31} & s_{32} & s_{33} & s_{34} \\ s_{41} & s_{42} & s_{43} & s_{44} \end{bmatrix}.

The entry sijs_{ij} measures how strongly token ii attends to token jj.

After softmax, each row becomes a probability distribution:

j=1Taij=1. \sum_{j=1}^{T} a_{ij}=1.

The output for token ii is

zi=j=1Taijvj. z_i = \sum_{j=1}^{T} a_{ij}v_j.

Thus each output token is a weighted mixture of value vectors from the sequence.

Tensor Shapes in PyTorch

A standard input tensor has shape:

[B, T, D]

A simple self-attention layer preserves this shape:

[B, T, D] -> [B, T, D]

Internally, the score tensor has shape:

[B, T, T]

For each batch item, it stores all pairwise token scores.

A minimal implementation:

import math
import torch
from torch import nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x: torch.Tensor):
        # x: [B, T, D]
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        d_k = q.shape[-1]
        scores = q @ k.transpose(-2, -1)
        scores = scores / math.sqrt(d_k)

        weights = F.softmax(scores, dim=-1)
        z = weights @ v

        return self.out_proj(z), weights

Usage:

x = torch.randn(8, 16, 64)

attn = SelfAttention(d_model=64)
z, weights = attn(x)

print(z.shape)       # torch.Size([8, 16, 64])
print(weights.shape) # torch.Size([8, 16, 16])

The output has the same sequence length as the input. The weight matrix shows how each token reads from all tokens.

Learned Projections

The projections WQW_Q, WKW_K, and WVW_V are important. Self-attention does not compare the raw input vectors directly. It compares learned query and key representations.

For an input vector xix_i, the model computes:

qi=xiWQ,ki=xiWK,vi=xiWV. q_i = x_i W_Q, \quad k_i = x_i W_K, \quad v_i = x_i W_V.

These projections let the model separate three roles:

ProjectionRole
QueryWhat this position wants to find
KeyHow this position can be found
ValueWhat this position contributes

For example, in a language model, one attention head may learn patterns related to subject-verb agreement. Another may learn local phrase structure. Another may track delimiter matching. These roles are learned from data through gradient descent.

Self-Attention as Message Passing

Self-attention can also be viewed as message passing on a fully connected graph. Each token is a node. Each token sends a value vector to every other token. Attention weights determine how strongly each message is received.

For token ii, the update is

zi=jaijvj. z_i = \sum_j a_{ij}v_j.

This resembles graph neural network message passing, except the graph is dense and the edge weights are computed dynamically from the token representations.

This view is useful because it explains why self-attention is flexible. The model does not need a fixed neighborhood. It constructs a data-dependent interaction pattern at every layer.

Bidirectional and Causal Self-Attention

There are two major forms of self-attention.

Bidirectional self-attention allows each token to attend to all tokens in the sequence. This is common in encoder models such as BERT-style systems. It is useful when the whole input is available.

Causal self-attention allows each token to attend only to itself and previous tokens. This is common in decoder-only language models. It is required for next-token prediction, where future tokens must remain hidden.

For causal self-attention, the attention matrix is lower triangular:

aij=0when j>i. a_{ij}=0 \quad \text{when } j>i.

In PyTorch, causal masking can be done manually:

def causal_self_attention(x: torch.Tensor, layer: SelfAttention):
    z, weights = layer(x)
    return z, weights

A correct causal implementation must apply the mask before softmax:

class CausalSelfAttention(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x: torch.Tensor):
        # x: [B, T, D]
        B, T, D = x.shape

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        scores = q @ k.transpose(-2, -1)
        scores = scores / math.sqrt(D)

        mask = torch.triu(
            torch.ones(T, T, device=x.device, dtype=torch.bool),
            diagonal=1,
        )

        scores = scores.masked_fill(mask, float("-inf"))

        weights = F.softmax(scores, dim=-1)
        z = weights @ v

        return self.out_proj(z), weights

For production code, PyTorch’s built-in scaled dot-product attention is preferred:

z = F.scaled_dot_product_attention(q, k, v, is_causal=True)

Padding Masks

Batches often contain sequences of different lengths. To process them together, we pad short sequences with a special padding token.

For example:

[the, cat, sleeps, <pad>, <pad>]
[the, small, dog, runs, fast]

The model should not use the padding tokens as meaningful context. A padding mask hides them from attention.

Suppose attention_mask has shape [B, T], where 1 means real token and 0 means padding. We can convert it into a mask over keys:

def apply_padding_mask(scores: torch.Tensor, attention_mask: torch.Tensor):
    # scores: [B, T, T]
    # attention_mask: [B, T]
    key_mask = attention_mask[:, None, :].bool()
    scores = scores.masked_fill(~key_mask, float("-inf"))
    return scores

The mask is applied before softmax. After softmax, padding positions receive zero attention weight.

Self-Attention With Masks

A general self-attention layer often needs both causal and padding masks. Padding masks remove invalid tokens. Causal masks remove future tokens.

class MaskedSelfAttention(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(
        self,
        x: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        causal: bool = False,
    ):
        # x: [B, T, D]
        B, T, D = x.shape

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        scores = q @ k.transpose(-2, -1)
        scores = scores / math.sqrt(D)

        if attention_mask is not None:
            # attention_mask: [B, T], 1 for real token, 0 for pad
            key_mask = attention_mask[:, None, :].bool()
            scores = scores.masked_fill(~key_mask, float("-inf"))

        if causal:
            causal_mask = torch.triu(
                torch.ones(T, T, device=x.device, dtype=torch.bool),
                diagonal=1,
            )
            scores = scores.masked_fill(causal_mask, float("-inf"))

        weights = F.softmax(scores, dim=-1)
        z = weights @ v

        return self.out_proj(z), weights

This implementation is useful for study. Real transformer code usually folds these details into optimized attention kernels.

Limitations of Single-Head Self-Attention

The simple self-attention layer above uses one attention pattern. That means each token forms one weighted mixture of the sequence.

This is limiting. A token may need several kinds of context at once. In a sentence, it may need syntactic context, semantic context, positional context, and coreference context. One attention distribution may be too narrow.

Multi-head attention addresses this by running several attention mechanisms in parallel. Each head has its own projections and its own attention weights. The results are concatenated and projected back into the model dimension.

This is the topic of Section 15.3.

Computational Cost

Self-attention requires all pairwise token comparisons. For sequence length TT, the attention matrix has T2T^2 entries.

The main cost is

O(BT2D). O(BT^2D).

The memory for attention weights is

O(BT2). O(BT^2).

This quadratic scaling is the main weakness of standard self-attention. It is efficient for moderate sequence lengths and expensive for very long contexts.

For this reason, long-context models often modify attention. Common strategies include local attention, sparse attention, sliding windows, memory tokens, recurrence, linear attention, and state-space layers.

Summary

Self-attention lets each position in a sequence read from other positions in the same sequence. It computes queries, keys, and values from one input tensor, forms pairwise scores, normalizes them with softmax, and combines value vectors.

Bidirectional self-attention is used when the full input is visible. Causal self-attention is used for autoregressive generation. Padding masks prevent the model from reading artificial padding tokens.

The result is a context-dependent sequence representation. This mechanism is the core building block of transformer models.