Skip to content

Chapter 20. Building an AD Engine

A minimal forward mode automatic differentiation engine has one job: evaluate a program while carrying both a value and its derivative. The engine does not build a graph. It...

Minimal Forward Mode Engine

A minimal forward mode automatic differentiation engine has one job: evaluate a program while carrying both a value and its derivative. The engine does not build a graph. It does not store a tape. It computes derivatives in the same order as the original computation.

Forward mode is best understood as ordinary evaluation over a richer number type.

Instead of evaluating with real numbers, we evaluate with pairs:

x(x,x˙) x \mapsto (x, \dot{x})

The first component is the primal value. The second component is the tangent value. The tangent records how the output changes when the input is perturbed in a chosen direction.

For a scalar function

f:RR f : \mathbb{R} \to \mathbb{R}

we initialize the input as

x(x,1) x \mapsto (x, 1)

Then the final tangent is exactly f(x)f'(x).

For a multivariate function

f:RnR f : \mathbb{R}^n \to \mathbb{R}

the tangent represents a direction vector vv. Forward mode computes the Jacobian-vector product:

Jf(x)v J_f(x)v

This is the central operation of forward mode.

The Core Data Type

A minimal engine can start with a single type:

type Dual struct {
    Value float64
    Deriv float64
}

Value is the ordinary numeric value. Deriv is the derivative with respect to the chosen seed direction.

For scalar differentiation, the input variable receives derivative 1.

func Var(x float64) Dual {
    return Dual{
        Value: x,
        Deriv: 1,
    }
}

Constants receive derivative 0.

func Const(x float64) Dual {
    return Dual{
        Value: x,
        Deriv: 0,
    }
}

This distinction matters. A variable changes when the input changes. A constant does not.

Arithmetic Rules

Each primitive operation must define how values and derivatives propagate.

For addition:

func Add(a, b Dual) Dual {
    return Dual{
        Value: a.Value + b.Value,
        Deriv: a.Deriv + b.Deriv,
    }
}

For subtraction:

func Sub(a, b Dual) Dual {
    return Dual{
        Value: a.Value - b.Value,
        Deriv: a.Deriv - b.Deriv,
    }
}

For multiplication:

func Mul(a, b Dual) Dual {
    return Dual{
        Value: a.Value * b.Value,
        Deriv: a.Deriv*b.Value + a.Value*b.Deriv,
    }
}

For division:

func Div(a, b Dual) Dual {
    return Dual{
        Value: a.Value / b.Value,
        Deriv: (a.Deriv*b.Value - a.Value*b.Deriv) / (b.Value * b.Value),
    }
}

These are just the ordinary derivative rules encoded as executable code.

Elementary Functions

The same pattern extends to functions such as sin, cos, exp, and log.

func Sin(x Dual) Dual {
    return Dual{
        Value: math.Sin(x.Value),
        Deriv: math.Cos(x.Value) * x.Deriv,
    }
}

func Cos(x Dual) Dual {
    return Dual{
        Value: math.Cos(x.Value),
        Deriv: -math.Sin(x.Value) * x.Deriv,
    }
}

func Exp(x Dual) Dual {
    e := math.Exp(x.Value)
    return Dual{
        Value: e,
        Deriv: e * x.Deriv,
    }
}

func Log(x Dual) Dual {
    return Dual{
        Value: math.Log(x.Value),
        Deriv: x.Deriv / x.Value,
    }
}

Each rule has the same shape:

y=f(x) y = f(x) y˙=f(x)x˙ \dot{y} = f'(x)\dot{x}

Forward mode applies the chain rule locally at every operation.

A Complete Example

Consider:

f(x)=x2+3x+2 f(x) = x^2 + 3x + 2

In Go:

func F(x Dual) Dual {
    return Add(
        Add(
            Mul(x, x),
            Mul(Const(3), x),
        ),
        Const(2),
    )
}

Evaluate at x=5x = 5:

func main() {
    y := F(Var(5))

    fmt.Println("value:", y.Value)
    fmt.Println("derivative:", y.Deriv)
}

The result is:

value: 42
derivative: 13

The derivative is correct because:

f(x)=2x+3 f'(x) = 2x + 3

and therefore:

f(5)=13 f'(5) = 13

The engine never constructs the symbolic expression 2x+32x + 3. It also never estimates the derivative using finite differences. It computes the derivative exactly through the program structure, subject only to floating point arithmetic.

Multivariate Inputs

For a function

f(x,y)=xy+sin(x) f(x, y) = xy + \sin(x)

we can compute partial derivatives by choosing different seeds.

func G(x, y Dual) Dual {
    return Add(
        Mul(x, y),
        Sin(x),
    )
}

To compute f/x\partial f / \partial x, seed x with 1 and y with 0.

x := Dual{Value: 2, Deriv: 1}
y := Dual{Value: 3, Deriv: 0}

out := G(x, y)

The tangent gives:

fx=y+cos(x) \frac{\partial f}{\partial x} = y + \cos(x)

To compute f/y\partial f / \partial y, seed x with 0 and y with 1.

x := Dual{Value: 2, Deriv: 0}
y := Dual{Value: 3, Deriv: 1}

out := G(x, y)

The tangent gives:

fy=x \frac{\partial f}{\partial y} = x

A full gradient for f:RnRf : \mathbb{R}^n \to \mathbb{R} requires nn forward passes if we use scalar tangents. Each pass seeds one input direction.

Vector Tangents

A more general engine stores a vector of derivatives.

type DualVec struct {
    Value float64
    Deriv []float64
}

Now one forward pass can carry multiple seed directions.

For example, for two inputs:

x := DualVec{Value: 2, Deriv: []float64{1, 0}}
y := DualVec{Value: 3, Deriv: []float64{0, 1}}

The output derivative vector contains both partial derivatives.

This is convenient, but it changes the cost model. Every primitive operation now performs vector arithmetic on the tangent field.

Scalar tangent:

cost per primitive: O(1)

Vector tangent of width k:

cost per primitive: O(k)

The choice depends on the shape of the problem.

Minimal Engine Interface

A small Go-style API can expose only the essential operations:

type Dual struct {
    Value float64
    Deriv float64
}

func Var(x float64) Dual
func Const(x float64) Dual

func Add(a, b Dual) Dual
func Sub(a, b Dual) Dual
func Mul(a, b Dual) Dual
func Div(a, b Dual) Dual

func Sin(x Dual) Dual
func Cos(x Dual) Dual
func Exp(x Dual) Dual
func Log(x Dual) Dual

This is enough to differentiate many scalar programs. More functions can be added incrementally.

The important design rule is that every primitive must preserve the invariant:

Dual.Value = primal value
Dual.Deriv = derivative of primal value with respect to the seed

Once this invariant holds for constants, variables, and primitive operations, it holds for every expression built from them.

Why Forward Mode Is Simple

Forward mode is simple because the derivative flows in the same direction as evaluation.

Original program:

inputs -> intermediate values -> output

Forward mode:

input tangents -> intermediate tangents -> output tangent

There is no need to revisit earlier operations. There is no backward pass. There is no tape. The engine can run with constant extra memory per active value.

This makes forward mode attractive for:

Problem shapeWhy forward mode fits
Few inputs, many outputsOne seed direction can update all outputs
Jacobian-vector productsDirectly computed
Local sensitivity analysisCheap for selected directions
Embedded systemsSimple memory model
Small numerical kernelsLow implementation overhead
Higher-order Taylor methodsNatural extension through richer number types

Limitations

Forward mode becomes expensive when the input dimension is large and the output dimension is small.

For a function:

f:RnR f : \mathbb{R}^n \to \mathbb{R}

a full gradient needs nn scalar forward passes, or one pass with tangent width nn. Either way, the work scales with the number of inputs.

This is why reverse mode dominates deep learning. Neural networks often have millions or billions of parameters and a scalar loss. Reverse mode can compute the full gradient with work comparable to a small constant multiple of the primal evaluation.

Forward mode remains valuable because it is predictable, local, and easy to implement. It is also the natural primitive for Jacobian-vector products, higher-order methods, and testing reverse-mode systems.

Minimal Correctness Argument

The correctness proof is structural.

For each expression ee, the engine computes:

(e(x),De(x)v) (e(x), De(x)v)

where vv is the seed direction.

For variables, this holds by construction. For constants, the derivative is zero. For each primitive operation, the implementation applies the corresponding derivative rule. For composition, the tangent propagation is exactly the chain rule.

Therefore every expression built from supported primitives has a correct forward-mode derivative.

The engine is small because automatic differentiation does not require symbolic algebra. It only requires local derivative rules and ordinary program evaluation.