Recursion is control flow where a function calls itself. In automatic differentiation, recursion behaves like a loop with a call stack. Each recursive call contributes one...
Recursion is control flow where a function calls itself. In automatic differentiation, recursion behaves like a loop with a call stack. Each recursive call contributes one layer to the computation, and differentiation applies the chain rule across the call tree.
A simple recursive function has the form:
def f(x, n):
if n == 0:
return x
return g(f(x, n - 1))This denotes repeated composition:
So the derivative is:
where .
Recursion as a Dynamic Computation Graph
Every recursive call creates a new frame. The frame contains local variables, arguments, return values, and sometimes saved intermediates.
For example:
def pow2(x, n):
if n == 0:
return x
y = pow2(x, n - 1)
return y * yThe call pow2(x, 3) expands into:
y0 = x
y1 = y0 * y0
y2 = y1 * y1
y3 = y2 * y2This is the same computation as a loop, but the structure is produced by nested calls rather than iteration.
Forward mode can propagate tangents through the recursive calls as they return. Reverse mode records the call tree and walks it backward.
Forward Mode Through Recursion
Forward mode augments each value with a tangent:
For a recursive function, each call receives both the primal argument and its tangent. The recursive structure is unchanged.
def pow2_jvp(x, dx, n):
if n == 0:
return x, dx
y, dy = pow2_jvp(x, dx, n - 1)
out = y * y
dout = y * dy + dy * y
return out, doutFor , the primal result is:
and the tangent is:
Forward mode needs only the active call stack. It does not need to store a separate global tape unless the implementation chooses to do so.
Reverse Mode Through Recursion
Reverse mode handles recursion by recording the operations and calls that occurred during the forward pass. The backward pass follows the reverse of the actual call tree.
For a recursive function:
def f(x):
if stop(x):
return base(x)
y = f(step(x))
return combine(x, y)the forward pass builds a dynamic tree or chain of calls. The backward pass accumulates adjoints from the return value back into every active frame.
A conceptual reverse pass looks like:
forward:
push frame
call recursively
compute output
backward:
start from output adjoint
pop frames in reverse order
apply local adjoint rulesFor linear recursion, the call structure is a stack. For branching recursion, the call structure is a tree.
Branching Recursion
Some recursive functions call themselves more than once:
def tree(x, n):
if n == 0:
return h(x)
a = tree(left(x), n - 1)
b = tree(right(x), n - 1)
return combine(a, b)The computation graph is now a tree. Reverse mode must accumulate gradients from all branches that depend on the same input.
If both left(x) and right(x) depend on x, then the adjoint of x receives contributions from both subcalls:
The key operation is accumulation. Shared inputs receive summed adjoints.
Base Cases and Differentiability
A recursive function must eventually reach a base case. The base case is also part of the differentiated program.
def f(x, n):
if n == 0:
return abs(x)
return f(x * x, n - 1)The derivative depends on the base case. If the base case is non-smooth, the whole recursive function inherits that non-smoothness at inputs that reach it.
As with ordinary conditionals, AD differentiates the executed base case. It does not prove that all possible recursive paths are smooth.
Data-Dependent Recursion Depth
Recursion depth may depend on input values:
def f(x):
if x < 1:
return x
return f(x / 2)The number of recursive calls changes discontinuously as crosses thresholds. AD treats the observed recursion depth as fixed during differentiation.
This is locally valid where small perturbations do not change the number of calls. At thresholds where recursion depth changes, the mathematical derivative may fail to exist or may differ from the derivative of the executed path.
Recursive Data Structures
Recursion often appears with trees, graphs, lists, and symbolic expressions.
Example:
def eval(node):
if node.kind == "const":
return node.value
if node.kind == "add":
return eval(node.left) + eval(node.right)
if node.kind == "mul":
return eval(node.left) * eval(node.right)Differentiating such a program means differentiating the evaluation of a recursive data structure. The derivative follows the structure that was traversed.
For expression trees, reverse mode resembles symbolic backpropagation over the tree. For graph-like structures with sharing, the derivative must respect shared subexpressions and accumulate adjoints only once per shared node.
Memoization and Sharing
Memoized recursion stores results to avoid repeated work:
cache = {}
def f(x, node):
if node in cache:
return cache[node]
y = compute(x, node)
cache[node] = y
return yMemoization changes the execution graph from a tree into a directed acyclic graph. Reverse mode must accumulate adjoints from every use of a cached value.
If a cached value is used three times, its adjoint receives three contributions. Treating the computation as a tree would recompute the value; treating it as a graph stores it once and sums incoming adjoints.
This distinction matters for both correctness and performance.
Tail Recursion
A tail-recursive function has the recursive call as its final action:
def sum_loop(xs, i, acc):
if i == len(xs):
return acc
return sum_loop(xs, i + 1, acc + xs[i])Tail recursion is equivalent to a loop. A compiler may transform it into iteration.
For AD, this means tail-recursive differentiation can often use loop differentiation machinery:
acc = acc0
for i in range(len(xs)):
acc = acc + xs[i]Forward mode remains streaming. Reverse mode still needs intermediate information when the loop body contains nonlinear operations.
Recursive Fixed Points
Some recursive definitions describe fixed points:
For example, a recursive solver may call itself until convergence. Differentiating through every call gives the derivative of the finite execution. Differentiating the limiting fixed point gives an implicit derivative.
These are different objects.
Finite recursion gives:
Implicit recursion gives:
The finite derivative depends on the number of calls. The implicit derivative depends on the fixed-point equation.
A later section treats implicit differentiation in detail.
Stack, Tape, and Memory
Recursion consumes memory in two ways.
First, the ordinary runtime call stack stores active frames. Second, reverse-mode AD may store intermediate values for the backward pass.
For deep recursion, this can cause stack overflow or excessive tape growth.
Common implementation strategies include:
| Strategy | Description |
|---|---|
| Convert recursion to loops | Replace tail recursion or linear recursion with iteration |
| Use explicit stacks | Store frames in heap-allocated data structures |
| Checkpoint recursive calls | Store selected frames and recompute others |
| Memoize subproblems | Avoid repeated recursive work |
| Use custom adjoints | Replace recursive backward pass with a closed-form derivative |
Correctness Rule
For recursion, AD differentiates the call tree produced by the primal execution.
Forward mode:
propagates tangents through recursive calls.
Reverse mode:
records executed calls and propagates adjoints backward.
Branching recursion:
accumulates adjoints from all executed branches.
Data-dependent depth:
differentiates the observed depth.Recursion therefore poses no new mathematical rule beyond the chain rule. Its difficulty is operational: dynamic call structure, stack growth, sharing, memoization, and data-dependent termination.