Skip to content

Nested AD

Nested automatic differentiation means applying automatic differentiation inside another automatic differentiation computation.

Nested automatic differentiation means applying automatic differentiation inside another automatic differentiation computation.

Conceptually, this means differentiating derivative computations.

Examples include:

ExpressionMeaning
(f)\nabla(\nabla f)Hessian
D(Df)D(Df)second directional derivative
(Df[v])\nabla(Df[v])Hessian-vector product
D(f)D(\nabla f)Jacobian of gradient
(optimization step)\nabla(\text{optimization step})meta-gradient

Nested AD is essential for higher-order derivatives, implicit differentiation, meta-learning, differentiable optimization, and differentiable programming systems.

The mathematics is straightforward. The implementation is subtle.

First-Order AD as a Transformation

An AD transform maps a program into another program.

If

f:XY, f : X \to Y,

then forward mode produces a transformed program:

Df:TXTY, Df : TX \to TY,

where TT represents tangent information.

Reverse mode produces a program that computes pullbacks:

f:TYTX. f^\ast : T^\ast Y \to T^\ast X.

Nested AD applies these transformations repeatedly.

For example:

D(Df) D(Df)

means applying forward-mode transformation twice.

Similarly,

(f) \nabla(\nabla f)

means applying reverse mode to a reverse-mode derivative computation.

The resulting program has multiple derivative levels active simultaneously.

Example: Forward-over-Forward

Consider

f(x)=x3. f(x) = x^3.

First forward mode computes:

f(x+ϵv)=x3+3x2vϵ. f(x + \epsilon v) = x^3 + 3x^2v\,\epsilon.

The tangent coefficient is:

Df(x)[v]=3x2v. Df(x)[v] = 3x^2v.

Apply forward mode again with another perturbation:

x+ϵ1v+ϵ2w+ϵ1ϵ2u. x + \epsilon_1 v + \epsilon_2 w + \epsilon_1\epsilon_2 u.

Now mixed infinitesimal terms appear:

f(x+)=x3+3x2()+6xvwϵ1ϵ2+ f(x + \cdots) = x^3 + 3x^2(\cdots) + 6xvw\,\epsilon_1\epsilon_2 + \cdots

The coefficient of

ϵ1ϵ2 \epsilon_1\epsilon_2

contains second-order information.

Nested forward mode therefore naturally produces higher derivatives.

Dual Number Nesting

Forward mode is commonly implemented using dual numbers:

a+bϵ,ϵ2=0. a + b\epsilon, \quad \epsilon^2 = 0.

Nested forward mode uses nested dual structures:

(a+bϵ1)+(c+dϵ1)ϵ2. (a + b\epsilon_1) + (c + d\epsilon_1)\epsilon_2.

Expanding gives:

a+bϵ1+cϵ2+dϵ1ϵ2. a + b\epsilon_1 + c\epsilon_2 + d\epsilon_1\epsilon_2.

Each infinitesimal direction corresponds to a derivative level.

The mixed term

dϵ1ϵ2 d\epsilon_1\epsilon_2

encodes second-order interaction.

This construction generalizes to arbitrary derivative order, but the algebra grows rapidly.

Reverse-over-Reverse Nesting

Nested reverse mode is more complicated.

A reverse-mode computation creates adjoints and backward passes. Differentiating that computation introduces adjoints of adjoints.

Suppose:

g = grad(f)
H = grad(g)

The second gradient differentiates the reverse pass used by the first gradient.

This means:

  1. The original primal computation must remain differentiable.
  2. The backward pass must itself behave like differentiable code.
  3. Adjoint accumulation logic becomes part of the differentiated computation.

The system must distinguish derivative levels carefully.

Mixed-Mode Nesting

Many practical systems use mixed-mode nesting.

Examples include:

NestingPurpose
forward-over-reverseHessian-vector products
reverse-over-forwarddirectional derivative gradients
reverse-over-reversehigher-order scalar derivatives
forward-over-forwardTaylor coefficients
reverse-over-forward-over-reverseadvanced implicit differentiation

Mixed mode is often preferable because different AD modes have complementary strengths.

For scalar-output functions:

ModeEfficient dimension
forward modesmall input dimension
reverse modesmall output dimension

Nested systems combine these strengths.

Perturbation Confusion

The most famous nested-AD failure mode is perturbation confusion.

Suppose two forward-mode computations accidentally share the same infinitesimal symbol:

ϵ. \epsilon.

Then derivative information from different levels mixes incorrectly.

For example:

(a+bϵ)+(c+dϵ)=(a+c)+(b+d)ϵ. (a + b\epsilon) + (c + d\epsilon) = (a+c) + (b+d)\epsilon.

If these infinitesimals were meant to represent different derivative levels, the result is wrong.

The problem becomes severe when derivative computations are passed through higher-order functions or closures.

Tagging Derivative Levels

Correct nested forward mode assigns a unique tag to each perturbation level.

Instead of one universal infinitesimal:

ϵ, \epsilon,

the system uses:

ϵ1,ϵ2,ϵ3, \epsilon_1, \epsilon_2, \epsilon_3, \ldots

with independent nilpotent behavior.

Then:

ϵiϵj0for ij, \epsilon_i \epsilon_j \ne 0 \quad \text{for } i \ne j,

while

ϵi2=0. \epsilon_i^2 = 0.

Each derivative transform introduces a fresh perturbation identity.

Implementation-wise, systems usually represent this with:

TechniqueDescription
unique IDseach AD transform gets fresh perturbation tag
lexical scopingperturbations scoped to derivative level
type-level taggingderivative levels encoded in types
runtime taggingtags stored dynamically
staged transformsderivative levels separated during compilation

Without tagging, nested forward mode is unreliable.

Cotangent Level Separation

Reverse mode has a similar issue.

Adjoints from different derivative levels must remain separate.

Suppose:

outer gradient
    inner gradient

The inner reverse pass should not accidentally consume cotangents from the outer reverse pass.

A correct nested reverse system tracks cotangent levels explicitly.

Conceptually:

LevelMeaning
primal leveloriginal computation
tangent/cotangent level 1first derivative
tangent/cotangent level 2second derivative
tangent/cotangent level kkth derivative

This separation is part of the semantics, not merely debugging metadata.

Closures and Higher-Order Functions

Nested AD becomes harder in languages with closures and higher-order functions.

Example:

def outer(x):
    def inner(y):
        return x * y
    return grad(inner)

The derivative transform captures free variables from the outer scope.

The AD system must decide:

QuestionRequirement
which values are primaloriginal computation
which values are tangentderivative level
which values are cotangentreverse accumulation
which closures capture derivative statenested differentiation
which tapes belong to which levelnesting correctness

Functional languages often model this more cleanly because closures are explicit semantic objects.

Dynamic Computation Graphs

Nested AD interacts strongly with dynamic graph systems.

If the computation graph depends on runtime control flow, the graph structure itself may differ between derivative levels.

For example:

if grad(f)(x) > 0:
    ...

The derivative computation influences the primal control flow.

A nested AD system must specify:

  1. Which computations are traced.
  2. Which branches are differentiated.
  3. Whether graph structure is static or dynamic.
  4. Whether nested derivative traces are compositional.

Different frameworks make different choices.

Differentiating Optimizers

Nested AD is central in differentiable optimization.

Suppose gradient descent performs:

xt+1=xtηf(xt). x_{t+1} = x_t - \eta \nabla f(x_t).

Now suppose we want derivatives with respect to hyperparameters:

xTη. \frac{\partial x_T}{\partial \eta}.

The optimizer itself becomes part of the differentiated program.

This requires differentiating through:

ObjectDerivative target
gradientsfirst-order derivatives
update rulesoptimization dynamics
momentum accumulatorsoptimizer state
learning rate scheduleshyperparameters
inner training loopsmeta-learning

Nested AD enables this.

Implicit Differentiation

Nested AD also appears in implicit differentiation.

Suppose:

g(x,y(x))=0. g(x, y(x)) = 0.

We may avoid differentiating every iteration of a solver by differentiating the fixed-point condition itself.

Still, the resulting derivative computations often involve nested linearizations and reverse passes.

This is especially common in:

AreaExample
meta-learningdifferentiating equilibrium states
optimization layersdifferentiating argmin operators
physics simulationdifferentiating steady states
probabilistic inferencedifferentiating fixed-point solvers

Nested AD provides the underlying machinery.

Tape Nesting

In tape-based reverse mode, nested AD introduces nested tapes.

Possible models include:

ModelDescription
independent tapeseach derivative level has separate tape
hierarchical tapestapes reference parent tapes
reentrant tapesbackward passes may themselves record operations
staged tapesderivative levels separated during compilation

Incorrect tape interaction can cause:

FailureMeaning
tape corruptionnested passes overwrite state
missing gradientstape lifetime ends too early
duplicated gradientsnested replay occurs twice
memory explosionall nested levels retained simultaneously

Robust nested systems need explicit tape ownership semantics.

Compiler Perspective

A compiler-based AD system often handles nesting better than runtime operator overloading systems.

Instead of dynamically stacking derivative objects, the compiler transforms intermediate representations explicitly:

program
→ linearized program
→ transposed program
→ differentiated transposed program

This exposes derivative levels structurally.

Compiler IRs can annotate:

IR annotationPurpose
primal variableoriginal value
tangent variableforward derivative
cotangent variablereverse derivative
residualsaved intermediate
derivative levelnesting separation

This structure reduces ambiguity.

Complexity Explosion

Higher-order derivatives grow rapidly.

For dimension nn:

Derivative orderTensor size
gradientnn
Hessiann2n^2
third derivative tensorn3n^3
fourth derivative tensorn4n^4

Nested AD can therefore create exponential growth in storage and computation.

Most practical systems avoid materializing full higher-order tensors.

Instead they compute:

OperationScalable form
Hessian-vector productHvHv
Jacobian-vector productJvJv
vector-Hessian-vectorvHvv^\top Hv
directional kth derivativescalar directional expansion

Operator forms scale better than explicit tensors.

Practical Design Principles

A robust nested AD system should:

PrincipleReason
separate derivative levelsavoid perturbation confusion
represent derivatives explicitlyimprove correctness
isolate tapes per levelprevent state corruption
distinguish primal/tangent/cotangent datapreserve semantics
expose operator APIsavoid tensor explosion
support compositional transformsenable higher-order programming

Nested AD is fundamentally about composing derivative transformations safely and predictably.

Conceptual View

Automatic differentiation is often introduced as a technique for computing gradients.

Nested AD reveals a deeper interpretation.

Differentiation becomes a compositional program transform:

D(D(f)),D(D(D(f))), \mathcal{D}(\mathcal{D}(f)), \quad \mathcal{D}(\mathcal{D}(\mathcal{D}(f))), \quad \ldots

The challenge is no longer merely computing derivatives. The challenge is preserving semantic structure across multiple interacting derivative levels while controlling memory, complexity, and numerical stability.