# AD as Program Transformation

Automatic differentiation can be understood as a transformation from one program into another program.

The original program computes values.

The transformed program computes values and derivatives.

This view is important because real AD systems operate on code, traces, graphs, or compiler intermediate representations. They are not merely applying calculus rules on paper. They rewrite computation.

## From Mathematical Function to Program

A mathematical function may be written as:

$$
f(x)=\sin(x^2+3x)
$$

A program evaluates it as a sequence of operations:

```text
v1 = x * x
v2 = 3 * x
v3 = v1 + v2
v4 = sin(v3)
return v4
```

Automatic differentiation transforms this program into another program.

For forward mode, the transformed program carries tangent values.

For reverse mode, the transformed program records operations and later propagates adjoints.

## Forward-Mode Transformation

Forward mode rewrites every variable into a pair:

$$
v \mapsto (v,\dot v)
$$

The primal component stores the ordinary value.

The tangent component stores the directional derivative.

Original operation:

$$
z=x+y
$$

becomes:

$$
(z,\dot z)=(x+y,\dot x+\dot y)
$$

Original operation:

$$
z=xy
$$

becomes:

$$
(z,\dot z)=(xy,y\dot x+x\dot y)
$$

For the example:

```text
v1 = x * x
v2 = 3 * x
v3 = v1 + v2
v4 = sin(v3)
return v4
```

the forward-transformed program is:

```text
v1     = x * x
dot_v1 = x * dot_x + x * dot_x

v2     = 3 * x
dot_v2 = 3 * dot_x

v3     = v1 + v2
dot_v3 = dot_v1 + dot_v2

v4     = sin(v3)
dot_v4 = cos(v3) * dot_v3

return v4, dot_v4
```

The program still runs in the original order. The derivative computation is interleaved with the primal computation.

## Reverse-Mode Transformation

Reverse mode rewrites the program into two phases.

The first phase runs forward and records enough information for differentiation.

The second phase runs backward and accumulates adjoints.

Original program:

```text
v1 = x * x
v2 = 3 * x
v3 = v1 + v2
v4 = sin(v3)
return v4
```

Reverse-transformed program:

```text
# forward pass
v1 = x * x
record(mul, x, x, v1)

v2 = 3 * x
record(mul_const, 3, x, v2)

v3 = v1 + v2
record(add, v1, v2, v3)

v4 = sin(v3)
record(sin, v3, v4)

# reverse pass
bar_v4 = 1

backward sin:
    bar_v3 += cos(v3) * bar_v4

backward add:
    bar_v1 += bar_v3
    bar_v2 += bar_v3

backward mul_const:
    bar_x += 3 * bar_v2

backward mul:
    bar_x += x * bar_v1
    bar_x += x * bar_v1

return v4, bar_x
```

This transformation changes evaluation structure more than forward mode. Reverse mode must preserve information from the forward pass so that local derivative rules can be evaluated later.

## Program Transformation Targets

AD can transform different representations.

| Target | Description |
|---|---|
| Source code | Rewrite the user program directly |
| Operator overload trace | Record operations during execution |
| Bytecode | Transform lower-level executable instructions |
| Compiler IR | Transform SSA, graph IR, MLIR, or LLVM IR |
| Tensor graph | Transform tensor-level dataflow graph |

Each target gives different tradeoffs.

Source transformation can produce efficient code, but handling a full programming language is difficult.

Operator overloading is simpler to implement, but it may have higher runtime overhead.

Compiler IR transformation enables optimization, but requires integration with compiler infrastructure.

## Source Transformation

Source transformation generates derivative source code from ordinary source code.

Example:

```text
func f(x):
    return sin(x * x)
```

may become:

```text
func f_forward(x, dot_x):
    v1 = x * x
    dot_v1 = x * dot_x + x * dot_x

    y = sin(v1)
    dot_y = cos(v1) * dot_v1

    return y, dot_y
```

This approach has several advantages:
- derivative code is explicit
- compiler optimizations can apply
- no runtime tracing is required
- generated code can be inspected

But it must handle:
- function calls
- loops
- mutation
- aliasing
- closures
- exceptions
- external libraries

Source transformation is powerful but language dependent.

## Operator Overloading

Operator overloading changes the meaning of primitive operations for special AD values.

Instead of using ordinary numbers, the program uses objects such as:

```text
Dual(value, tangent)
```

Then ordinary operations are overloaded:

```text
Dual(x, dx) * Dual(y, dy)
=
Dual(x * y, x * dy + y * dx)
```

The user program may remain almost unchanged.

Example:

```text
x = Dual(3, 1)
y = sin(x * x)
```

returns both:
- primal value
- derivative value

Operator overloading is common because it is easy to embed in host languages.

Its limits are:
- runtime overhead
- less compiler visibility
- difficulty optimizing across operation boundaries
- possible interaction problems with mutation and control flow

## Tracing

Tracing records operations while the program executes.

A tracer replaces ordinary values with trace values. Each primitive operation appends a node to a graph or tape.

Example:

```text
x = trace_input("x")
v1 = x * x        # records mul
v2 = sin(v1)      # records sin
```

The trace can then be transformed into derivative code.

Tracing works well for numerical programs that execute representative control flow.

However, tracing captures the executed path. If control flow depends on input values, different inputs may produce different traces.

## Compiler IR Transformation

Modern AD systems often transform intermediate representations.

A compiler may lower code into SSA form:

```text
v1 = mul x x
v2 = sin v1
return v2
```

AD can then operate on this lower-level form.

Benefits:
- every value has a unique definition
- data dependencies are explicit
- optimization passes already exist
- reverse-mode adjoints can be inserted systematically
- dead derivative code can be eliminated

SSA form is especially useful for reverse mode because mutation is converted into value versions.

## AD and SSA

In static single assignment form, each variable is assigned once.

Imperative code:

```text
x = x * x
x = x + 1
```

becomes:

```text
x1 = x0 * x0
x2 = x1 + 1
```

This makes derivative propagation clearer.

Reverse mode can assign adjoints to value versions:

```text
bar_x2 = 1
bar_x1 += bar_x2
bar_x0 += 2 * x0 * bar_x1
```

Without SSA, the name `x` would refer to several different program states, which complicates reverse accumulation.

## Transformation and Optimization

Once derivative code exists, normal compiler optimization becomes important.

Common optimizations include:
- common subexpression elimination
- dead code elimination
- algebraic simplification
- memory reuse
- fusion
- inlining
- checkpoint placement
- layout optimization

Derivative code can be larger than primal code. Optimization prevents unnecessary overhead.

For example:

```text
dot_v1 = x * dot_x + x * dot_x
```

can be simplified to:

```text
dot_v1 = 2 * x * dot_x
```

Similarly, if only some inputs require derivatives, unused adjoint computations can be removed.

## Transformation Boundaries

AD systems must decide what to transform.

A program may call:
- built-in primitives
- user-defined functions
- library functions
- foreign functions
- system APIs

For each call, the AD system needs one of:

| Case | Handling |
|---|---|
| Differentiable primitive | use registered derivative rule |
| User function | recursively transform |
| Library function | use custom derivative rule |
| Opaque external call | stop differentiation or require user rule |
| Non-differentiable operation | return zero, error, or subgradient |

Boundary handling determines how useful an AD system is in real programs.

## Custom Derivative Rules

Some functions should not be differentiated by expanding their implementation.

Example:

$$
\operatorname{logsumexp}(x) =
\log\sum_i e^{x_i}
$$

A naive expansion may be numerically unstable.

A system should use a stable primitive and a custom derivative rule.

Custom rules are also useful for:
- solvers
- linear algebra decompositions
- special functions
- external kernels
- probabilistic operations
- implicit functions

A transformation-based AD system needs an extension mechanism for such rules.

## Transformation of Control Flow

Forward mode usually preserves control flow directly.

Reverse mode must invert control dependencies carefully.

For a loop:

```text
for i in range(n):
    x = step(x, i)
```

reverse mode often stores all intermediate states:

```text
x0, x1, ..., xn
```

Then it runs backward:

```text
for i in reversed(range(n)):
    propagate through step(x_i, i)
```

This is why long loops can be memory intensive.

Checkpointing changes the transformed program by recomputing some intermediate states instead of storing them all.

## Transformation of Mutation

Mutation requires careful semantics.

Example:

```text
a[i] = a[i] + x
```

The derivative of this operation depends on:
- previous value of `a[i]`
- aliasing
- whether other references observe the mutation
- order of writes

A robust AD compiler must model mutation explicitly.

Common strategies:
- convert mutation into pure updates
- use SSA value versions
- record mutation logs
- restrict in-place mutation
- define custom adjoints for update operations

In-place mutation can improve performance, but it complicates correctness.

## Transformation of Data Structures

Programs use data structures, not only scalars and tensors.

An AD system may need to handle:
- arrays
- structs
- tuples
- trees
- dictionaries
- sparse matrices
- graphs

The transformation must map each differentiable field to a corresponding tangent or adjoint field.

Example:

```text
State {
    position
    velocity
    mass
}
```

If `mass` is constant and `position`, `velocity` are differentiable, the tangent state contains only the relevant derivative fields.

This is where type systems help.

## Type-Level View

Forward mode transforms a type $T$ into a tangent bundle type:

$$
T \mapsto T \times \dot T
$$

Reverse mode transforms computations so that each type has an associated cotangent type:

$$
T \mapsto T^*
$$

For simple arrays, the tangent and cotangent types often match the primal type.

For structured objects, they may differ.

Examples:
- integer indices have no tangent
- booleans have no tangent
- floating arrays have floating tangents
- sparse structures may have sparse or dense cotangents

A well-designed AD system needs precise derivative types.

## AD as a Compiler Pass

In a compiler pipeline, AD can be placed as a transformation pass.

A simplified pipeline:

```text
source program
    ↓
typed AST
    ↓
intermediate representation
    ↓
AD transformation
    ↓
optimization
    ↓
lowering
    ↓
machine code / kernel code
```

This view treats differentiation as compilation, not runtime magic.

The AD pass introduces tangent or adjoint computations. Later passes optimize and lower them.

## Correctness of Transformation

A program transformation is correct if the transformed program computes the derivative of the original program.

For forward mode, correctness means:

$$
\text{forward}(p)(x,\dot x) =
(p(x),Dp(x)\dot x)
$$

For reverse mode, correctness means:

$$
\text{reverse}(p)(x,\bar y) =
(p(x),Dp(x)^T\bar y)
$$

These equations define the semantic contract of AD.

Everything else is implementation.

## Summary

Automatic differentiation is best understood as program transformation.

Forward mode transforms each value into a primal-tangent pair.

Reverse mode transforms evaluation into a forward recording pass plus a backward adjoint pass.

The transformation may operate on source code, traces, compiler IR, tensor graphs, or runtime tapes.

This perspective explains why AD is both mathematical and systems-oriented: it depends on the chain rule, but it must also handle real programming language features such as control flow, mutation, types, memory, and external calls.

