Skip to content

Checkpointing

Checkpointing is a technique for reducing the memory cost of reverse mode automatic differentiation by selectively storing intermediate states and recomputing missing values...

Checkpointing is a technique for reducing the memory cost of reverse mode automatic differentiation by selectively storing intermediate states and recomputing missing values during the backward pass.

It addresses the central problem of reverse mode:

The backward pass needs information from the forward pass, but storing all intermediate values may exceed available memory.

Checkpointing balances memory usage against recomputation cost.

Basic Idea

Suppose a forward computation produces a sequence of states:

x0x1x2xn. x_0 \to x_1 \to x_2 \to \cdots \to x_n.

A naive reverse mode implementation stores every intermediate state:

x0,x1,x2,,xn. x_0, x_1, x_2, \ldots, x_n.

Memory grows linearly with the length of the computation.

Checkpointing instead stores only selected states, called checkpoints:

x0,xk,x2k,. x_0, x_k, x_{2k}, \ldots.

During the backward pass, missing intermediate values are recomputed from the nearest checkpoint.

The system therefore trades extra computation for reduced memory.

Simple Example

Consider a computation chain:

x0x1x2x3x4y. x_0 \rightarrow x_1 \rightarrow x_2 \rightarrow x_3 \rightarrow x_4 \rightarrow y.

Full reverse mode stores:

x0,x1,x2,x3,x4. x_0,x_1,x_2,x_3,x_4.

Suppose checkpointing stores only:

x0,x2,x4. x_0,x_2,x_4.

Backward execution proceeds as follows.

Reverse Through Final Segment

To differentiate the segment

x2x3x4, x_2 \to x_3 \to x_4,

the system recomputes:

x3 x_3

starting from stored checkpoint x2x_2.

Then it performs reverse propagation through that segment.

Reverse Through Earlier Segment

To differentiate

x0x1x2, x_0 \to x_1 \to x_2,

the system recomputes:

x1 x_1

from checkpoint x0x_0.

This avoids storing all intermediates simultaneously.

Computational Interpretation

Checkpointing partitions the computational graph into regions.

Each region has:

  1. a stored entry state;
  2. recomputed internal states;
  3. a local backward sweep.

Conceptually:

forward:
    save selected checkpoints

backward:
    restore checkpoint
    recompute local forward region
    execute local reverse pass

The backward pass therefore becomes an alternating process of:

  1. restoration;
  2. recomputation;
  3. reverse accumulation.

Why Checkpointing Matters

Modern deep learning models often exceed accelerator memory limits if every activation is stored.

For a transformer with hundreds of layers:

ResourceScaling
parameterslarge
optimizer statevery large
activationsoften dominant

Checkpointing allows training larger models by reducing activation memory.

Without checkpointing, many large models would not fit into available GPU memory.

Memory-Compute Tradeoff

Checkpointing reduces memory at the cost of additional forward recomputation.

Suppose:

L L

layers exist.

Full storage:

MetricCost
memoryO(L)O(L)
forward computeO(L)O(L)
backward computeO(L)O(L)

No storage:

MetricCost
memoryO(1)O(1)
backward recomputationO(L2)O(L^2)

Checkpointing:

MetricCost
memoryreduced
computemoderately increased

The goal is to approach minimal memory without excessive recomputation.

Segment Checkpointing

The simplest strategy divides the program into fixed-size segments.

Suppose a computation has

L L

layers and segment size

s. s.

Store checkpoints every ss layers:

x0,xs,x2s,. x_0,x_s,x_{2s},\ldots.

During backward execution:

  1. reload checkpoint;
  2. recompute segment forward;
  3. execute segment backward.

This reduces memory approximately by factor ss, while increasing compute by roughly the same factor.

Example in Neural Networks

Suppose a neural network computes:

hi+1=Fi(hi,θi). h_{i+1} = F_i(h_i,\theta_i).

Without checkpointing, reverse mode stores every activation:

h0,h1,,hL. h_0,h_1,\ldots,h_L.

With checkpointing:

h0,h4,h8, h_0,h_4,h_8,\ldots

may be stored instead.

To backpropagate through layers 44 to 88:

  1. restore h4h_4;
  2. recompute h5,h6,h7,h8h_5,h_6,h_7,h_8;
  3. run backward pass.

Memory falls substantially.

Recursive Checkpointing

More advanced strategies recursively subdivide the computation.

Instead of uniform segments, the system:

  1. stores a checkpoint;
  2. recursively recomputes subregions;
  3. recursively backpropagates.

This can asymptotically reduce memory further.

A famous result is the revolve algorithm, which computes optimal schedules for certain sequential computations.

Recursive checkpointing becomes important for:

  1. very deep networks;
  2. long simulations;
  3. adjoint PDE solvers;
  4. neural ODEs.

Binomial Checkpointing

Binomial checkpointing uses recursive schedules that achieve sublinear memory growth.

The key idea is:

  1. save carefully chosen checkpoints;
  2. recompute strategically;
  3. avoid repeated recomputation explosion.

For long chains, binomial schedules can provide:

QuantityComplexity
memoryO(logL)O(\log L)
recomputationmoderate

These methods are mathematically elegant but operationally more complex.

Online and Offline Schedules

Checkpoint schedules may be:

TypeDescription
offlineplanned before execution
onlineadapt dynamically during execution

Offline schedules assume known graph structure.

Online schedules adapt to:

  1. dynamic control flow;
  2. runtime memory pressure;
  3. variable tensor sizes;
  4. device availability.

Dynamic models often require online checkpointing decisions.

Tensor Checkpointing

Not all tensors have equal value.

Some activations are:

  1. large but cheap to recompute;
  2. small but expensive to recompute.

Good checkpoint systems consider:

PropertyImportance
tensor sizememory cost
recomputation FLOPscompute cost
reuse frequencybackward demand
kernel latencyhardware effect

Example:

Tensor TypeCommon Strategy
ReLU outputsrecompute
attention logitsoften store
matrix factorization statestore
random masksstore seed

The optimal policy depends on hardware and workload.

Checkpointing and Randomness

Random operations complicate recomputation.

Example:

y=dropout(x). y = \text{dropout}(x).

Backward propagation requires the same dropout mask used during the forward pass.

If the mask changes during recomputation, gradients become incorrect.

Solutions include:

MethodStrategy
store random masksimplest
store RNG seedlower memory
deterministic RNG replayreproducible recomputation

Modern systems frequently save RNG state rather than full masks.

Checkpointing and Control Flow

Dynamic branches require replaying the same execution path.

Example:

if x > 0:
    y = f(x)
else:
    y = g(x)

Backward recomputation must execute the same branch as the original forward pass.

Therefore the system must preserve:

  1. branch decisions;
  2. loop iteration counts;
  3. dynamic execution traces.

This metadata is often much smaller than full tensor storage.

Checkpointing in Recurrent Models

Long recurrent computations are especially memory-intensive.

Example:

for t in 1..T:
    h[t] = step(h[t-1], x[t])

Naive reverse mode stores all hidden states:

h1,h2,,hT. h_1,h_2,\ldots,h_T.

For large TT, this becomes infeasible.

Checkpointing stores selected hidden states and recomputes missing segments during backpropagation through time.

Checkpointing in Neural ODEs

Neural ordinary differential equations use continuous-time dynamics:

dhdt=f(h,t,θ). \frac{dh}{dt} = f(h,t,\theta).

The backward pass may reconstruct trajectories by solving another differential equation backward in time.

This can dramatically reduce memory usage.

However, backward reconstruction may introduce:

  1. numerical instability;
  2. trajectory drift;
  3. mismatch between forward and backward states.

Continuous-time checkpointing therefore becomes both a numerical and systems problem.

Hardware Effects

Checkpointing changes the balance between arithmetic and memory traffic.

Modern accelerators often have:

ResourceRelative Speed
arithmetic throughputextremely high
memory bandwidthcomparatively limited

Recomputation may therefore be cheaper than storing and loading large tensors.

Checkpointing can improve throughput even while increasing FLOPs.

This is increasingly true on modern GPUs and TPUs.

Compiler-Assisted Checkpointing

Compilers may automatically choose checkpoint locations.

Optimization objectives include:

  1. minimize peak memory;
  2. minimize recomputation;
  3. satisfy hardware constraints;
  4. maximize parallelism.

Compiler-based systems can analyze:

InformationUse
tensor sizesmemory estimation
dependency graphrecomputation planning
operation costsscheduling
device topologyplacement

Automatic checkpoint scheduling remains an active research area.

Gradient Correctness

Checkpointing must preserve mathematical equivalence with full reverse mode.

Recomputed values must match the original forward computation closely enough that reverse propagation remains valid.

Problems arise with:

  1. nondeterministic kernels;
  2. floating-point instability;
  3. stateful operations;
  4. stochastic algorithms.

Correct checkpointing therefore requires careful management of execution determinism.

Conceptual Summary

Checkpointing reduces reverse-mode memory usage by storing selected intermediate states and recomputing missing values during the backward pass.

It transforms the reverse pass into an alternating sequence of:

  1. checkpoint restoration;
  2. local forward recomputation;
  3. local backward propagation.

The fundamental tradeoff is:

Lower MemoryHigher Compute

Modern automatic differentiation systems rely heavily on checkpointing because memory capacity, rather than arithmetic throughput, is often the primary bottleneck in large-scale differentiable computation.