Skip to content

AD as Program Transformation

Automatic differentiation can be understood as a transformation from one program into another program.

Automatic differentiation can be understood as a transformation from one program into another program.

The original program computes values.

The transformed program computes values and derivatives.

This view is important because real AD systems operate on code, traces, graphs, or compiler intermediate representations. They are not merely applying calculus rules on paper. They rewrite computation.

From Mathematical Function to Program

A mathematical function may be written as:

f(x)=sin(x2+3x) f(x)=\sin(x^2+3x)

A program evaluates it as a sequence of operations:

v1 = x * x
v2 = 3 * x
v3 = v1 + v2
v4 = sin(v3)
return v4

Automatic differentiation transforms this program into another program.

For forward mode, the transformed program carries tangent values.

For reverse mode, the transformed program records operations and later propagates adjoints.

Forward-Mode Transformation

Forward mode rewrites every variable into a pair:

v(v,v˙) v \mapsto (v,\dot v)

The primal component stores the ordinary value.

The tangent component stores the directional derivative.

Original operation:

z=x+y z=x+y

becomes:

(z,z˙)=(x+y,x˙+y˙) (z,\dot z)=(x+y,\dot x+\dot y)

Original operation:

z=xy z=xy

becomes:

(z,z˙)=(xy,yx˙+xy˙) (z,\dot z)=(xy,y\dot x+x\dot y)

For the example:

v1 = x * x
v2 = 3 * x
v3 = v1 + v2
v4 = sin(v3)
return v4

the forward-transformed program is:

v1     = x * x
dot_v1 = x * dot_x + x * dot_x

v2     = 3 * x
dot_v2 = 3 * dot_x

v3     = v1 + v2
dot_v3 = dot_v1 + dot_v2

v4     = sin(v3)
dot_v4 = cos(v3) * dot_v3

return v4, dot_v4

The program still runs in the original order. The derivative computation is interleaved with the primal computation.

Reverse-Mode Transformation

Reverse mode rewrites the program into two phases.

The first phase runs forward and records enough information for differentiation.

The second phase runs backward and accumulates adjoints.

Original program:

v1 = x * x
v2 = 3 * x
v3 = v1 + v2
v4 = sin(v3)
return v4

Reverse-transformed program:

# forward pass
v1 = x * x
record(mul, x, x, v1)

v2 = 3 * x
record(mul_const, 3, x, v2)

v3 = v1 + v2
record(add, v1, v2, v3)

v4 = sin(v3)
record(sin, v3, v4)

# reverse pass
bar_v4 = 1

backward sin:
    bar_v3 += cos(v3) * bar_v4

backward add:
    bar_v1 += bar_v3
    bar_v2 += bar_v3

backward mul_const:
    bar_x += 3 * bar_v2

backward mul:
    bar_x += x * bar_v1
    bar_x += x * bar_v1

return v4, bar_x

This transformation changes evaluation structure more than forward mode. Reverse mode must preserve information from the forward pass so that local derivative rules can be evaluated later.

Program Transformation Targets

AD can transform different representations.

TargetDescription
Source codeRewrite the user program directly
Operator overload traceRecord operations during execution
BytecodeTransform lower-level executable instructions
Compiler IRTransform SSA, graph IR, MLIR, or LLVM IR
Tensor graphTransform tensor-level dataflow graph

Each target gives different tradeoffs.

Source transformation can produce efficient code, but handling a full programming language is difficult.

Operator overloading is simpler to implement, but it may have higher runtime overhead.

Compiler IR transformation enables optimization, but requires integration with compiler infrastructure.

Source Transformation

Source transformation generates derivative source code from ordinary source code.

Example:

func f(x):
    return sin(x * x)

may become:

func f_forward(x, dot_x):
    v1 = x * x
    dot_v1 = x * dot_x + x * dot_x

    y = sin(v1)
    dot_y = cos(v1) * dot_v1

    return y, dot_y

This approach has several advantages:

  • derivative code is explicit
  • compiler optimizations can apply
  • no runtime tracing is required
  • generated code can be inspected

But it must handle:

  • function calls
  • loops
  • mutation
  • aliasing
  • closures
  • exceptions
  • external libraries

Source transformation is powerful but language dependent.

Operator Overloading

Operator overloading changes the meaning of primitive operations for special AD values.

Instead of using ordinary numbers, the program uses objects such as:

Dual(value, tangent)

Then ordinary operations are overloaded:

Dual(x, dx) * Dual(y, dy)
=
Dual(x * y, x * dy + y * dx)

The user program may remain almost unchanged.

Example:

x = Dual(3, 1)
y = sin(x * x)

returns both:

  • primal value
  • derivative value

Operator overloading is common because it is easy to embed in host languages.

Its limits are:

  • runtime overhead
  • less compiler visibility
  • difficulty optimizing across operation boundaries
  • possible interaction problems with mutation and control flow

Tracing

Tracing records operations while the program executes.

A tracer replaces ordinary values with trace values. Each primitive operation appends a node to a graph or tape.

Example:

x = trace_input("x")
v1 = x * x        # records mul
v2 = sin(v1)      # records sin

The trace can then be transformed into derivative code.

Tracing works well for numerical programs that execute representative control flow.

However, tracing captures the executed path. If control flow depends on input values, different inputs may produce different traces.

Compiler IR Transformation

Modern AD systems often transform intermediate representations.

A compiler may lower code into SSA form:

v1 = mul x x
v2 = sin v1
return v2

AD can then operate on this lower-level form.

Benefits:

  • every value has a unique definition
  • data dependencies are explicit
  • optimization passes already exist
  • reverse-mode adjoints can be inserted systematically
  • dead derivative code can be eliminated

SSA form is especially useful for reverse mode because mutation is converted into value versions.

AD and SSA

In static single assignment form, each variable is assigned once.

Imperative code:

x = x * x
x = x + 1

becomes:

x1 = x0 * x0
x2 = x1 + 1

This makes derivative propagation clearer.

Reverse mode can assign adjoints to value versions:

bar_x2 = 1
bar_x1 += bar_x2
bar_x0 += 2 * x0 * bar_x1

Without SSA, the name x would refer to several different program states, which complicates reverse accumulation.

Transformation and Optimization

Once derivative code exists, normal compiler optimization becomes important.

Common optimizations include:

  • common subexpression elimination
  • dead code elimination
  • algebraic simplification
  • memory reuse
  • fusion
  • inlining
  • checkpoint placement
  • layout optimization

Derivative code can be larger than primal code. Optimization prevents unnecessary overhead.

For example:

dot_v1 = x * dot_x + x * dot_x

can be simplified to:

dot_v1 = 2 * x * dot_x

Similarly, if only some inputs require derivatives, unused adjoint computations can be removed.

Transformation Boundaries

AD systems must decide what to transform.

A program may call:

  • built-in primitives
  • user-defined functions
  • library functions
  • foreign functions
  • system APIs

For each call, the AD system needs one of:

CaseHandling
Differentiable primitiveuse registered derivative rule
User functionrecursively transform
Library functionuse custom derivative rule
Opaque external callstop differentiation or require user rule
Non-differentiable operationreturn zero, error, or subgradient

Boundary handling determines how useful an AD system is in real programs.

Custom Derivative Rules

Some functions should not be differentiated by expanding their implementation.

Example:

logsumexp(x)=logiexi \operatorname{logsumexp}(x) = \log\sum_i e^{x_i}

A naive expansion may be numerically unstable.

A system should use a stable primitive and a custom derivative rule.

Custom rules are also useful for:

  • solvers
  • linear algebra decompositions
  • special functions
  • external kernels
  • probabilistic operations
  • implicit functions

A transformation-based AD system needs an extension mechanism for such rules.

Transformation of Control Flow

Forward mode usually preserves control flow directly.

Reverse mode must invert control dependencies carefully.

For a loop:

for i in range(n):
    x = step(x, i)

reverse mode often stores all intermediate states:

x0, x1, ..., xn

Then it runs backward:

for i in reversed(range(n)):
    propagate through step(x_i, i)

This is why long loops can be memory intensive.

Checkpointing changes the transformed program by recomputing some intermediate states instead of storing them all.

Transformation of Mutation

Mutation requires careful semantics.

Example:

a[i] = a[i] + x

The derivative of this operation depends on:

  • previous value of a[i]
  • aliasing
  • whether other references observe the mutation
  • order of writes

A robust AD compiler must model mutation explicitly.

Common strategies:

  • convert mutation into pure updates
  • use SSA value versions
  • record mutation logs
  • restrict in-place mutation
  • define custom adjoints for update operations

In-place mutation can improve performance, but it complicates correctness.

Transformation of Data Structures

Programs use data structures, not only scalars and tensors.

An AD system may need to handle:

  • arrays
  • structs
  • tuples
  • trees
  • dictionaries
  • sparse matrices
  • graphs

The transformation must map each differentiable field to a corresponding tangent or adjoint field.

Example:

State {
    position
    velocity
    mass
}

If mass is constant and position, velocity are differentiable, the tangent state contains only the relevant derivative fields.

This is where type systems help.

Type-Level View

Forward mode transforms a type TT into a tangent bundle type:

TT×T˙ T \mapsto T \times \dot T

Reverse mode transforms computations so that each type has an associated cotangent type:

TT T \mapsto T^*

For simple arrays, the tangent and cotangent types often match the primal type.

For structured objects, they may differ.

Examples:

  • integer indices have no tangent
  • booleans have no tangent
  • floating arrays have floating tangents
  • sparse structures may have sparse or dense cotangents

A well-designed AD system needs precise derivative types.

AD as a Compiler Pass

In a compiler pipeline, AD can be placed as a transformation pass.

A simplified pipeline:

source program
typed AST
intermediate representation
AD transformation
optimization
lowering
machine code / kernel code

This view treats differentiation as compilation, not runtime magic.

The AD pass introduces tangent or adjoint computations. Later passes optimize and lower them.

Correctness of Transformation

A program transformation is correct if the transformed program computes the derivative of the original program.

For forward mode, correctness means:

forward(p)(x,x˙)=(p(x),Dp(x)x˙) \text{forward}(p)(x,\dot x) = (p(x),Dp(x)\dot x)

For reverse mode, correctness means:

reverse(p)(x,yˉ)=(p(x),Dp(x)Tyˉ) \text{reverse}(p)(x,\bar y) = (p(x),Dp(x)^T\bar y)

These equations define the semantic contract of AD.

Everything else is implementation.

Summary

Automatic differentiation is best understood as program transformation.

Forward mode transforms each value into a primal-tangent pair.

Reverse mode transforms evaluation into a forward recording pass plus a backward adjoint pass.

The transformation may operate on source code, traces, compiler IR, tensor graphs, or runtime tapes.

This perspective explains why AD is both mathematical and systems-oriented: it depends on the chain rule, but it must also handle real programming language features such as control flow, mutation, types, memory, and external calls.