Skip to content

Efficient Attention Methods

Standard self-attention compares every token with every other token. For a sequence of length $T$, this produces a $T \times T$ attention matrix. The cost grows quadratically with sequence length.

Standard self-attention compares every token with every other token. For a sequence of length TT, this produces a T×TT \times T attention matrix. The cost grows quadratically with sequence length.

For short and medium sequences, this is acceptable. For long documents, audio streams, videos, high-resolution images, and long-context language models, quadratic attention becomes the main bottleneck.

The standard attention operation is

Attention(Q,K,V)=softmax(QKd)V. \operatorname{Attention}(Q,K,V) = \operatorname{softmax} \left( \frac{QK^\top}{\sqrt{d}} \right)V.

If

Q,K,VRB×H×T×d, Q,K,V \in \mathbb{R}^{B \times H \times T \times d},

then the score matrix has shape

[B,H,T,T]. [B,H,T,T].

This section studies methods that reduce memory, reduce compute, or improve kernel efficiency.

The Cost of Standard Attention

For each head, attention computes all pairwise query-key dot products. The main score computation has cost

O(T2d). O(T^2d).

For batch size BB and number of heads HH, the cost is

O(BHT2d). O(BHT^2d).

The attention weights require memory

O(BHT2). O(BHT^2).

This memory cost is often more restrictive than arithmetic cost. During training, intermediate tensors must be saved for backpropagation. Long sequences therefore increase activation memory sharply.

For example, with B=1B=1, H=32H=32, and T=32768T=32768, the attention matrix contains

32×327682 32 \times 32768^2

entries. That is more than 34 billion attention scores before accounting for data type, gradients, and other activations.

Efficient attention methods attack this problem in different ways.

Efficient Attention Has Several Meanings

The phrase efficient attention can mean different things.

Method typeMain idea
Kernel-efficient exact attentionCompute exact attention with less memory traffic
Sparse attentionAttend to only selected positions
Local attentionAttend to nearby windows
Linear attentionAvoid explicitly forming the T×TT \times T matrix
Low-rank attentionApproximate attention through compressed representations
Memory attentionAttend to compressed memory tokens
KV cache optimizationReduce autoregressive decoding cost
Quantized attentionUse lower precision for memory and speed

These methods make different tradeoffs. Some preserve exact attention. Others change the model. Some help training. Others mainly help inference.

Flash Attention

Flash attention is an exact attention algorithm designed for modern GPUs. It computes the same result as standard attention, up to numerical precision, but avoids materializing the full attention matrix in high-bandwidth memory.

The key idea is tiling. Instead of computing and storing the full matrix

QK, QK^\top,

the algorithm processes blocks of queries and keys. It keeps partial softmax statistics and partial outputs in fast on-chip memory. This reduces memory reads and writes.

Flash attention improves efficiency because standard attention is often memory-bandwidth bound. Moving the large attention matrix to and from GPU memory is expensive. Tiled attention reduces this traffic.

Conceptually, flash attention still computes

softmax(QKd)V. \operatorname{softmax} \left( \frac{QK^\top}{\sqrt{d}} \right)V.

It changes how the computation is scheduled, not the mathematical operation.

In PyTorch, optimized kernels may be used through:

import torch
import torch.nn.functional as F

out = F.scaled_dot_product_attention(q, k, v, is_causal=True)

The tensors usually have shape:

[B, H, T, d_head]

PyTorch selects an available backend depending on hardware, dtype, tensor shapes, masking, and configuration.

Flash attention is often the first efficient attention method to use because it preserves exact attention while improving speed and memory behavior.

Local Attention

Local attention restricts each token to a window around itself. Instead of attending to all TT tokens, each token attends to at most ww nearby tokens.

For token ii, the visible keys are

j[ir,i+r], j \in [i-r, i+r],

where w=2r+1w=2r+1.

The cost becomes approximately

O(Twd) O(Twd)

per head instead of

O(T2d). O(T^2d).

Local attention is natural for images, audio, and many time series. Nearby elements are often more relevant than distant ones. It also resembles convolution, but with content-dependent weights.

For causal local attention, token ii attends only to recent history:

j[iw+1,i]. j \in [i-w+1, i].

This is useful in autoregressive models where long-range attention is expensive.

The limitation is clear: information cannot move globally in one layer. A distant token can influence another token only after many layers, or through special global tokens.

Sparse Attention

Sparse attention uses a structured subset of token pairs. Instead of a dense T×TT \times T attention matrix, it defines an attention graph.

Common sparse patterns include:

PatternDescription
Sliding windowEach token attends to nearby tokens
Dilated windowTokens attend at spaced intervals
Global tokensSpecial tokens attend everywhere
Block sparseAttention is computed between selected blocks
Random sparseSome connections are sampled
Routing-basedTokens are grouped by learned or approximate similarity

Sparse attention reduces cost by making the number of attended positions per token much smaller than TT.

If each token attends to ss positions, the cost becomes roughly

O(Tsd). O(Tsd).

Sparse attention changes the model’s inductive bias. It can work well when the sparse pattern matches the task. It can fail when important information lies outside the chosen pattern.

Global Tokens

A common way to improve local or sparse attention is to add global tokens. These tokens can attend to all positions, and all positions can attend to them.

For example, in a document encoder, a special classification token may attend globally. This allows information from the whole sequence to collect into one representation.

The attention pattern becomes a mixture:

Token typeVisible positions
Local tokenNearby window plus global tokens
Global tokenAll tokens

This gives the model a cheap communication path across the sequence. It does not fully match dense attention, but it often works well for classification and retrieval tasks.

Linear Attention

Linear attention tries to avoid the explicit T×TT \times T attention matrix. It rewrites attention using a feature map ϕ\phi so that the softmax kernel is approximated by an inner product:

exp(qk)ϕ(q)ϕ(k). \exp(q^\top k) \approx \phi(q)^\top \phi(k).

Then attention can be written approximately as

Attention(Q,K,V)i=ϕ(qi)jϕ(kj)vjϕ(qi)jϕ(kj). \operatorname{Attention}(Q,K,V)_i = \frac{ \phi(q_i)^\top \sum_j \phi(k_j)v_j^\top }{ \phi(q_i)^\top \sum_j \phi(k_j) }.

The sums over keys and values can be computed once, avoiding all pairwise comparisons.

The cost becomes roughly linear in sequence length:

O(Td2) O(Td^2)

or similar, depending on the feature dimension.

Linear attention is attractive for very long contexts. However, it changes the attention function. It may underperform exact softmax attention on tasks that need sharp retrieval or precise token selection.

Recurrent View of Causal Linear Attention

For causal linear attention, we can maintain running summaries:

St=St1+ϕ(kt)vt, S_t = S_{t-1} + \phi(k_t)v_t^\top, zt=zt1+ϕ(kt). z_t = z_{t-1} + \phi(k_t).

Then the output at time tt is

yt=ϕ(qt)Stϕ(qt)zt. y_t = \frac{\phi(q_t)^\top S_t} {\phi(q_t)^\top z_t}.

This creates a recurrent form. Each new token updates a fixed-size state. The model can process long streams without storing all past tokens.

The advantage is memory efficiency. The limitation is compression. The fixed-size state may lose details needed for exact retrieval.

Low-Rank Attention

Low-rank attention approximates the full attention matrix using a smaller set of landmarks or projected tokens.

Instead of comparing all TT queries with all TT keys, the model projects the sequence into rr representative positions, where

rT. r \ll T.

The attention computation then uses the compressed representation.

A rough pattern is:

K,VRT×dK~,V~Rr×d. K,V \in \mathbb{R}^{T \times d} \quad\rightarrow\quad \tilde{K},\tilde{V} \in \mathbb{R}^{r \times d}.

Then queries attend to K~\tilde{K} and V~\tilde{V}, reducing cost from T2T^2 to TrTr.

Low-rank methods are useful when the sequence has redundancy. They can be less effective for tasks requiring exact access to arbitrary individual tokens.

Memory Tokens and Compressed Context

Another approach is to compress earlier context into memory tokens. Instead of attending to every previous token, the model attends to a smaller memory representation.

For a long document, the model may split text into chunks. Each chunk produces memory vectors. Later chunks attend to those memory vectors rather than all previous tokens.

This gives a hierarchy:

tokens -> chunk summaries -> document memory

The cost depends on memory size, not full sequence length.

The tradeoff is information loss. Compression can preserve high-level meaning but may lose exact details. This matters for tasks such as code, math, citation lookup, and long-context retrieval.

KV Caching for Decoding

During autoregressive generation, each new token attends to all previous tokens. Without caching, the model recomputes keys and values for the entire context at every step.

KV caching stores keys and values from previous steps.

For each layer, the cache contains:

keys:   [B, H, T_past, d_head]
values: [B, H, T_past, d_head]

At a new step, the model computes only the new key and value, then appends them to the cache.

This reduces repeated computation during inference. However, the cache itself grows linearly with context length:

O(BHLTdh). O(BHLTd_h).

For long contexts and many layers, KV cache memory can dominate inference cost.

Grouped-Query and Multi-Query Attention

Standard multi-head attention has separate keys and values for each head. Grouped-query attention reduces KV cache size by sharing keys and values across groups of query heads.

Multi-query attention is the most aggressive form: all query heads share one set of keys and values.

MethodQueriesKeys and values
Multi-head attentionMany headsSeparate K/V per head
Grouped-query attentionMany headsShared K/V per group
Multi-query attentionMany headsOne shared K/V set

These methods are especially useful for large language model inference because they reduce KV cache memory and memory bandwidth.

The tradeoff is reduced key-value diversity. In practice, grouped-query attention often gives a good balance between quality and inference efficiency.

Quantized Attention and KV Cache

Attention tensors can use lower precision to reduce memory. Training often uses float16 or bfloat16. Inference can quantize weights, activations, or KV cache entries.

KV cache quantization is particularly useful for long-context inference. Reducing keys and values from 16-bit to 8-bit or lower can greatly reduce memory usage.

The risk is numerical degradation. Attention is sensitive to dot products and softmax. Poor quantization can change attention distributions and reduce model quality.

A practical system measures the quality-speed-memory tradeoff rather than assuming quantization is free.

Efficient Attention in PyTorch

For many models, the best starting point is PyTorch’s built-in scaled dot-product attention:

import torch
import torch.nn.functional as F

q = torch.randn(2, 8, 1024, 64, device="cuda", dtype=torch.float16)
k = torch.randn(2, 8, 1024, 64, device="cuda", dtype=torch.float16)
v = torch.randn(2, 8, 1024, 64, device="cuda", dtype=torch.float16)

out = F.scaled_dot_product_attention(
    q,
    k,
    v,
    is_causal=True,
)

This API allows PyTorch to use optimized attention kernels when supported.

A simple wrapper module:

import torch
from torch import nn
import torch.nn.functional as F

class EfficientSelfAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()

        if d_model % num_heads != 0:
            raise ValueError("d_model must be divisible by num_heads")

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x: torch.Tensor, causal: bool = False):
        # x: [B, T, D]
        B, T, D = x.shape

        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, T, self.num_heads, self.d_head).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.d_head).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.d_head).transpose(1, 2)

        # q, k, v: [B, H, T, d_head]
        out = F.scaled_dot_product_attention(
            q,
            k,
            v,
            is_causal=causal,
        )

        out = out.transpose(1, 2).contiguous().view(B, T, D)
        return self.out_proj(out)

This implementation avoids manually materializing attention weights in Python. It also gives the backend more freedom to select an optimized kernel.

Choosing an Efficient Attention Method

The correct method depends on the bottleneck.

ProblemUseful method
Training uses too much attention memoryFlash attention, gradient checkpointing
Inference KV cache is too largeGrouped-query attention, multi-query attention, KV quantization
Sequence is very long but mostly localLocal or sliding-window attention
Need document-level classificationSparse attention with global tokens
Need streaming sequence processingCausal linear attention or recurrent memory
Need exact retrieval over long contextExact attention, retrieval augmentation, memory hierarchy
Need lower latencyOptimized kernels, shorter context, KV caching

Efficient attention should be chosen by measuring quality, latency, memory, and implementation complexity. A method that improves benchmark speed can still damage task performance if its attention pattern removes necessary information.

Summary

Standard attention has quadratic time and memory in sequence length. Efficient attention methods reduce this cost by improving kernels, restricting attention patterns, approximating attention, compressing context, or optimizing inference memory.

Flash attention keeps exact attention but computes it with better memory behavior. Local and sparse attention reduce the number of token pairs. Linear and low-rank attention approximate dense attention. Grouped-query and multi-query attention reduce KV cache cost during decoding.

For PyTorch models, torch.nn.functional.scaled_dot_product_attention is usually the first practical tool to use. More specialized methods should be selected only after identifying the actual bottleneck.