# Memory-Time Tradeoffs

## Memory-Time Tradeoffs

Reverse mode automatic differentiation is computationally efficient for scalar-output functions, but it has a major systems cost: it needs information from the forward pass during the backward pass.

The reverse pass cannot usually compute derivatives from the output value alone. It needs intermediate primal values, shapes, branches, and sometimes full solver or kernel state. The system must either store this information or recompute it.

This creates the central tradeoff:

| Strategy | Memory | Compute |
|---|---:|---:|
| Store everything | high | low |
| Recompute everything | low | high |
| Checkpoint selectively | medium | medium |

### Why Memory Is Needed

Consider multiplication:

$$
z = xy.
$$

The reverse rules are

$$
\bar x \mathrel{+}= \bar z y,
$$

$$
\bar y \mathrel{+}= \bar z x.
$$

The backward pass needs the original forward values $x$ and $y$. If they were discarded, the system must recompute them.

For a deep computation,

$$
v_1 \to v_2 \to \cdots \to v_k \to y,
$$

the backward pass visits operations in reverse order:

$$
v_k \to v_{k-1} \to \cdots \to v_1.
$$

Each reverse step may need values produced much earlier in the forward pass. Keeping all of them can dominate memory usage.

### The Full Storage Strategy

The simplest reverse mode implementation stores every intermediate value.

Forward pass:

```text
for op in program:
    compute output
    save output and needed inputs
```

Backward pass:

```text
for op in reverse(program):
    read saved values
    apply reverse rule
```

This strategy is fast because the backward pass has direct access to every required value.

Its cost is memory.

If a model has $L$ layers and each layer activation requires $A$ bytes, activation storage is roughly

$$
O(LA).
$$

For large neural networks, activations can use more memory than parameters.

### The Recomputation Strategy

The opposite strategy stores almost nothing.

During the backward pass, whenever a needed value is missing, the system reruns part of the forward computation to recover it.

Memory falls, but compute increases.

In the extreme case, each backward step may require recomputing a long prefix of the program. For a chain of $L$ operations, this can increase cost from linear to quadratic.

| Strategy | Forward storage | Backward cost for chain |
|---|---:|---:|
| Store all values | $O(L)$ | $O(L)$ |
| Recompute from start each time | $O(1)$ | $O(L^2)$ |

Pure recomputation is therefore rarely acceptable for large programs.

### Checkpointing

Checkpointing stores selected intermediate states and recomputes only between those states.

Suppose a computation has $L$ stages. Instead of storing all $L$ activations, the system stores checkpoints every $s$ stages.

During the backward pass, to differentiate a segment, it reloads the nearest checkpoint and recomputes forward values inside that segment.

Memory becomes smaller than full storage, while compute remains much smaller than full recomputation.

Conceptually:

```text
forward:
    save selected checkpoints

backward:
    for each segment in reverse:
        restore checkpoint
        recompute segment forward
        run local backward pass
```

Checkpointing is one of the main techniques for training deeper models under fixed memory budgets.

### Simple Chain Example

Consider a chain:

$$
x \to v_1 \to v_2 \to v_3 \to v_4 \to y.
$$

Full storage saves:

$$
x, v_1, v_2, v_3, v_4, y.
$$

A checkpointed strategy might save only:

$$
x, v_2, y.
$$

To compute derivatives for the segment $v_2 \to v_4 \to y$, the system recomputes $v_3$ and $v_4$ from $v_2$.

To compute derivatives for the segment $x \to v_2$, it recomputes $v_1$ and $v_2$ from $x$.

This reduces stored values from all intermediates to a smaller set of boundary states.

### Activation Memory in Neural Networks

In neural network training, the largest reverse-mode memory cost is often activation storage.

For a layer

$$
h_{i+1} = F_i(h_i, \theta_i),
$$

the backward rule often requires $h_i$, $\theta_i$, and sometimes $h_{i+1}$.

Parameters $\theta_i$ are already stored. Activations $h_i$ are produced dynamically and must be retained or recomputed.

For a network with many layers, the activation memory scales with depth and batch size.

Important drivers include:

| Factor | Memory Effect |
|---|---|
| batch size | linear growth |
| sequence length | often linear or quadratic growth |
| hidden dimension | linear growth |
| number of layers | linear growth |
| saved attention maps | can be very large |

This is why memory optimization is central to training large transformer models.

### Invertible Computations

Some computations allow the input to be reconstructed from the output.

If

$$
y = g(x)
$$

is invertible, then the backward pass may recover

$$
x = g^{-1}(y)
$$

instead of storing $x$.

Invertible neural networks use this idea to reduce activation memory. They store less forward state and reconstruct hidden activations during the backward pass.

The tradeoff is that the inverse computation must be stable and cheaper than storing the missing values.

### Rematerialization

Rematerialization is another name for recomputing discarded intermediates.

A compiler or runtime may choose to rematerialize cheap values instead of loading them from memory.

For example, recomputing

$$
x + 1
$$

may be cheaper than storing and later reading it.

For expensive operations, such as matrix multiplication or factorization, storage is often preferable.

The decision depends on:

1. compute cost;
2. memory bandwidth;
3. tensor size;
4. hardware architecture;
5. reuse count.

### Memory Bandwidth Versus Arithmetic

On modern accelerators, memory traffic often limits performance.

Saving all activations can make the backward pass memory-bound. Recomputing selected values may reduce memory pressure and improve throughput, even though it performs extra arithmetic.

This is counterintuitive but common: more floating-point operations can make a program faster if it avoids slow memory movement.

A practical AD system therefore optimizes both memory footprint and memory bandwidth, not only operation count.

### Granularity of Checkpoints

Checkpointing can operate at different levels.

| Level | Example |
|---|---|
| primitive operation | individual arithmetic op |
| tensor operation | matrix multiplication, convolution |
| layer | transformer block |
| module | encoder stack |
| program region | loop body, solver step |

Fine-grained checkpointing gives more control but adds overhead.

Coarse-grained checkpointing is simpler and usually matches how users think about models.

In deep learning systems, checkpointing is often applied at the layer or block level.

### Schedules

A checkpoint schedule decides which values to store and which regions to recompute.

For a simple chain, uniform spacing may work well. For irregular computation graphs, the optimal schedule depends on graph shape and operation cost.

A schedule should consider:

1. long-lived values;
2. branch fan-out;
3. expensive operations;
4. tensor sizes;
5. available device memory.

Poor schedules can save small tensors while recomputing expensive kernels. Good schedules save high-value boundary states and rematerialize cheap interiors.

### Interaction With Control Flow

Dynamic control flow complicates memory-time tradeoffs.

If the forward pass follows an input-dependent branch, the backward pass must follow the same executed branch. The system must record enough control-flow information to replay or reverse the computation correctly.

For loops, the system may need per-iteration checkpoints.

Example:

```text
for i in 1..T:
    h = step(h, x[i])
```

Full storage saves every $h_i$.

Checkpointing saves selected $h_i$ values and recomputes loop segments during backward execution.

Long recurrent computations and neural ODE solvers often require careful checkpointing because $T$ can be large.

### Interaction With Mutation

Mutation makes recomputation harder.

If a value is overwritten during the forward pass, the backward pass may need the old value.

Example:

```text
x = f(x)
x = g(x)
```

The reverse rule for $g$ may need the input to $g$, while the reverse rule for $f$ may need the input to $f$.

A correct system must preserve or reconstruct each version.

This is why many AD systems use immutable intermediate values or compiler forms such as SSA. Each assignment creates a distinct version, so the reverse pass can refer to the correct value.

### Numerical Effects of Recomputation

Recomputation may not produce bit-identical values.

Sources of differences include:

1. nondeterministic parallel reductions;
2. random number generation;
3. floating-point associativity;
4. stateful kernels;
5. mixed precision.

For stable training, recomputed values must be consistent enough for the backward pass.

Random operations require special handling. The system may store random seeds or masks rather than full outputs. Dropout is a common example.

### Practical Policy

A practical reverse-mode engine usually follows a mixed policy:

| Value Type | Common Policy |
|---|---|
| small scalar metadata | store |
| large cheap activations | recompute |
| expensive factorization results | store |
| random masks | store seed or mask |
| control-flow decisions | store |
| tensor shapes and strides | store |
| outputs of non-deterministic kernels | store |

The right policy is workload-dependent.

### Conceptual Summary

Reverse mode gains time efficiency by traversing a recorded computation backward. That efficiency requires access to forward information.

Memory-time tradeoffs determine how that information is made available.

The three basic choices are:

1. store values from the forward pass;
2. recompute values during the backward pass;
3. checkpoint selected states and recompute locally.

Production AD systems combine these strategies. The goal is to minimize memory pressure without making the backward pass too expensive or numerically inconsistent.

