# Memory Explosion

## Memory Explosion

Reverse-mode automatic differentiation trades computation for memory. To compute gradients efficiently, the backward pass requires access to intermediate values produced during the forward pass. For large computational graphs, storing these intermediates can dominate total system cost.

This phenomenon is known as memory explosion.

It is one of the central engineering constraints in modern automatic differentiation systems.

Memory pressure limits:

- model size,
- batch size,
- sequence length,
- simulation resolution,
- graph depth,
- and hardware utilization.

In many large systems, computation is abundant while memory bandwidth and capacity are scarce. Modern AD design therefore revolves around managing intermediate state efficiently.

### Why Reverse Mode Requires Memory

Consider a computation:

$$
x \to v_1 \to v_2 \to \cdots \to y.
$$

Reverse mode computes gradients backward:

$$
\bar{v}_i = \frac{\partial y}{\partial v_i}.
$$

Each local reverse rule requires primal values.

Example:

$$
z = xy.
$$

Backward propagation uses:

$$
\bar{x} \mathrel{+}= \bar{z}y,
$$

$$
\bar{y} \mathrel{+}= \bar{z}x.
$$

The backward pass therefore needs access to both:

$$
x, y.
$$

If the forward computation has millions or billions of operations, storing every intermediate becomes extremely expensive.

### Linear Growth of Activation Memory

Suppose a computation graph contains:

$$
n
$$

operations, each producing an activation of size:

$$
s.
$$

Naively, reverse mode stores:

$$
O(ns)
$$

memory.

For deep neural networks:

| Component | Typical scale |
|---|---|
| Activations | GBs to TBs |
| Parameters | MBs to hundreds of GBs |
| Gradients | Similar to parameter size |
| Optimizer state | 2× to 8× parameter size |

Activation memory often dominates total memory usage during training.

### Memory in Deep Networks

Consider a transformer with:

- depth $L$,
- sequence length $T$,
- hidden dimension $d$,
- batch size $B$.

Activation tensors scale roughly as:

$$
O(BTdL).
$$

Attention layers introduce additional terms:

$$
O(BT^2).
$$

Long-context transformers therefore experience rapid memory growth.

Example:

| Sequence length | Attention memory |
|---:|---:|
| 1K | manageable |
| 8K | large |
| 32K | severe |
| 128K | often impractical |

Memory becomes the bottleneck before computation does.

### Reverse Mode as a Tape

Many AD systems conceptualize reverse mode as a tape.

During forward execution:

1. execute operation,
2. record metadata,
3. store required intermediates.

During backward execution:

1. traverse tape backward,
2. retrieve stored values,
3. apply reverse rules.

The tape may contain:

| Stored item | Purpose |
|---|---|
| Primal values | Local derivatives |
| Tensor shapes | Broadcasting and reduction |
| Data types | Correct gradient kernels |
| Operation identifiers | Backward dispatch |
| Control flow metadata | Dynamic graph reconstruction |

Large graphs therefore create large tapes.

### Wengert Lists

A Wengert list stores intermediate variables explicitly:

$$
v_1, v_2, \dots, v_n.
$$

Each variable depends on previous variables.

Example:

$$
v_1 = x_1 x_2,
$$

$$
v_2 = \sin(v_1),
$$

$$
v_3 = v_2 + x_3.
$$

Reverse mode traverses:

$$
v_3 \to v_2 \to v_1.
$$

The larger the dependency chain, the larger the stored state.

### Dynamic Graphs

Dynamic graph systems allocate graph structures at runtime.

Examples include:

- eager execution frameworks,
- dynamic control flow,
- recursive differentiable programs.

Dynamic graphs increase memory pressure because:

- graph structure itself consumes memory,
- metadata cannot always be statically optimized,
- allocations become fragmented,
- runtime bookkeeping increases overhead.

Static graph compilers can optimize memory reuse more aggressively.

### Memory Fragmentation

Memory explosion is not only about total size.

Fragmentation matters.

Suppose free memory exists in many small blocks rather than one contiguous region. Large tensor allocations may fail even though nominal free memory appears sufficient.

GPU allocators therefore use:

- pooling,
- caching allocators,
- arena allocation,
- and tensor reuse strategies.

Fragmentation becomes severe in dynamic workloads with varying tensor shapes.

### Gradient Storage

Backward propagation requires gradient buffers.

For parameters:

$$
\theta_i,
$$

systems store:

$$
\nabla_{\theta_i} L.
$$

Optimizer state often multiplies memory further.

Example: Adam optimizer.

For parameter tensor:

$$
\theta,
$$

Adam stores:

| Quantity | Memory multiplier |
|---|---:|
| Parameters | 1× |
| Gradients | 1× |
| First moment | 1× |
| Second moment | 1× |

Total:

$$
4\times
$$

parameter memory before activations are considered.

### Higher-Order AD

Higher-order differentiation dramatically increases memory cost.

Suppose reverse mode is nested inside reverse mode.

The outer reverse pass must preserve:

- primal values,
- inner reverse state,
- gradient intermediates,
- higher-order adjoints.

Naive higher-order reverse mode can produce exponential memory growth.

This is one reason why Hessian computation is substantially harder than gradient computation.

### Recomputation Tradeoffs

Memory can be reduced by recomputing values instead of storing them.

This creates a tradeoff:

| Strategy | Memory | Computation |
|---|---|---|
| Store everything | High | Low |
| Recompute everything | Low | High |
| Checkpointing | Moderate | Moderate |

The central idea:

Instead of saving all activations, save only selected checkpoints.

Missing values are recomputed during backward execution.

### Checkpointing

Checkpointing partitions the graph into segments.

Suppose:

$$
x_0 \to x_1 \to \cdots \to x_n.
$$

Rather than storing every $x_i$, store only selected states:

$$
x_0, x_k, x_{2k}, \dots
$$

During backward execution:

1. reload nearest checkpoint,
2. recompute intermediate states,
3. continue backward pass.

This reduces memory from:

$$
O(n)
$$

toward:

$$
O(\sqrt{n})
$$

or even logarithmic scaling depending on strategy.

### Revolve Algorithm

Optimal checkpoint scheduling is a classical problem.

The Revolve algorithm computes recomputation schedules minimizing memory under bounded storage.

It treats reverse-mode differentiation as a reversible execution problem.

This becomes important in:

- PDE solvers,
- climate simulation,
- differentiable physics,
- and long time-horizon systems.

### Activation Checkpointing in Deep Learning

Modern deep learning systems commonly use activation checkpointing.

Typical policy:

- save activations at layer boundaries,
- recompute inside segments.

This enables training larger models on limited hardware.

Tradeoff:

| Effect | Result |
|---|---|
| Lower memory | Larger models |
| Higher recomputation | Slower training |

Large language models rely heavily on this technique.

### Reversible Networks

Some architectures reconstruct activations exactly during backward execution.

Example:

$$
y_1 = x_1 + f(x_2),
$$

$$
y_2 = x_2 + g(y_1).
$$

The original inputs can be recovered:

$$
x_2 = y_2 - g(y_1),
$$

$$
x_1 = y_1 - f(x_2).
$$

Thus activations need not be stored explicitly.

Reversible networks trade:

- extra recomputation,
- stricter architectural constraints,

for dramatically lower memory use.

### Gradient Checkpoint Granularity

Checkpoint placement matters.

Fine-grained checkpoints:

| Property | Result |
|---|---|
| Low recomputation | High memory |
| High scheduling complexity | More metadata |

Coarse-grained checkpoints:

| Property | Result |
|---|---|
| Lower memory | More recomputation |
| Simpler scheduling | Reduced flexibility |

Optimal checkpoint placement depends on:

- graph structure,
- tensor sizes,
- recomputation cost,
- hardware bandwidth.

### Memory in Attention

Self-attention is especially memory intensive.

Attention scores require:

$$
QK^T.
$$

For sequence length $T$:

$$
O(T^2)
$$

memory.

Backward propagation also requires:

- attention probabilities,
- normalization statistics,
- softmax intermediates.

Long-context transformers therefore become memory-bound quickly.

### Flash Attention

Flash Attention reduces memory usage by avoiding explicit materialization of large attention matrices.

Instead:

- compute attention in blocks,
- fuse operations,
- recompute partial quantities as needed.

This changes memory scaling dramatically.

The key idea:

Trade extra arithmetic for reduced memory traffic.

Modern accelerators are often compute-rich but bandwidth-limited, making this trade favorable.

### Offloading

Large systems sometimes move tensors between devices.

Examples:

| Offload target | Purpose |
|---|---|
| CPU RAM | Extend GPU capacity |
| NVMe SSD | Very large models |
| Remote memory | Distributed systems |

Offloading reduces peak GPU memory but introduces transfer latency.

Efficient scheduling becomes essential.

### Tensor Compression

Activation storage can be compressed.

Methods include:

| Method | Idea |
|---|---|
| Reduced precision | float16 or bfloat16 |
| Quantization | Integer representation |
| Sparsification | Store only important entries |
| Delta encoding | Store differences |
| Low-rank compression | Factorized activations |

Compression reduces memory but may degrade gradient accuracy.

### In-Place Operations

In-place updates reuse memory:

$$
x \leftarrow x + y.
$$

This saves allocations.

However, reverse mode may still require the original value of $x$.

Unsafe in-place mutation can therefore destroy information needed for gradients.

AD systems carefully track aliasing and mutation dependencies.

### Static Memory Planning

Static graph compilers can analyze tensor lifetimes.

Suppose two tensors are never simultaneously live.

Then they may share memory.

This resembles register allocation in classical compilers.

Static planning enables:

- tensor reuse,
- buffer recycling,
- preallocation,
- memory pooling.

Framework compilers aggressively optimize these schedules.

### Distributed Memory

Large models exceed single-device memory.

Distributed strategies include:

| Strategy | Partition |
|---|---|
| Data parallelism | Batch dimension |
| Tensor parallelism | Tensor dimensions |
| Pipeline parallelism | Layers |
| ZeRO optimization | Optimizer states |

Memory becomes a distributed systems problem rather than a single-device problem.

### Sparse Activation Systems

Sparse models activate only subsets of parameters.

Mixture-of-experts architectures are a major example.

Sparse activation reduces:

- activation memory,
- optimizer memory,
- communication volume.

However, routing metadata and irregular execution introduce new complexity.

### Memory Bandwidth vs Capacity

Capacity is only part of the problem.

Bandwidth matters equally.

Backward propagation repeatedly loads:

- activations,
- weights,
- gradients,
- optimizer states.

Memory traffic often dominates runtime.

Modern AD systems therefore optimize:

- tensor locality,
- kernel fusion,
- cache reuse,
- recomputation balance.

### Computational Graph Lifetime

Some graphs are short-lived.

Others persist across iterations.

Persistent graphs consume memory through:

- retained references,
- closure capture,
- caching,
- graph history accumulation.

Improper graph cleanup is a common source of memory leaks.

### Gradient Accumulation

Large effective batch sizes may exceed device memory.

Gradient accumulation simulates larger batches:

1. compute gradients on microbatches,
2. accumulate gradients,
3. update parameters later.

This trades:

- longer training steps,
- more gradient storage,

for lower activation memory per step.

### Memory Complexity of Reverse Mode

Forward mode complexity:

| Quantity | Complexity |
|---|---|
| Memory | Small |
| Compute | Scales with inputs |

Reverse mode:

| Quantity | Complexity |
|---|---|
| Memory | Potentially very large |
| Compute | Scales with outputs |

The low computational cost of reverse mode comes partly from high memory consumption.

This is a fundamental tradeoff.

### Core Idea

Reverse-mode automatic differentiation requires access to intermediate program state during backward propagation. As computational graphs grow larger, storing these intermediates becomes a dominant systems constraint. Memory explosion is therefore not an implementation accident but a structural consequence of reverse accumulation. Modern AD systems manage this through checkpointing, recomputation, reversible computation, compression, static planning, and distributed execution strategies.

