Reverse mode automatic differentiation is computationally efficient for scalar-output functions, but it has a major systems cost: it needs information from the forward pass...
Reverse mode automatic differentiation is computationally efficient for scalar-output functions, but it has a major systems cost: it needs information from the forward pass during the backward pass.
The reverse pass cannot usually compute derivatives from the output value alone. It needs intermediate primal values, shapes, branches, and sometimes full solver or kernel state. The system must either store this information or recompute it.
This creates the central tradeoff:
| Strategy | Memory | Compute |
|---|---|---|
| Store everything | high | low |
| Recompute everything | low | high |
| Checkpoint selectively | medium | medium |
Why Memory Is Needed
Consider multiplication:
The reverse rules are
The backward pass needs the original forward values and . If they were discarded, the system must recompute them.
For a deep computation,
the backward pass visits operations in reverse order:
Each reverse step may need values produced much earlier in the forward pass. Keeping all of them can dominate memory usage.
The Full Storage Strategy
The simplest reverse mode implementation stores every intermediate value.
Forward pass:
for op in program:
compute output
save output and needed inputsBackward pass:
for op in reverse(program):
read saved values
apply reverse ruleThis strategy is fast because the backward pass has direct access to every required value.
Its cost is memory.
If a model has layers and each layer activation requires bytes, activation storage is roughly
For large neural networks, activations can use more memory than parameters.
The Recomputation Strategy
The opposite strategy stores almost nothing.
During the backward pass, whenever a needed value is missing, the system reruns part of the forward computation to recover it.
Memory falls, but compute increases.
In the extreme case, each backward step may require recomputing a long prefix of the program. For a chain of operations, this can increase cost from linear to quadratic.
| Strategy | Forward storage | Backward cost for chain |
|---|---|---|
| Store all values | ||
| Recompute from start each time |
Pure recomputation is therefore rarely acceptable for large programs.
Checkpointing
Checkpointing stores selected intermediate states and recomputes only between those states.
Suppose a computation has stages. Instead of storing all activations, the system stores checkpoints every stages.
During the backward pass, to differentiate a segment, it reloads the nearest checkpoint and recomputes forward values inside that segment.
Memory becomes smaller than full storage, while compute remains much smaller than full recomputation.
Conceptually:
forward:
save selected checkpoints
backward:
for each segment in reverse:
restore checkpoint
recompute segment forward
run local backward passCheckpointing is one of the main techniques for training deeper models under fixed memory budgets.
Simple Chain Example
Consider a chain:
Full storage saves:
A checkpointed strategy might save only:
To compute derivatives for the segment , the system recomputes and from .
To compute derivatives for the segment , it recomputes and from .
This reduces stored values from all intermediates to a smaller set of boundary states.
Activation Memory in Neural Networks
In neural network training, the largest reverse-mode memory cost is often activation storage.
For a layer
the backward rule often requires , , and sometimes .
Parameters are already stored. Activations are produced dynamically and must be retained or recomputed.
For a network with many layers, the activation memory scales with depth and batch size.
Important drivers include:
| Factor | Memory Effect |
|---|---|
| batch size | linear growth |
| sequence length | often linear or quadratic growth |
| hidden dimension | linear growth |
| number of layers | linear growth |
| saved attention maps | can be very large |
This is why memory optimization is central to training large transformer models.
Invertible Computations
Some computations allow the input to be reconstructed from the output.
If
is invertible, then the backward pass may recover
instead of storing .
Invertible neural networks use this idea to reduce activation memory. They store less forward state and reconstruct hidden activations during the backward pass.
The tradeoff is that the inverse computation must be stable and cheaper than storing the missing values.
Rematerialization
Rematerialization is another name for recomputing discarded intermediates.
A compiler or runtime may choose to rematerialize cheap values instead of loading them from memory.
For example, recomputing
may be cheaper than storing and later reading it.
For expensive operations, such as matrix multiplication or factorization, storage is often preferable.
The decision depends on:
- compute cost;
- memory bandwidth;
- tensor size;
- hardware architecture;
- reuse count.
Memory Bandwidth Versus Arithmetic
On modern accelerators, memory traffic often limits performance.
Saving all activations can make the backward pass memory-bound. Recomputing selected values may reduce memory pressure and improve throughput, even though it performs extra arithmetic.
This is counterintuitive but common: more floating-point operations can make a program faster if it avoids slow memory movement.
A practical AD system therefore optimizes both memory footprint and memory bandwidth, not only operation count.
Granularity of Checkpoints
Checkpointing can operate at different levels.
| Level | Example |
|---|---|
| primitive operation | individual arithmetic op |
| tensor operation | matrix multiplication, convolution |
| layer | transformer block |
| module | encoder stack |
| program region | loop body, solver step |
Fine-grained checkpointing gives more control but adds overhead.
Coarse-grained checkpointing is simpler and usually matches how users think about models.
In deep learning systems, checkpointing is often applied at the layer or block level.
Schedules
A checkpoint schedule decides which values to store and which regions to recompute.
For a simple chain, uniform spacing may work well. For irregular computation graphs, the optimal schedule depends on graph shape and operation cost.
A schedule should consider:
- long-lived values;
- branch fan-out;
- expensive operations;
- tensor sizes;
- available device memory.
Poor schedules can save small tensors while recomputing expensive kernels. Good schedules save high-value boundary states and rematerialize cheap interiors.
Interaction With Control Flow
Dynamic control flow complicates memory-time tradeoffs.
If the forward pass follows an input-dependent branch, the backward pass must follow the same executed branch. The system must record enough control-flow information to replay or reverse the computation correctly.
For loops, the system may need per-iteration checkpoints.
Example:
for i in 1..T:
h = step(h, x[i])Full storage saves every .
Checkpointing saves selected values and recomputes loop segments during backward execution.
Long recurrent computations and neural ODE solvers often require careful checkpointing because can be large.
Interaction With Mutation
Mutation makes recomputation harder.
If a value is overwritten during the forward pass, the backward pass may need the old value.
Example:
x = f(x)
x = g(x)The reverse rule for may need the input to , while the reverse rule for may need the input to .
A correct system must preserve or reconstruct each version.
This is why many AD systems use immutable intermediate values or compiler forms such as SSA. Each assignment creates a distinct version, so the reverse pass can refer to the correct value.
Numerical Effects of Recomputation
Recomputation may not produce bit-identical values.
Sources of differences include:
- nondeterministic parallel reductions;
- random number generation;
- floating-point associativity;
- stateful kernels;
- mixed precision.
For stable training, recomputed values must be consistent enough for the backward pass.
Random operations require special handling. The system may store random seeds or masks rather than full outputs. Dropout is a common example.
Practical Policy
A practical reverse-mode engine usually follows a mixed policy:
| Value Type | Common Policy |
|---|---|
| small scalar metadata | store |
| large cheap activations | recompute |
| expensive factorization results | store |
| random masks | store seed or mask |
| control-flow decisions | store |
| tensor shapes and strides | store |
| outputs of non-deterministic kernels | store |
The right policy is workload-dependent.
Conceptual Summary
Reverse mode gains time efficiency by traversing a recorded computation backward. That efficiency requires access to forward information.
Memory-time tradeoffs determine how that information is made available.
The three basic choices are:
- store values from the forward pass;
- recompute values during the backward pass;
- checkpoint selected states and recompute locally.
Production AD systems combine these strategies. The goal is to minimize memory pressure without making the backward pass too expensive or numerically inconsistent.