# Nested AD

## Nested AD

Nested automatic differentiation means applying automatic differentiation inside another automatic differentiation computation.

Conceptually, this means differentiating derivative computations.

Examples include:

| Expression | Meaning |
|---|---|
| $\nabla(\nabla f)$ | Hessian |
| $D(Df)$ | second directional derivative |
| $\nabla(Df[v])$ | Hessian-vector product |
| $D(\nabla f)$ | Jacobian of gradient |
| $\nabla(\text{optimization step})$ | meta-gradient |

Nested AD is essential for higher-order derivatives, implicit differentiation, meta-learning, differentiable optimization, and differentiable programming systems.

The mathematics is straightforward. The implementation is subtle.

## First-Order AD as a Transformation

An AD transform maps a program into another program.

If

$$
f : X \to Y,
$$

then forward mode produces a transformed program:

$$
Df : TX \to TY,
$$

where $T$ represents tangent information.

Reverse mode produces a program that computes pullbacks:

$$
f^\ast : T^\ast Y \to T^\ast X.
$$

Nested AD applies these transformations repeatedly.

For example:

$$
D(Df)
$$

means applying forward-mode transformation twice.

Similarly,

$$
\nabla(\nabla f)
$$

means applying reverse mode to a reverse-mode derivative computation.

The resulting program has multiple derivative levels active simultaneously.

## Example: Forward-over-Forward

Consider

$$
f(x) = x^3.
$$

First forward mode computes:

$$
f(x + \epsilon v) =
x^3 + 3x^2v\,\epsilon.
$$

The tangent coefficient is:

$$
Df(x)[v] = 3x^2v.
$$

Apply forward mode again with another perturbation:

$$
x + \epsilon_1 v + \epsilon_2 w + \epsilon_1\epsilon_2 u.
$$

Now mixed infinitesimal terms appear:

$$
f(x + \cdots) =
x^3
+
3x^2(\cdots)
+
6xvw\,\epsilon_1\epsilon_2
+
\cdots
$$

The coefficient of

$$
\epsilon_1\epsilon_2
$$

contains second-order information.

Nested forward mode therefore naturally produces higher derivatives.

## Dual Number Nesting

Forward mode is commonly implemented using dual numbers:

$$
a + b\epsilon,
\quad
\epsilon^2 = 0.
$$

Nested forward mode uses nested dual structures:

$$
(a + b\epsilon_1)
+
(c + d\epsilon_1)\epsilon_2.
$$

Expanding gives:

$$
a
+
b\epsilon_1
+
c\epsilon_2
+
d\epsilon_1\epsilon_2.
$$

Each infinitesimal direction corresponds to a derivative level.

The mixed term

$$
d\epsilon_1\epsilon_2
$$

encodes second-order interaction.

This construction generalizes to arbitrary derivative order, but the algebra grows rapidly.

## Reverse-over-Reverse Nesting

Nested reverse mode is more complicated.

A reverse-mode computation creates adjoints and backward passes. Differentiating that computation introduces adjoints of adjoints.

Suppose:

```text
g = grad(f)
H = grad(g)
```

The second gradient differentiates the reverse pass used by the first gradient.

This means:

1. The original primal computation must remain differentiable.
2. The backward pass must itself behave like differentiable code.
3. Adjoint accumulation logic becomes part of the differentiated computation.

The system must distinguish derivative levels carefully.

## Mixed-Mode Nesting

Many practical systems use mixed-mode nesting.

Examples include:

| Nesting | Purpose |
|---|---|
| forward-over-reverse | Hessian-vector products |
| reverse-over-forward | directional derivative gradients |
| reverse-over-reverse | higher-order scalar derivatives |
| forward-over-forward | Taylor coefficients |
| reverse-over-forward-over-reverse | advanced implicit differentiation |

Mixed mode is often preferable because different AD modes have complementary strengths.

For scalar-output functions:

| Mode | Efficient dimension |
|---|---|
| forward mode | small input dimension |
| reverse mode | small output dimension |

Nested systems combine these strengths.

## Perturbation Confusion

The most famous nested-AD failure mode is perturbation confusion.

Suppose two forward-mode computations accidentally share the same infinitesimal symbol:

$$
\epsilon.
$$

Then derivative information from different levels mixes incorrectly.

For example:

$$
(a + b\epsilon) + (c + d\epsilon) =
(a+c) + (b+d)\epsilon.
$$

If these infinitesimals were meant to represent different derivative levels, the result is wrong.

The problem becomes severe when derivative computations are passed through higher-order functions or closures.

## Tagging Derivative Levels

Correct nested forward mode assigns a unique tag to each perturbation level.

Instead of one universal infinitesimal:

$$
\epsilon,
$$

the system uses:

$$
\epsilon_1,
\epsilon_2,
\epsilon_3,
\ldots
$$

with independent nilpotent behavior.

Then:

$$
\epsilon_i \epsilon_j \ne 0
\quad \text{for } i \ne j,
$$

while

$$
\epsilon_i^2 = 0.
$$

Each derivative transform introduces a fresh perturbation identity.

Implementation-wise, systems usually represent this with:

| Technique | Description |
|---|---|
| unique IDs | each AD transform gets fresh perturbation tag |
| lexical scoping | perturbations scoped to derivative level |
| type-level tagging | derivative levels encoded in types |
| runtime tagging | tags stored dynamically |
| staged transforms | derivative levels separated during compilation |

Without tagging, nested forward mode is unreliable.

## Cotangent Level Separation

Reverse mode has a similar issue.

Adjoints from different derivative levels must remain separate.

Suppose:

```text
outer gradient
    inner gradient
```

The inner reverse pass should not accidentally consume cotangents from the outer reverse pass.

A correct nested reverse system tracks cotangent levels explicitly.

Conceptually:

| Level | Meaning |
|---|---|
| primal level | original computation |
| tangent/cotangent level 1 | first derivative |
| tangent/cotangent level 2 | second derivative |
| tangent/cotangent level k | kth derivative |

This separation is part of the semantics, not merely debugging metadata.

## Closures and Higher-Order Functions

Nested AD becomes harder in languages with closures and higher-order functions.

Example:

```text
def outer(x):
    def inner(y):
        return x * y
    return grad(inner)
```

The derivative transform captures free variables from the outer scope.

The AD system must decide:

| Question | Requirement |
|---|---|
| which values are primal | original computation |
| which values are tangent | derivative level |
| which values are cotangent | reverse accumulation |
| which closures capture derivative state | nested differentiation |
| which tapes belong to which level | nesting correctness |

Functional languages often model this more cleanly because closures are explicit semantic objects.

## Dynamic Computation Graphs

Nested AD interacts strongly with dynamic graph systems.

If the computation graph depends on runtime control flow, the graph structure itself may differ between derivative levels.

For example:

```text
if grad(f)(x) > 0:
    ...
```

The derivative computation influences the primal control flow.

A nested AD system must specify:

1. Which computations are traced.
2. Which branches are differentiated.
3. Whether graph structure is static or dynamic.
4. Whether nested derivative traces are compositional.

Different frameworks make different choices.

## Differentiating Optimizers

Nested AD is central in differentiable optimization.

Suppose gradient descent performs:

$$
x_{t+1} =
x_t - \eta \nabla f(x_t).
$$

Now suppose we want derivatives with respect to hyperparameters:

$$
\frac{\partial x_T}{\partial \eta}.
$$

The optimizer itself becomes part of the differentiated program.

This requires differentiating through:

| Object | Derivative target |
|---|---|
| gradients | first-order derivatives |
| update rules | optimization dynamics |
| momentum accumulators | optimizer state |
| learning rate schedules | hyperparameters |
| inner training loops | meta-learning |

Nested AD enables this.

## Implicit Differentiation

Nested AD also appears in implicit differentiation.

Suppose:

$$
g(x, y(x)) = 0.
$$

We may avoid differentiating every iteration of a solver by differentiating the fixed-point condition itself.

Still, the resulting derivative computations often involve nested linearizations and reverse passes.

This is especially common in:

| Area | Example |
|---|---|
| meta-learning | differentiating equilibrium states |
| optimization layers | differentiating argmin operators |
| physics simulation | differentiating steady states |
| probabilistic inference | differentiating fixed-point solvers |

Nested AD provides the underlying machinery.

## Tape Nesting

In tape-based reverse mode, nested AD introduces nested tapes.

Possible models include:

| Model | Description |
|---|---|
| independent tapes | each derivative level has separate tape |
| hierarchical tapes | tapes reference parent tapes |
| reentrant tapes | backward passes may themselves record operations |
| staged tapes | derivative levels separated during compilation |

Incorrect tape interaction can cause:

| Failure | Meaning |
|---|---|
| tape corruption | nested passes overwrite state |
| missing gradients | tape lifetime ends too early |
| duplicated gradients | nested replay occurs twice |
| memory explosion | all nested levels retained simultaneously |

Robust nested systems need explicit tape ownership semantics.

## Compiler Perspective

A compiler-based AD system often handles nesting better than runtime operator overloading systems.

Instead of dynamically stacking derivative objects, the compiler transforms intermediate representations explicitly:

```text
program
→ linearized program
→ transposed program
→ differentiated transposed program
```

This exposes derivative levels structurally.

Compiler IRs can annotate:

| IR annotation | Purpose |
|---|---|
| primal variable | original value |
| tangent variable | forward derivative |
| cotangent variable | reverse derivative |
| residual | saved intermediate |
| derivative level | nesting separation |

This structure reduces ambiguity.

## Complexity Explosion

Higher-order derivatives grow rapidly.

For dimension $n$:

| Derivative order | Tensor size |
|---|---:|
| gradient | $n$ |
| Hessian | $n^2$ |
| third derivative tensor | $n^3$ |
| fourth derivative tensor | $n^4$ |

Nested AD can therefore create exponential growth in storage and computation.

Most practical systems avoid materializing full higher-order tensors.

Instead they compute:

| Operation | Scalable form |
|---|---|
| Hessian-vector product | $Hv$ |
| Jacobian-vector product | $Jv$ |
| vector-Hessian-vector | $v^\top Hv$ |
| directional kth derivative | scalar directional expansion |

Operator forms scale better than explicit tensors.

## Practical Design Principles

A robust nested AD system should:

| Principle | Reason |
|---|---|
| separate derivative levels | avoid perturbation confusion |
| represent derivatives explicitly | improve correctness |
| isolate tapes per level | prevent state corruption |
| distinguish primal/tangent/cotangent data | preserve semantics |
| expose operator APIs | avoid tensor explosion |
| support compositional transforms | enable higher-order programming |

Nested AD is fundamentally about composing derivative transformations safely and predictably.

## Conceptual View

Automatic differentiation is often introduced as a technique for computing gradients.

Nested AD reveals a deeper interpretation.

Differentiation becomes a compositional program transform:

$$
\mathcal{D}(\mathcal{D}(f)),
\quad
\mathcal{D}(\mathcal{D}(\mathcal{D}(f))),
\quad \ldots
$$

The challenge is no longer merely computing derivatives. The challenge is preserving semantic structure across multiple interacting derivative levels while controlling memory, complexity, and numerical stability.

