Skip to content

Higher-Order Reverse Mode

Reverse mode is efficient for scalar-output functions because it propagates one adjoint backward through the computation and produces a full gradient. For

Reverse mode is efficient for scalar-output functions because it propagates one adjoint backward through the computation and produces a full gradient. For

f:RnR, f : \mathbb{R}^n \to \mathbb{R},

one reverse pass computes

f(x) \nabla f(x)

at a cost comparable to a small constant multiple of evaluating ff.

Higher-order reverse mode asks for derivatives of reverse-mode derivative computations. The simplest case is differentiating the gradient:

D(f)(x)=2f(x). D(\nabla f)(x) = \nabla^2 f(x).

The idea is mathematically clean. The implementation is much harder.

First-Order Reverse Mode Recap

A program computes intermediate values:

v1,v2,,vk. v_1, v_2, \ldots, v_k.

The final output is

y=vk. y = v_k.

Reverse mode associates each intermediate value viv_i with an adjoint:

vˉi=yvi. \bar{v}_i = \frac{\partial y}{\partial v_i}.

The backward pass starts with

yˉ=1. \bar{y} = 1.

Then it propagates adjoints backward through local derivative rules.

For example, if

z=xy, z = xy,

then reverse mode applies

xˉ+=zˉy, \bar{x} \mathrel{+}= \bar{z}y, yˉ+=zˉx. \bar{y} \mathrel{+}= \bar{z}x.

This produces first derivatives.

Differentiating the Backward Pass

Higher-order reverse mode differentiates these backward computations.

For

z=xy, z = xy,

the backward rules are:

xˉ+=zˉy, \bar{x} \mathrel{+}= \bar{z}y, yˉ+=zˉx. \bar{y} \mathrel{+}= \bar{z}x.

These rules are themselves programs. They use multiplication, addition, saved primal values, and adjoints. If we differentiate this backward program, we obtain second-order information.

This means a higher-order reverse system must treat the backward pass as differentiable code, not as an opaque implementation detail.

Reverse-over-Reverse

Reverse-over-reverse means applying reverse mode to a computation that was itself produced by reverse mode.

Conceptually:

g = grad(f)
H = jacobian(g)

If the second Jacobian is computed by reverse mode, the system is doing reverse-over-reverse.

This can compute second derivatives, but it creates several problems.

The first reverse pass records the primal computation. The second reverse pass may need to record the backward computation. This may include adjoint accumulation, tape reads, saved intermediates, and control flow in derivative rules.

The resulting computation can be much larger than the original one.

Tapes Become Part of the Program

Many reverse-mode systems use a tape. The tape records operations from the forward pass so that the backward pass can replay them in reverse.

For first-order AD, the tape is an implementation device.

For higher-order reverse mode, the tape can become part of the differentiated computation.

This creates hard questions:

IssueWhy it matters
tape allocationallocation behavior may affect differentiated execution
mutationadjoints are often accumulated destructively
aliasingmultiple references may point to the same storage
saved valueshigher-order rules may need more saved intermediates
custom gradientsfirst-order custom rules may not have valid higher derivatives
control flowbackward control flow must remain differentiable

A clean higher-order system separates mathematical derivative semantics from tape mechanics.

Mutation and Adjoint Accumulation

Reverse mode commonly updates adjoints by mutation:

x_bar += z_bar * y
y_bar += z_bar * x

This is natural and efficient.

But higher-order differentiation must account for how these updates depend on primal values and incoming adjoints. Mutation introduces ordering and aliasing concerns.

For example, if two paths contribute to the same adjoint, reverse mode sums them. The sum is mathematically commutative, but floating point mutation has an order. Higher-order differentiation can expose this ordering when numerical reproducibility matters.

A robust implementation may use an intermediate representation where adjoint accumulation is explicit and analyzable.

Saved Primal Values

Reverse-mode rules often need primal values from the forward pass.

For

z=sinx, z = \sin x,

the backward rule is

xˉ+=zˉcosx. \bar{x} \mathrel{+}= \bar{z}\cos x.

The backward pass needs xx, or some equivalent saved value.

For second derivatives, differentiating the backward rule also needs the derivative of cosx\cos x:

d(xˉ)=d(zˉ)cosxzˉsinxdx. d(\bar{x}) = d(\bar{z})\cos x - \bar{z}\sin x\,dx.

So a second-order system must preserve enough information to differentiate the backward rule itself.

If a first-order system saves too little, higher-order differentiation may be impossible or incorrect.

Custom Gradients

Many AD systems allow custom first-order gradients.

For example, a user may define:

forward: y = op(x)
backward: x_bar = custom_rule(x, y, y_bar)

This is enough for first-order AD.

For higher-order AD, the custom backward rule must also be differentiable and mathematically correct.

A custom rule can be first-order correct but second-order wrong.

For example, a rule may stop gradients through an intermediate value for numerical reasons. That may preserve the first derivative while destroying the second derivative.

Therefore, production systems should distinguish:

Rule typeMeaning
first-order custom gradientvalid for gradients only
higher-order custom gradientvalid under nested AD
non-differentiable custom gradientblocks higher-order differentiation
symbolic derivative rulesupplies explicit higher-order behavior

This distinction prevents silent errors.

Reverse Mode for Vector Outputs

For

F:RnRm, F : \mathbb{R}^n \to \mathbb{R}^m,

reverse mode computes vector-Jacobian products:

wJF(x). w^\top J_F(x).

Higher-order reverse mode can differentiate these products.

If

ϕ(x)=wF(x), \phi(x) = w^\top F(x),

then reverse mode computes

ϕ(x)=JF(x)w. \nabla \phi(x) = J_F(x)^\top w.

Differentiating this gradient gives

2ϕ(x). \nabla^2 \phi(x).

So higher-order reverse mode naturally handles scalarizations of vector-output functions.

This is important in machine learning, where losses are scalar but models are vector-valued internally.

Reverse Mode and Hessian-Vector Products

Higher-order reverse mode can compute Hessian-vector products, but pure reverse-over-reverse is often not the best route.

For scalar ff, a Hessian-vector product can be computed as:

Hf(x)v=(f(x)v). H_f(x)v = \nabla(\nabla f(x)^\top v).

This uses reverse mode on the scalar function

ϕ(x)=f(x)v. \phi(x) = \nabla f(x)^\top v.

The inner gradient usually comes from reverse mode. The outer gradient also uses reverse mode. That is reverse-over-reverse.

It works, but may require differentiating the backward pass. Forward-over-reverse often avoids some of this complexity by pushing a tangent through the gradient computation instead.

Complexity and Memory

Higher-order reverse mode can be expensive because it nests derivative computations.

For first-order reverse mode, memory is dominated by saved intermediates.

For higher-order reverse mode, memory may include:

Memory sourceDescription
primal tapeoperations from original forward pass
backward tapeoperations from derivative computation
saved primal valuesvalues needed by first backward pass
saved adjoint valuesvalues needed by differentiated backward pass
nested AD metadatatags, levels, tangent or adjoint structures
temporary derivative arraysintermediate higher-order values

The exact cost depends on the program and implementation. But the key point is stable: higher-order reverse mode can consume much more memory than first-order reverse mode.

Perturbation and Cotangent Levels

Nested AD needs distinct derivative levels.

In forward mode, this prevents perturbation confusion. In reverse mode, the analogous issue concerns adjoint levels.

A nested AD system must know which derivative level each tangent or cotangent belongs to.

Otherwise, an inner derivative computation may accidentally consume or modify derivative information intended for an outer computation.

Correct systems usually track derivative levels explicitly:

level 0: primal computation
level 1: first derivative
level 2: second derivative

This bookkeeping is not cosmetic. It is part of the semantics of nested differentiation.

Checkpointing for Higher-Order Reverse

Checkpointing trades recomputation for memory.

In first-order reverse mode, checkpointing avoids saving every intermediate value. During the backward pass, some values are recomputed as needed.

In higher-order reverse mode, checkpointing becomes more complicated because recomputation itself may occur inside a differentiated computation.

A checkpoint must specify:

QuestionRequirement
what is savedprimal values, derivative values, or both
what is recomputedforward code, backward code, or nested code
at which AD levelderivative level must remain consistent
with what side effectsrecomputation must preserve semantics

Checkpointing remains essential, but the implementation must be level-aware.

Implementation Strategy

A practical higher-order reverse system usually needs a disciplined internal representation.

Good designs tend to make these objects explicit:

primal value
tangent value
cotangent value
AD level
saved residuals
linearized function
transpose rule

The system should avoid hiding derivative semantics inside opaque mutation-heavy runtime code.

One useful design is to split differentiation into two phases:

linearize(f, x) -> y, pullback
transpose(linearized_program) -> reverse program

Then higher-order differentiation can operate on the linearized program representation rather than on ad hoc tape operations.

Practical Guidance

Higher-order reverse mode is powerful, but it should be used with care.

Use it when:

Use caseReason
differentiating optimization proceduresouter gradients require gradients of gradients
meta-learningtraining rules are themselves differentiated
implicit layersderivatives of solver outputs are needed
curvature analysissecond-order information is required
differentiable programming languagesnested AD is part of the language model

Avoid treating it as a default replacement for simpler methods.

For Hessian-vector products, prefer forward-over-reverse when available. For full Hessians, prefer structured or sparse methods when possible. For higher-order derivatives beyond second order, consider Taylor mode or specialized higher-order representations.

Design Principle

First-order reverse mode can be implemented as a runtime technique.

Higher-order reverse mode needs semantic discipline.

The backward pass must be a differentiable program with clear rules for values, adjoints, mutation, saved residuals, and derivative levels. Without that structure, nested reverse mode becomes fragile, memory-heavy, and prone to silent mathematical errors.