# Efficient Attention Methods

Standard self-attention compares every token with every other token. For a sequence of length $T$, this produces a $T \times T$ attention matrix. The cost grows quadratically with sequence length.

For short and medium sequences, this is acceptable. For long documents, audio streams, videos, high-resolution images, and long-context language models, quadratic attention becomes the main bottleneck.

The standard attention operation is

$$
\operatorname{Attention}(Q,K,V) =
\operatorname{softmax}
\left(
\frac{QK^\top}{\sqrt{d}}
\right)V.
$$

If

$$
Q,K,V \in \mathbb{R}^{B \times H \times T \times d},
$$

then the score matrix has shape

$$
[B,H,T,T].
$$

This section studies methods that reduce memory, reduce compute, or improve kernel efficiency.

### The Cost of Standard Attention

For each head, attention computes all pairwise query-key dot products. The main score computation has cost

$$
O(T^2d).
$$

For batch size $B$ and number of heads $H$, the cost is

$$
O(BHT^2d).
$$

The attention weights require memory

$$
O(BHT^2).
$$

This memory cost is often more restrictive than arithmetic cost. During training, intermediate tensors must be saved for backpropagation. Long sequences therefore increase activation memory sharply.

For example, with $B=1$, $H=32$, and $T=32768$, the attention matrix contains

$$
32 \times 32768^2
$$

entries. That is more than 34 billion attention scores before accounting for data type, gradients, and other activations.

Efficient attention methods attack this problem in different ways.

### Efficient Attention Has Several Meanings

The phrase efficient attention can mean different things.

| Method type | Main idea |
|---|---|
| Kernel-efficient exact attention | Compute exact attention with less memory traffic |
| Sparse attention | Attend to only selected positions |
| Local attention | Attend to nearby windows |
| Linear attention | Avoid explicitly forming the $T \times T$ matrix |
| Low-rank attention | Approximate attention through compressed representations |
| Memory attention | Attend to compressed memory tokens |
| KV cache optimization | Reduce autoregressive decoding cost |
| Quantized attention | Use lower precision for memory and speed |

These methods make different tradeoffs. Some preserve exact attention. Others change the model. Some help training. Others mainly help inference.

### Flash Attention

Flash attention is an exact attention algorithm designed for modern GPUs. It computes the same result as standard attention, up to numerical precision, but avoids materializing the full attention matrix in high-bandwidth memory.

The key idea is tiling. Instead of computing and storing the full matrix

$$
QK^\top,
$$

the algorithm processes blocks of queries and keys. It keeps partial softmax statistics and partial outputs in fast on-chip memory. This reduces memory reads and writes.

Flash attention improves efficiency because standard attention is often memory-bandwidth bound. Moving the large attention matrix to and from GPU memory is expensive. Tiled attention reduces this traffic.

Conceptually, flash attention still computes

$$
\operatorname{softmax}
\left(
\frac{QK^\top}{\sqrt{d}}
\right)V.
$$

It changes how the computation is scheduled, not the mathematical operation.

In PyTorch, optimized kernels may be used through:

```python
import torch
import torch.nn.functional as F

out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
```

The tensors usually have shape:

```python
[B, H, T, d_head]
```

PyTorch selects an available backend depending on hardware, dtype, tensor shapes, masking, and configuration.

Flash attention is often the first efficient attention method to use because it preserves exact attention while improving speed and memory behavior.

### Local Attention

Local attention restricts each token to a window around itself. Instead of attending to all $T$ tokens, each token attends to at most $w$ nearby tokens.

For token $i$, the visible keys are

$$
j \in [i-r, i+r],
$$

where $w=2r+1$.

The cost becomes approximately

$$
O(Twd)
$$

per head instead of

$$
O(T^2d).
$$

Local attention is natural for images, audio, and many time series. Nearby elements are often more relevant than distant ones. It also resembles convolution, but with content-dependent weights.

For causal local attention, token $i$ attends only to recent history:

$$
j \in [i-w+1, i].
$$

This is useful in autoregressive models where long-range attention is expensive.

The limitation is clear: information cannot move globally in one layer. A distant token can influence another token only after many layers, or through special global tokens.

### Sparse Attention

Sparse attention uses a structured subset of token pairs. Instead of a dense $T \times T$ attention matrix, it defines an attention graph.

Common sparse patterns include:

| Pattern | Description |
|---|---|
| Sliding window | Each token attends to nearby tokens |
| Dilated window | Tokens attend at spaced intervals |
| Global tokens | Special tokens attend everywhere |
| Block sparse | Attention is computed between selected blocks |
| Random sparse | Some connections are sampled |
| Routing-based | Tokens are grouped by learned or approximate similarity |

Sparse attention reduces cost by making the number of attended positions per token much smaller than $T$.

If each token attends to $s$ positions, the cost becomes roughly

$$
O(Tsd).
$$

Sparse attention changes the model’s inductive bias. It can work well when the sparse pattern matches the task. It can fail when important information lies outside the chosen pattern.

### Global Tokens

A common way to improve local or sparse attention is to add global tokens. These tokens can attend to all positions, and all positions can attend to them.

For example, in a document encoder, a special classification token may attend globally. This allows information from the whole sequence to collect into one representation.

The attention pattern becomes a mixture:

| Token type | Visible positions |
|---|---|
| Local token | Nearby window plus global tokens |
| Global token | All tokens |

This gives the model a cheap communication path across the sequence. It does not fully match dense attention, but it often works well for classification and retrieval tasks.

### Linear Attention

Linear attention tries to avoid the explicit $T \times T$ attention matrix. It rewrites attention using a feature map $\phi$ so that the softmax kernel is approximated by an inner product:

$$
\exp(q^\top k) \approx \phi(q)^\top \phi(k).
$$

Then attention can be written approximately as

$$
\operatorname{Attention}(Q,K,V)_i =
\frac{
\phi(q_i)^\top \sum_j \phi(k_j)v_j^\top
}{
\phi(q_i)^\top \sum_j \phi(k_j)
}.
$$

The sums over keys and values can be computed once, avoiding all pairwise comparisons.

The cost becomes roughly linear in sequence length:

$$
O(Td^2)
$$

or similar, depending on the feature dimension.

Linear attention is attractive for very long contexts. However, it changes the attention function. It may underperform exact softmax attention on tasks that need sharp retrieval or precise token selection.

### Recurrent View of Causal Linear Attention

For causal linear attention, we can maintain running summaries:

$$
S_t = S_{t-1} + \phi(k_t)v_t^\top,
$$

$$
z_t = z_{t-1} + \phi(k_t).
$$

Then the output at time $t$ is

$$
y_t =
\frac{\phi(q_t)^\top S_t}
{\phi(q_t)^\top z_t}.
$$

This creates a recurrent form. Each new token updates a fixed-size state. The model can process long streams without storing all past tokens.

The advantage is memory efficiency. The limitation is compression. The fixed-size state may lose details needed for exact retrieval.

### Low-Rank Attention

Low-rank attention approximates the full attention matrix using a smaller set of landmarks or projected tokens.

Instead of comparing all $T$ queries with all $T$ keys, the model projects the sequence into $r$ representative positions, where

$$
r \ll T.
$$

The attention computation then uses the compressed representation.

A rough pattern is:

$$
K,V \in \mathbb{R}^{T \times d}
\quad\rightarrow\quad
\tilde{K},\tilde{V} \in \mathbb{R}^{r \times d}.
$$

Then queries attend to $\tilde{K}$ and $\tilde{V}$, reducing cost from $T^2$ to $Tr$.

Low-rank methods are useful when the sequence has redundancy. They can be less effective for tasks requiring exact access to arbitrary individual tokens.

### Memory Tokens and Compressed Context

Another approach is to compress earlier context into memory tokens. Instead of attending to every previous token, the model attends to a smaller memory representation.

For a long document, the model may split text into chunks. Each chunk produces memory vectors. Later chunks attend to those memory vectors rather than all previous tokens.

This gives a hierarchy:

```text
tokens -> chunk summaries -> document memory
```

The cost depends on memory size, not full sequence length.

The tradeoff is information loss. Compression can preserve high-level meaning but may lose exact details. This matters for tasks such as code, math, citation lookup, and long-context retrieval.

### KV Caching for Decoding

During autoregressive generation, each new token attends to all previous tokens. Without caching, the model recomputes keys and values for the entire context at every step.

KV caching stores keys and values from previous steps.

For each layer, the cache contains:

```python
keys:   [B, H, T_past, d_head]
values: [B, H, T_past, d_head]
```

At a new step, the model computes only the new key and value, then appends them to the cache.

This reduces repeated computation during inference. However, the cache itself grows linearly with context length:

$$
O(BHLTd_h).
$$

For long contexts and many layers, KV cache memory can dominate inference cost.

### Grouped-Query and Multi-Query Attention

Standard multi-head attention has separate keys and values for each head. Grouped-query attention reduces KV cache size by sharing keys and values across groups of query heads.

Multi-query attention is the most aggressive form: all query heads share one set of keys and values.

| Method | Queries | Keys and values |
|---|---|---|
| Multi-head attention | Many heads | Separate K/V per head |
| Grouped-query attention | Many heads | Shared K/V per group |
| Multi-query attention | Many heads | One shared K/V set |

These methods are especially useful for large language model inference because they reduce KV cache memory and memory bandwidth.

The tradeoff is reduced key-value diversity. In practice, grouped-query attention often gives a good balance between quality and inference efficiency.

### Quantized Attention and KV Cache

Attention tensors can use lower precision to reduce memory. Training often uses `float16` or `bfloat16`. Inference can quantize weights, activations, or KV cache entries.

KV cache quantization is particularly useful for long-context inference. Reducing keys and values from 16-bit to 8-bit or lower can greatly reduce memory usage.

The risk is numerical degradation. Attention is sensitive to dot products and softmax. Poor quantization can change attention distributions and reduce model quality.

A practical system measures the quality-speed-memory tradeoff rather than assuming quantization is free.

### Efficient Attention in PyTorch

For many models, the best starting point is PyTorch’s built-in scaled dot-product attention:

```python
import torch
import torch.nn.functional as F

q = torch.randn(2, 8, 1024, 64, device="cuda", dtype=torch.float16)
k = torch.randn(2, 8, 1024, 64, device="cuda", dtype=torch.float16)
v = torch.randn(2, 8, 1024, 64, device="cuda", dtype=torch.float16)

out = F.scaled_dot_product_attention(
    q,
    k,
    v,
    is_causal=True,
)
```

This API allows PyTorch to use optimized attention kernels when supported.

A simple wrapper module:

```python
import torch
from torch import nn
import torch.nn.functional as F

class EfficientSelfAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()

        if d_model % num_heads != 0:
            raise ValueError("d_model must be divisible by num_heads")

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x: torch.Tensor, causal: bool = False):
        # x: [B, T, D]
        B, T, D = x.shape

        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, T, self.num_heads, self.d_head).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.d_head).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.d_head).transpose(1, 2)

        # q, k, v: [B, H, T, d_head]
        out = F.scaled_dot_product_attention(
            q,
            k,
            v,
            is_causal=causal,
        )

        out = out.transpose(1, 2).contiguous().view(B, T, D)
        return self.out_proj(out)
```

This implementation avoids manually materializing attention weights in Python. It also gives the backend more freedom to select an optimized kernel.

### Choosing an Efficient Attention Method

The correct method depends on the bottleneck.

| Problem | Useful method |
|---|---|
| Training uses too much attention memory | Flash attention, gradient checkpointing |
| Inference KV cache is too large | Grouped-query attention, multi-query attention, KV quantization |
| Sequence is very long but mostly local | Local or sliding-window attention |
| Need document-level classification | Sparse attention with global tokens |
| Need streaming sequence processing | Causal linear attention or recurrent memory |
| Need exact retrieval over long context | Exact attention, retrieval augmentation, memory hierarchy |
| Need lower latency | Optimized kernels, shorter context, KV caching |

Efficient attention should be chosen by measuring quality, latency, memory, and implementation complexity. A method that improves benchmark speed can still damage task performance if its attention pattern removes necessary information.

### Summary

Standard attention has quadratic time and memory in sequence length. Efficient attention methods reduce this cost by improving kernels, restricting attention patterns, approximating attention, compressing context, or optimizing inference memory.

Flash attention keeps exact attention but computes it with better memory behavior. Local and sparse attention reduce the number of token pairs. Linear and low-rank attention approximate dense attention. Grouped-query and multi-query attention reduce KV cache cost during decoding.

For PyTorch models, `torch.nn.functional.scaled_dot_product_attention` is usually the first practical tool to use. More specialized methods should be selected only after identifying the actual bottleneck.

