Skip to content

forward

A loop repeats a computation until a condition fails or a fixed iteration count is reached. In automatic differentiation, loops are important because many numerical algorithms...

Loops

A loop repeats a computation until a condition fails or a fixed iteration count is reached. In automatic differentiation, loops are important because many numerical algorithms are fundamentally iterative:

  • optimization methods
  • recurrent neural networks
  • numerical solvers
  • simulations
  • dynamic programming
  • graph algorithms

A loop introduces repeated dependence between program states. The derivative of the loop therefore depends on how perturbations propagate through the sequence of iterations.

A loop can be viewed as repeated function composition.

Consider the program:

x = x0

for i in range(n):
    x = f(x)

After nn iterations,

xn=f(n)(x0) x_n = f^{(n)}(x_0)

where f(n)f^{(n)} denotes nn-fold composition.

The derivative is obtained by repeated application of the chain rule:

dxndx0=f(xn1)f(xn2)f(x0) \frac{dx_n}{dx_0} = f'(x_{n-1}) f'(x_{n-2}) \cdots f'(x_0)

A loop is therefore a dynamic chain rule accumulator.

Loops as State Transitions

Most loops operate on a mutable program state. Let the loop state at iteration kk be

sk s_k

and let the loop body define a transition function

sk+1=T(sk) s_{k+1} = T(s_k)

Then after nn iterations,

sn=T(n)(s0) s_n = T^{(n)}(s_0)

The Jacobian of the full loop is

sns0=k=0n1sk+1sk \frac{\partial s_n}{\partial s_0} = \prod_{k=0}^{n-1} \frac{\partial s_{k+1}}{\partial s_k}

This is the fundamental differentiation rule for loops.

The loop body contributes a local Jacobian at every iteration. The full derivative is the product of all local Jacobians in execution order.

For scalar states:

dsnds0=k=0n1T(sk) \frac{ds_n}{ds_0} = \prod_{k=0}^{n-1} T'(s_k)

For vector states:

J=Jn1Jn2J0 J = J_{n-1} J_{n-2} \cdots J_0

where

Jk=sk+1sk J_k = \frac{\partial s_{k+1}}{\partial s_k}

Forward Mode Through Loops

Forward mode propagates tangents alongside primal values through every iteration.

Suppose:

x = x0

for i in range(n):
    x = f(x)

Forward mode augments the state:

(xk,x˙k) (x_k,\dot{x}_k)

At iteration kk,

xk+1=f(xk) x_{k+1} = f(x_k)

and

x˙k+1=f(xk)x˙k \dot{x}_{k+1} = f'(x_k)\dot{x}_k

The tangent evolves according to the linearized loop dynamics.

Conceptually:

x  = x0
dx = dx0

for i in range(n):
    x_new  = f(x)
    dx_new = f'(x) * dx

    x  = x_new
    dx = dx_new

The tangent update follows the same control flow and iteration count as the primal computation.

Forward mode is naturally streaming. It does not need to store previous loop states. Memory usage is therefore small:

O(1) O(1)

with respect to iteration count, assuming fixed-size state.

The cost is proportional to the number of propagated tangent directions.

Example: Repeated Squaring

Consider:

x = x0

for i in range(3):
    x = x * x

The iterations are:

x1=x02 x_1 = x_0^2 x2=x12=x04 x_2 = x_1^2 = x_0^4 x3=x22=x08 x_3 = x_2^2 = x_0^8

Thus,

dx3dx0=8x07 \frac{dx_3}{dx_0} = 8x_0^7

Forward mode computes:

Iteration 1:

x˙1=2x0x˙0 \dot{x}_1 = 2x_0\dot{x}_0

Iteration 2:

x˙2=2x1x˙1=4x03x˙0 \dot{x}_2 = 2x_1\dot{x}_1 = 4x_0^3\dot{x}_0

Iteration 3:

x˙3=2x2x˙2=8x07x˙0 \dot{x}_3 = 2x_2\dot{x}_2 = 8x_0^7\dot{x}_0

The tangent accumulates multiplicatively through the loop.

Reverse Mode Through Loops

Reverse mode differentiates loops by traversing iterations backward.

Suppose:

x = x0

for i in range(n):
    x = f(x)

The forward execution produces states:

x0,x1,,xn x_0,x_1,\dots,x_n

The reverse pass propagates adjoints backward:

xˉnxˉn1xˉ0 \bar{x}_n \rightarrow \bar{x}_{n-1} \rightarrow \cdots \rightarrow \bar{x}_0

using

xˉk=f(xk)Txˉk+1 \bar{x}_k = f'(x_k)^T \bar{x}_{k+1}

The backward pass therefore reverses the loop.

Conceptually:

# forward
states = [x0]

x = x0

for i in range(n):
    x = f(x)
    states.append(x)

# backward
x_bar = seed

for i in reversed(range(n)):
    x_prev = states[i]
    x_bar  = f'(x_prev)^T * x_bar

The reverse pass requires access to the intermediate states from the forward pass.

This introduces the major systems problem of reverse-mode loops:

Reverse mode usually needs loop-state storage.

If the loop runs for millions of iterations, storing all intermediate states may be prohibitively expensive.

Wengert Lists and Dynamic Tapes

In tape-based reverse mode, each loop iteration appends operations to a tape.

For example:

for i in range(n):
    x = sin(x)

generates a tape containing:

x1 = sin(x0)
x2 = sin(x1)
x3 = sin(x2)
...
xn = sin(xn-1)

The backward pass walks this tape in reverse order:

xˉk=cos(xk)xˉk+1 \bar{x}_{k} = \cos(x_k)\bar{x}_{k+1}

The tape length grows linearly with iteration count.

If the loop body is large or the iteration count is data-dependent, memory consumption can dominate execution cost.

Static vs Dynamic Loops

Loops fall into two broad categories.

Static Loops

A static loop has a compile-time known iteration count.

Example:

for i in range(10):
    x = f(x)

Compilers may unroll such loops partially or fully.

The differentiated program becomes a finite repeated composition.

Static loops are easier to optimize because:

  • iteration count is known
  • memory requirements are predictable
  • control flow is stable
  • vectorization is easier

Dynamic Loops

A dynamic loop depends on runtime values.

Example:

while norm(x) > eps:
    x = update(x)

The number of iterations depends on the input.

Dynamic loops are harder because:

  • forward and backward passes must agree on iteration count
  • loop states may vary in shape or structure
  • compilation becomes more difficult
  • tracing systems may specialize to observed iteration counts

The backward pass must replay exactly the same iteration structure as the forward pass.

Loop-Carried Dependencies

A loop-carried dependency occurs when one iteration depends on results from previous iterations.

Example:

h = h0

for t in range(T):
    h = tanh(W @ h + b)

This defines a recurrence:

ht+1=tanh(Wht+b) h_{t+1} = \tanh(Wh_t+b)

The derivative across time becomes:

hTh0=t=0T1Dt \frac{\partial h_T}{\partial h_0} = \prod_{t=0}^{T-1} D_t

where

Dt=diag(1tanh2(Wht+b))W D_t = \operatorname{diag}(1-\tanh^2(Wh_t+b))W

This repeated matrix multiplication causes two important phenomena:

  • vanishing gradients
  • exploding gradients

If the spectral norm satisfies

Dt<1 \|D_t\| < 1

then gradients shrink exponentially.

If

Dt>1 \|D_t\| > 1

then gradients grow exponentially.

This is a direct consequence of loop differentiation.

Backpropagation Through Time

Recurrent neural networks are differentiated by unrolling loops across time.

The recurrent program:

h = h0

for t in range(T):
    h = F(h, x[t])

is transformed into:

h1 = F(h0, x1)
h2 = F(h1, x2)
...
hT = F(hT-1, xT)

Reverse mode then propagates adjoints backward through the unrolled graph.

This procedure is called backpropagation through time (BPTT).

The loop becomes an explicit chain of repeated operations.

The cost scales linearly with sequence length:

  • forward cost: O(T)O(T)
  • backward cost: O(T)O(T)
  • memory cost: O(T)O(T)

for fixed hidden-state size.

Long sequences therefore create large memory pressure.

Checkpointing for Long Loops

Reverse mode stores intermediate states so the backward pass can reconstruct local derivatives.

For long loops, storing every state may be impossible.

Checkpointing trades computation for memory.

Instead of storing every state, the system stores selected checkpoints:

x0 ---- x100 ---- x200 ---- x300

During the backward pass, missing states are recomputed from nearby checkpoints.

This changes complexity:

StrategyMemoryRecomputation
Store all statesHighNone
Recompute everythingLowVery high
CheckpointingModerateModerate

Optimal checkpoint schedules are an important research topic in reverse-mode AD.

Differentiating While Loops

A while loop introduces a termination condition:

while c(x):
    x = f(x)

This defines a variable-length composition:

xn=f(n)(x0) x_n = f^{(n)}(x_0)

where nn depends on the input.

Differentiation assumes the iteration count remains locally stable.

If small perturbations change the number of iterations, the derivative may become discontinuous.

Example:

while x < 1:
    x = 2 * x

The number of iterations changes abruptly near powers of two.

The resulting function is piecewise smooth but globally non-smooth.

Most AD systems differentiate the observed execution path and treat the iteration count as fixed during local differentiation.

Fixed-Point Iterations

Many loops seek a fixed point:

x=f(x) x^* = f(x^*)

Example:

x = x0

for i in range(max_iter):
    x = f(x)

    if converged(x):
        break

Naively differentiating through all iterations may be expensive.

An alternative is implicit differentiation.

If

x=f(x,θ) x^* = f(x^*,\theta)

then differentiating gives

dxdθ=(Ifx)1fθ \frac{dx^*}{d\theta} = \left(I-\frac{\partial f}{\partial x}\right)^{-1} \frac{\partial f}{\partial \theta}

This avoids storing all intermediate iterations.

Implicit differentiation is discussed later in the book, but loops are the motivating structure behind it.

Parallel Loops and Associativity

Some loops represent reductions:

s = 0

for i in range(n):
    s += x[i]

The derivative is straightforward:

sxi=1 \frac{\partial s}{\partial x_i} = 1

But systems implementations may parallelize the reduction as a tree:

((x1+x2)+(x3+x4)) ((x_1+x_2)+(x_3+x_4))

instead of sequential accumulation.

Because floating-point addition is not associative, the exact primal and derivative values may vary across execution orders.

Parallel differentiation therefore introduces reproducibility concerns.

Mutation Inside Loops

Loops often mutate arrays or tensors in place:

for i in range(n):
    x[i] = x[i] * 2

Mutation complicates reverse mode because previous values may be overwritten before gradients are computed.

A reverse pass may require:

  • storing old values
  • versioning arrays
  • converting mutation into functional updates
  • using SSA-style representations

Many differentiable systems internally lower mutable loops into immutable graph representations.

Loop Fusion and Optimization

AD compilers frequently optimize loops after differentiation.

Example:

for i in range(n):
    y[i] = sin(x[i])

for i in range(n):
    z[i] = y[i] * 2

may be fused into:

for i in range(n):
    z[i] = 2 * sin(x[i])

Fusion reduces:

  • memory traffic
  • intermediate allocations
  • kernel launches
  • tape size

Loop optimization is therefore tightly connected to AD performance.

Correctness Rule

For a loop

s = s0

for i in range(n):
    s = T(s)

the derivative is obtained by repeated chain-rule application:

sns0=k=0n1sk+1sk \frac{\partial s_n}{\partial s_0} = \prod_{k=0}^{n-1} \frac{\partial s_{k+1}}{\partial s_k}

Forward mode propagates tangents through iterations in execution order.

Reverse mode propagates adjoints backward through iterations in reverse order.

The central systems problem is state management:

Forward mode:
    low memory, streaming propagation.

Reverse mode:
    efficient for scalar outputs,
    but usually requires loop-state storage.

Loops therefore transform differentiation from a local operation into a dynamical process across execution time.