Automatic differentiation can be understood as a transformation from one program into another program.
Automatic differentiation can be understood as a transformation from one program into another program.
The original program computes values.
The transformed program computes values and derivatives.
This view is important because real AD systems operate on code, traces, graphs, or compiler intermediate representations. They are not merely applying calculus rules on paper. They rewrite computation.
From Mathematical Function to Program
A mathematical function may be written as:
A program evaluates it as a sequence of operations:
v1 = x * x
v2 = 3 * x
v3 = v1 + v2
v4 = sin(v3)
return v4Automatic differentiation transforms this program into another program.
For forward mode, the transformed program carries tangent values.
For reverse mode, the transformed program records operations and later propagates adjoints.
Forward-Mode Transformation
Forward mode rewrites every variable into a pair:
The primal component stores the ordinary value.
The tangent component stores the directional derivative.
Original operation:
becomes:
Original operation:
becomes:
For the example:
v1 = x * x
v2 = 3 * x
v3 = v1 + v2
v4 = sin(v3)
return v4the forward-transformed program is:
v1 = x * x
dot_v1 = x * dot_x + x * dot_x
v2 = 3 * x
dot_v2 = 3 * dot_x
v3 = v1 + v2
dot_v3 = dot_v1 + dot_v2
v4 = sin(v3)
dot_v4 = cos(v3) * dot_v3
return v4, dot_v4The program still runs in the original order. The derivative computation is interleaved with the primal computation.
Reverse-Mode Transformation
Reverse mode rewrites the program into two phases.
The first phase runs forward and records enough information for differentiation.
The second phase runs backward and accumulates adjoints.
Original program:
v1 = x * x
v2 = 3 * x
v3 = v1 + v2
v4 = sin(v3)
return v4Reverse-transformed program:
# forward pass
v1 = x * x
record(mul, x, x, v1)
v2 = 3 * x
record(mul_const, 3, x, v2)
v3 = v1 + v2
record(add, v1, v2, v3)
v4 = sin(v3)
record(sin, v3, v4)
# reverse pass
bar_v4 = 1
backward sin:
bar_v3 += cos(v3) * bar_v4
backward add:
bar_v1 += bar_v3
bar_v2 += bar_v3
backward mul_const:
bar_x += 3 * bar_v2
backward mul:
bar_x += x * bar_v1
bar_x += x * bar_v1
return v4, bar_xThis transformation changes evaluation structure more than forward mode. Reverse mode must preserve information from the forward pass so that local derivative rules can be evaluated later.
Program Transformation Targets
AD can transform different representations.
| Target | Description |
|---|---|
| Source code | Rewrite the user program directly |
| Operator overload trace | Record operations during execution |
| Bytecode | Transform lower-level executable instructions |
| Compiler IR | Transform SSA, graph IR, MLIR, or LLVM IR |
| Tensor graph | Transform tensor-level dataflow graph |
Each target gives different tradeoffs.
Source transformation can produce efficient code, but handling a full programming language is difficult.
Operator overloading is simpler to implement, but it may have higher runtime overhead.
Compiler IR transformation enables optimization, but requires integration with compiler infrastructure.
Source Transformation
Source transformation generates derivative source code from ordinary source code.
Example:
func f(x):
return sin(x * x)may become:
func f_forward(x, dot_x):
v1 = x * x
dot_v1 = x * dot_x + x * dot_x
y = sin(v1)
dot_y = cos(v1) * dot_v1
return y, dot_yThis approach has several advantages:
- derivative code is explicit
- compiler optimizations can apply
- no runtime tracing is required
- generated code can be inspected
But it must handle:
- function calls
- loops
- mutation
- aliasing
- closures
- exceptions
- external libraries
Source transformation is powerful but language dependent.
Operator Overloading
Operator overloading changes the meaning of primitive operations for special AD values.
Instead of using ordinary numbers, the program uses objects such as:
Dual(value, tangent)Then ordinary operations are overloaded:
Dual(x, dx) * Dual(y, dy)
=
Dual(x * y, x * dy + y * dx)The user program may remain almost unchanged.
Example:
x = Dual(3, 1)
y = sin(x * x)returns both:
- primal value
- derivative value
Operator overloading is common because it is easy to embed in host languages.
Its limits are:
- runtime overhead
- less compiler visibility
- difficulty optimizing across operation boundaries
- possible interaction problems with mutation and control flow
Tracing
Tracing records operations while the program executes.
A tracer replaces ordinary values with trace values. Each primitive operation appends a node to a graph or tape.
Example:
x = trace_input("x")
v1 = x * x # records mul
v2 = sin(v1) # records sinThe trace can then be transformed into derivative code.
Tracing works well for numerical programs that execute representative control flow.
However, tracing captures the executed path. If control flow depends on input values, different inputs may produce different traces.
Compiler IR Transformation
Modern AD systems often transform intermediate representations.
A compiler may lower code into SSA form:
v1 = mul x x
v2 = sin v1
return v2AD can then operate on this lower-level form.
Benefits:
- every value has a unique definition
- data dependencies are explicit
- optimization passes already exist
- reverse-mode adjoints can be inserted systematically
- dead derivative code can be eliminated
SSA form is especially useful for reverse mode because mutation is converted into value versions.
AD and SSA
In static single assignment form, each variable is assigned once.
Imperative code:
x = x * x
x = x + 1becomes:
x1 = x0 * x0
x2 = x1 + 1This makes derivative propagation clearer.
Reverse mode can assign adjoints to value versions:
bar_x2 = 1
bar_x1 += bar_x2
bar_x0 += 2 * x0 * bar_x1Without SSA, the name x would refer to several different program states, which complicates reverse accumulation.
Transformation and Optimization
Once derivative code exists, normal compiler optimization becomes important.
Common optimizations include:
- common subexpression elimination
- dead code elimination
- algebraic simplification
- memory reuse
- fusion
- inlining
- checkpoint placement
- layout optimization
Derivative code can be larger than primal code. Optimization prevents unnecessary overhead.
For example:
dot_v1 = x * dot_x + x * dot_xcan be simplified to:
dot_v1 = 2 * x * dot_xSimilarly, if only some inputs require derivatives, unused adjoint computations can be removed.
Transformation Boundaries
AD systems must decide what to transform.
A program may call:
- built-in primitives
- user-defined functions
- library functions
- foreign functions
- system APIs
For each call, the AD system needs one of:
| Case | Handling |
|---|---|
| Differentiable primitive | use registered derivative rule |
| User function | recursively transform |
| Library function | use custom derivative rule |
| Opaque external call | stop differentiation or require user rule |
| Non-differentiable operation | return zero, error, or subgradient |
Boundary handling determines how useful an AD system is in real programs.
Custom Derivative Rules
Some functions should not be differentiated by expanding their implementation.
Example:
A naive expansion may be numerically unstable.
A system should use a stable primitive and a custom derivative rule.
Custom rules are also useful for:
- solvers
- linear algebra decompositions
- special functions
- external kernels
- probabilistic operations
- implicit functions
A transformation-based AD system needs an extension mechanism for such rules.
Transformation of Control Flow
Forward mode usually preserves control flow directly.
Reverse mode must invert control dependencies carefully.
For a loop:
for i in range(n):
x = step(x, i)reverse mode often stores all intermediate states:
x0, x1, ..., xnThen it runs backward:
for i in reversed(range(n)):
propagate through step(x_i, i)This is why long loops can be memory intensive.
Checkpointing changes the transformed program by recomputing some intermediate states instead of storing them all.
Transformation of Mutation
Mutation requires careful semantics.
Example:
a[i] = a[i] + xThe derivative of this operation depends on:
- previous value of
a[i] - aliasing
- whether other references observe the mutation
- order of writes
A robust AD compiler must model mutation explicitly.
Common strategies:
- convert mutation into pure updates
- use SSA value versions
- record mutation logs
- restrict in-place mutation
- define custom adjoints for update operations
In-place mutation can improve performance, but it complicates correctness.
Transformation of Data Structures
Programs use data structures, not only scalars and tensors.
An AD system may need to handle:
- arrays
- structs
- tuples
- trees
- dictionaries
- sparse matrices
- graphs
The transformation must map each differentiable field to a corresponding tangent or adjoint field.
Example:
State {
position
velocity
mass
}If mass is constant and position, velocity are differentiable, the tangent state contains only the relevant derivative fields.
This is where type systems help.
Type-Level View
Forward mode transforms a type into a tangent bundle type:
Reverse mode transforms computations so that each type has an associated cotangent type:
For simple arrays, the tangent and cotangent types often match the primal type.
For structured objects, they may differ.
Examples:
- integer indices have no tangent
- booleans have no tangent
- floating arrays have floating tangents
- sparse structures may have sparse or dense cotangents
A well-designed AD system needs precise derivative types.
AD as a Compiler Pass
In a compiler pipeline, AD can be placed as a transformation pass.
A simplified pipeline:
source program
↓
typed AST
↓
intermediate representation
↓
AD transformation
↓
optimization
↓
lowering
↓
machine code / kernel codeThis view treats differentiation as compilation, not runtime magic.
The AD pass introduces tangent or adjoint computations. Later passes optimize and lower them.
Correctness of Transformation
A program transformation is correct if the transformed program computes the derivative of the original program.
For forward mode, correctness means:
For reverse mode, correctness means:
These equations define the semantic contract of AD.
Everything else is implementation.
Summary
Automatic differentiation is best understood as program transformation.
Forward mode transforms each value into a primal-tangent pair.
Reverse mode transforms evaluation into a forward recording pass plus a backward adjoint pass.
The transformation may operate on source code, traces, compiler IR, tensor graphs, or runtime tapes.
This perspective explains why AD is both mathematical and systems-oriented: it depends on the chain rule, but it must also handle real programming language features such as control flow, mutation, types, memory, and external calls.