Checkpointing is a technique for reducing the memory cost of reverse mode automatic differentiation by selectively storing intermediate states and recomputing missing values...
Checkpointing is a technique for reducing the memory cost of reverse mode automatic differentiation by selectively storing intermediate states and recomputing missing values during the backward pass.
It addresses the central problem of reverse mode:
The backward pass needs information from the forward pass, but storing all intermediate values may exceed available memory.
Checkpointing balances memory usage against recomputation cost.
Basic Idea
Suppose a forward computation produces a sequence of states:
A naive reverse mode implementation stores every intermediate state:
Memory grows linearly with the length of the computation.
Checkpointing instead stores only selected states, called checkpoints:
During the backward pass, missing intermediate values are recomputed from the nearest checkpoint.
The system therefore trades extra computation for reduced memory.
Simple Example
Consider a computation chain:
Full reverse mode stores:
Suppose checkpointing stores only:
Backward execution proceeds as follows.
Reverse Through Final Segment
To differentiate the segment
the system recomputes:
starting from stored checkpoint .
Then it performs reverse propagation through that segment.
Reverse Through Earlier Segment
To differentiate
the system recomputes:
from checkpoint .
This avoids storing all intermediates simultaneously.
Computational Interpretation
Checkpointing partitions the computational graph into regions.
Each region has:
- a stored entry state;
- recomputed internal states;
- a local backward sweep.
Conceptually:
forward:
save selected checkpoints
backward:
restore checkpoint
recompute local forward region
execute local reverse passThe backward pass therefore becomes an alternating process of:
- restoration;
- recomputation;
- reverse accumulation.
Why Checkpointing Matters
Modern deep learning models often exceed accelerator memory limits if every activation is stored.
For a transformer with hundreds of layers:
| Resource | Scaling |
|---|---|
| parameters | large |
| optimizer state | very large |
| activations | often dominant |
Checkpointing allows training larger models by reducing activation memory.
Without checkpointing, many large models would not fit into available GPU memory.
Memory-Compute Tradeoff
Checkpointing reduces memory at the cost of additional forward recomputation.
Suppose:
layers exist.
Full storage:
| Metric | Cost |
|---|---|
| memory | |
| forward compute | |
| backward compute |
No storage:
| Metric | Cost |
|---|---|
| memory | |
| backward recomputation |
Checkpointing:
| Metric | Cost |
|---|---|
| memory | reduced |
| compute | moderately increased |
The goal is to approach minimal memory without excessive recomputation.
Segment Checkpointing
The simplest strategy divides the program into fixed-size segments.
Suppose a computation has
layers and segment size
Store checkpoints every layers:
During backward execution:
- reload checkpoint;
- recompute segment forward;
- execute segment backward.
This reduces memory approximately by factor , while increasing compute by roughly the same factor.
Example in Neural Networks
Suppose a neural network computes:
Without checkpointing, reverse mode stores every activation:
With checkpointing:
may be stored instead.
To backpropagate through layers to :
- restore ;
- recompute ;
- run backward pass.
Memory falls substantially.
Recursive Checkpointing
More advanced strategies recursively subdivide the computation.
Instead of uniform segments, the system:
- stores a checkpoint;
- recursively recomputes subregions;
- recursively backpropagates.
This can asymptotically reduce memory further.
A famous result is the revolve algorithm, which computes optimal schedules for certain sequential computations.
Recursive checkpointing becomes important for:
- very deep networks;
- long simulations;
- adjoint PDE solvers;
- neural ODEs.
Binomial Checkpointing
Binomial checkpointing uses recursive schedules that achieve sublinear memory growth.
The key idea is:
- save carefully chosen checkpoints;
- recompute strategically;
- avoid repeated recomputation explosion.
For long chains, binomial schedules can provide:
| Quantity | Complexity |
|---|---|
| memory | |
| recomputation | moderate |
These methods are mathematically elegant but operationally more complex.
Online and Offline Schedules
Checkpoint schedules may be:
| Type | Description |
|---|---|
| offline | planned before execution |
| online | adapt dynamically during execution |
Offline schedules assume known graph structure.
Online schedules adapt to:
- dynamic control flow;
- runtime memory pressure;
- variable tensor sizes;
- device availability.
Dynamic models often require online checkpointing decisions.
Tensor Checkpointing
Not all tensors have equal value.
Some activations are:
- large but cheap to recompute;
- small but expensive to recompute.
Good checkpoint systems consider:
| Property | Importance |
|---|---|
| tensor size | memory cost |
| recomputation FLOPs | compute cost |
| reuse frequency | backward demand |
| kernel latency | hardware effect |
Example:
| Tensor Type | Common Strategy |
|---|---|
| ReLU outputs | recompute |
| attention logits | often store |
| matrix factorization state | store |
| random masks | store seed |
The optimal policy depends on hardware and workload.
Checkpointing and Randomness
Random operations complicate recomputation.
Example:
Backward propagation requires the same dropout mask used during the forward pass.
If the mask changes during recomputation, gradients become incorrect.
Solutions include:
| Method | Strategy |
|---|---|
| store random mask | simplest |
| store RNG seed | lower memory |
| deterministic RNG replay | reproducible recomputation |
Modern systems frequently save RNG state rather than full masks.
Checkpointing and Control Flow
Dynamic branches require replaying the same execution path.
Example:
if x > 0:
y = f(x)
else:
y = g(x)Backward recomputation must execute the same branch as the original forward pass.
Therefore the system must preserve:
- branch decisions;
- loop iteration counts;
- dynamic execution traces.
This metadata is often much smaller than full tensor storage.
Checkpointing in Recurrent Models
Long recurrent computations are especially memory-intensive.
Example:
for t in 1..T:
h[t] = step(h[t-1], x[t])Naive reverse mode stores all hidden states:
For large , this becomes infeasible.
Checkpointing stores selected hidden states and recomputes missing segments during backpropagation through time.
Checkpointing in Neural ODEs
Neural ordinary differential equations use continuous-time dynamics:
The backward pass may reconstruct trajectories by solving another differential equation backward in time.
This can dramatically reduce memory usage.
However, backward reconstruction may introduce:
- numerical instability;
- trajectory drift;
- mismatch between forward and backward states.
Continuous-time checkpointing therefore becomes both a numerical and systems problem.
Hardware Effects
Checkpointing changes the balance between arithmetic and memory traffic.
Modern accelerators often have:
| Resource | Relative Speed |
|---|---|
| arithmetic throughput | extremely high |
| memory bandwidth | comparatively limited |
Recomputation may therefore be cheaper than storing and loading large tensors.
Checkpointing can improve throughput even while increasing FLOPs.
This is increasingly true on modern GPUs and TPUs.
Compiler-Assisted Checkpointing
Compilers may automatically choose checkpoint locations.
Optimization objectives include:
- minimize peak memory;
- minimize recomputation;
- satisfy hardware constraints;
- maximize parallelism.
Compiler-based systems can analyze:
| Information | Use |
|---|---|
| tensor sizes | memory estimation |
| dependency graph | recomputation planning |
| operation costs | scheduling |
| device topology | placement |
Automatic checkpoint scheduling remains an active research area.
Gradient Correctness
Checkpointing must preserve mathematical equivalence with full reverse mode.
Recomputed values must match the original forward computation closely enough that reverse propagation remains valid.
Problems arise with:
- nondeterministic kernels;
- floating-point instability;
- stateful operations;
- stochastic algorithms.
Correct checkpointing therefore requires careful management of execution determinism.
Conceptual Summary
Checkpointing reduces reverse-mode memory usage by storing selected intermediate states and recomputing missing values during the backward pass.
It transforms the reverse pass into an alternating sequence of:
- checkpoint restoration;
- local forward recomputation;
- local backward propagation.
The fundamental tradeoff is:
| Lower Memory | Higher Compute |
|---|
Modern automatic differentiation systems rely heavily on checkpointing because memory capacity, rather than arithmetic throughput, is often the primary bottleneck in large-scale differentiable computation.