Skip to content

Memory-Time Tradeoffs

Reverse mode automatic differentiation is computationally efficient for scalar-output functions, but it has a major systems cost: it needs information from the forward pass...

Reverse mode automatic differentiation is computationally efficient for scalar-output functions, but it has a major systems cost: it needs information from the forward pass during the backward pass.

The reverse pass cannot usually compute derivatives from the output value alone. It needs intermediate primal values, shapes, branches, and sometimes full solver or kernel state. The system must either store this information or recompute it.

This creates the central tradeoff:

StrategyMemoryCompute
Store everythinghighlow
Recompute everythinglowhigh
Checkpoint selectivelymediummedium

Why Memory Is Needed

Consider multiplication:

z=xy. z = xy.

The reverse rules are

xˉ+=zˉy, \bar x \mathrel{+}= \bar z y, yˉ+=zˉx. \bar y \mathrel{+}= \bar z x.

The backward pass needs the original forward values xx and yy. If they were discarded, the system must recompute them.

For a deep computation,

v1v2vky, v_1 \to v_2 \to \cdots \to v_k \to y,

the backward pass visits operations in reverse order:

vkvk1v1. v_k \to v_{k-1} \to \cdots \to v_1.

Each reverse step may need values produced much earlier in the forward pass. Keeping all of them can dominate memory usage.

The Full Storage Strategy

The simplest reverse mode implementation stores every intermediate value.

Forward pass:

for op in program:
    compute output
    save output and needed inputs

Backward pass:

for op in reverse(program):
    read saved values
    apply reverse rule

This strategy is fast because the backward pass has direct access to every required value.

Its cost is memory.

If a model has LL layers and each layer activation requires AA bytes, activation storage is roughly

O(LA). O(LA).

For large neural networks, activations can use more memory than parameters.

The Recomputation Strategy

The opposite strategy stores almost nothing.

During the backward pass, whenever a needed value is missing, the system reruns part of the forward computation to recover it.

Memory falls, but compute increases.

In the extreme case, each backward step may require recomputing a long prefix of the program. For a chain of LL operations, this can increase cost from linear to quadratic.

StrategyForward storageBackward cost for chain
Store all valuesO(L)O(L)O(L)O(L)
Recompute from start each timeO(1)O(1)O(L2)O(L^2)

Pure recomputation is therefore rarely acceptable for large programs.

Checkpointing

Checkpointing stores selected intermediate states and recomputes only between those states.

Suppose a computation has LL stages. Instead of storing all LL activations, the system stores checkpoints every ss stages.

During the backward pass, to differentiate a segment, it reloads the nearest checkpoint and recomputes forward values inside that segment.

Memory becomes smaller than full storage, while compute remains much smaller than full recomputation.

Conceptually:

forward:
    save selected checkpoints

backward:
    for each segment in reverse:
        restore checkpoint
        recompute segment forward
        run local backward pass

Checkpointing is one of the main techniques for training deeper models under fixed memory budgets.

Simple Chain Example

Consider a chain:

xv1v2v3v4y. x \to v_1 \to v_2 \to v_3 \to v_4 \to y.

Full storage saves:

x,v1,v2,v3,v4,y. x, v_1, v_2, v_3, v_4, y.

A checkpointed strategy might save only:

x,v2,y. x, v_2, y.

To compute derivatives for the segment v2v4yv_2 \to v_4 \to y, the system recomputes v3v_3 and v4v_4 from v2v_2.

To compute derivatives for the segment xv2x \to v_2, it recomputes v1v_1 and v2v_2 from xx.

This reduces stored values from all intermediates to a smaller set of boundary states.

Activation Memory in Neural Networks

In neural network training, the largest reverse-mode memory cost is often activation storage.

For a layer

hi+1=Fi(hi,θi), h_{i+1} = F_i(h_i, \theta_i),

the backward rule often requires hih_i, θi\theta_i, and sometimes hi+1h_{i+1}.

Parameters θi\theta_i are already stored. Activations hih_i are produced dynamically and must be retained or recomputed.

For a network with many layers, the activation memory scales with depth and batch size.

Important drivers include:

FactorMemory Effect
batch sizelinear growth
sequence lengthoften linear or quadratic growth
hidden dimensionlinear growth
number of layerslinear growth
saved attention mapscan be very large

This is why memory optimization is central to training large transformer models.

Invertible Computations

Some computations allow the input to be reconstructed from the output.

If

y=g(x) y = g(x)

is invertible, then the backward pass may recover

x=g1(y) x = g^{-1}(y)

instead of storing xx.

Invertible neural networks use this idea to reduce activation memory. They store less forward state and reconstruct hidden activations during the backward pass.

The tradeoff is that the inverse computation must be stable and cheaper than storing the missing values.

Rematerialization

Rematerialization is another name for recomputing discarded intermediates.

A compiler or runtime may choose to rematerialize cheap values instead of loading them from memory.

For example, recomputing

x+1 x + 1

may be cheaper than storing and later reading it.

For expensive operations, such as matrix multiplication or factorization, storage is often preferable.

The decision depends on:

  1. compute cost;
  2. memory bandwidth;
  3. tensor size;
  4. hardware architecture;
  5. reuse count.

Memory Bandwidth Versus Arithmetic

On modern accelerators, memory traffic often limits performance.

Saving all activations can make the backward pass memory-bound. Recomputing selected values may reduce memory pressure and improve throughput, even though it performs extra arithmetic.

This is counterintuitive but common: more floating-point operations can make a program faster if it avoids slow memory movement.

A practical AD system therefore optimizes both memory footprint and memory bandwidth, not only operation count.

Granularity of Checkpoints

Checkpointing can operate at different levels.

LevelExample
primitive operationindividual arithmetic op
tensor operationmatrix multiplication, convolution
layertransformer block
moduleencoder stack
program regionloop body, solver step

Fine-grained checkpointing gives more control but adds overhead.

Coarse-grained checkpointing is simpler and usually matches how users think about models.

In deep learning systems, checkpointing is often applied at the layer or block level.

Schedules

A checkpoint schedule decides which values to store and which regions to recompute.

For a simple chain, uniform spacing may work well. For irregular computation graphs, the optimal schedule depends on graph shape and operation cost.

A schedule should consider:

  1. long-lived values;
  2. branch fan-out;
  3. expensive operations;
  4. tensor sizes;
  5. available device memory.

Poor schedules can save small tensors while recomputing expensive kernels. Good schedules save high-value boundary states and rematerialize cheap interiors.

Interaction With Control Flow

Dynamic control flow complicates memory-time tradeoffs.

If the forward pass follows an input-dependent branch, the backward pass must follow the same executed branch. The system must record enough control-flow information to replay or reverse the computation correctly.

For loops, the system may need per-iteration checkpoints.

Example:

for i in 1..T:
    h = step(h, x[i])

Full storage saves every hih_i.

Checkpointing saves selected hih_i values and recomputes loop segments during backward execution.

Long recurrent computations and neural ODE solvers often require careful checkpointing because TT can be large.

Interaction With Mutation

Mutation makes recomputation harder.

If a value is overwritten during the forward pass, the backward pass may need the old value.

Example:

x = f(x)
x = g(x)

The reverse rule for gg may need the input to gg, while the reverse rule for ff may need the input to ff.

A correct system must preserve or reconstruct each version.

This is why many AD systems use immutable intermediate values or compiler forms such as SSA. Each assignment creates a distinct version, so the reverse pass can refer to the correct value.

Numerical Effects of Recomputation

Recomputation may not produce bit-identical values.

Sources of differences include:

  1. nondeterministic parallel reductions;
  2. random number generation;
  3. floating-point associativity;
  4. stateful kernels;
  5. mixed precision.

For stable training, recomputed values must be consistent enough for the backward pass.

Random operations require special handling. The system may store random seeds or masks rather than full outputs. Dropout is a common example.

Practical Policy

A practical reverse-mode engine usually follows a mixed policy:

Value TypeCommon Policy
small scalar metadatastore
large cheap activationsrecompute
expensive factorization resultsstore
random masksstore seed or mask
control-flow decisionsstore
tensor shapes and stridesstore
outputs of non-deterministic kernelsstore

The right policy is workload-dependent.

Conceptual Summary

Reverse mode gains time efficiency by traversing a recorded computation backward. That efficiency requires access to forward information.

Memory-time tradeoffs determine how that information is made available.

The three basic choices are:

  1. store values from the forward pass;
  2. recompute values during the backward pass;
  3. checkpoint selected states and recompute locally.

Production AD systems combine these strategies. The goal is to minimize memory pressure without making the backward pass too expensive or numerically inconsistent.