Skip to content

Chain Rule as Composition Algebra

The chain rule is the central theorem behind automatic differentiation. Every useful AD algorithm is a disciplined way of applying the chain rule to a program.

The chain rule is the central theorem behind automatic differentiation. Every useful AD algorithm is a disciplined way of applying the chain rule to a program.

A program is built from smaller operations. Each operation maps input values to output values. When one operation feeds another, the corresponding mathematical functions are composed. AD differentiates the whole computation by differentiating each small mapping and composing their local derivative information.

Composition of Scalar Functions

For scalar functions,

g:RR g : \mathbb{R} \to \mathbb{R}

and

h:RR h : \mathbb{R} \to \mathbb{R}

their composition is

f=hg f = h \circ g

meaning

f(x)=h(g(x)) f(x) = h(g(x))

The derivative is

f(x)=h(g(x))g(x) f'(x) = h'(g(x))g'(x)

This simple rule already contains the main structure of AD. We evaluate the inner function gg, use its result as the input to hh, and multiply local derivatives in the same dependency order.

For example, let

f(x)=sin(x2) f(x) = \sin(x^2)

Define

u=x2 u = x^2 y=sinu y = \sin u

Then

dydx=dydududx=cos(u)2x \frac{dy}{dx} = \frac{dy}{du} \frac{du}{dx} = \cos(u) \cdot 2x

Substituting u=x2u = x^2,

dydx=2xcos(x2) \frac{dy}{dx} = 2x\cos(x^2)

AD keeps the intermediate value uu available because the local derivative of sinu\sin u depends on it.

Composition of Vector Functions

In automatic differentiation, scalar composition is only the simplest case. Most computations involve vector-valued mappings.

Let

g:RnRk g : \mathbb{R}^n \to \mathbb{R}^k

and

h:RkRm h : \mathbb{R}^k \to \mathbb{R}^m

Then

f=hg f = h \circ g

has type

f:RnRm f : \mathbb{R}^n \to \mathbb{R}^m

The Jacobian chain rule is

Jf(x)=Jh(g(x))Jg(x) J_f(x) = J_h(g(x)) J_g(x)

The matrix shapes are important:

Jg(x)Rk×n J_g(x) \in \mathbb{R}^{k \times n} Jh(g(x))Rm×k J_h(g(x)) \in \mathbb{R}^{m \times k} Jf(x)Rm×n J_f(x) \in \mathbb{R}^{m \times n}

The product is valid because the intermediate dimension kk appears once as the output dimension of gg and once as the input dimension of hh.

This is composition algebra: derivative objects compose according to the same wiring as the original computation, but with linear maps replacing nonlinear maps.

Local Linear Maps

Each differentiable function has a local linear approximation. If

y=g(x) y = g(x)

then a small perturbation Δx\Delta x produces

ΔyJg(x)Δx \Delta y \approx J_g(x)\Delta x

If

z=h(y) z = h(y)

then

ΔzJh(y)Δy \Delta z \approx J_h(y)\Delta y

Substituting the first approximation into the second gives

ΔzJh(y)Jg(x)Δx \Delta z \approx J_h(y)J_g(x)\Delta x

Therefore,

Jhg(x)=Jh(g(x))Jg(x) J_{h \circ g}(x) = J_h(g(x))J_g(x)

This interpretation is more useful for AD than the formula alone. AD propagates local linear maps through the same structure as values.

Computational Graph View

Consider a computation

x -> u -> v -> y

with

u=g1(x) u = g_1(x) v=g2(u) v = g_2(u) y=g3(v) y = g_3(v)

The full function is

f=g3g2g1 f = g_3 \circ g_2 \circ g_1

The Jacobian is

Jf(x)=Jg3(v)Jg2(u)Jg1(x) J_f(x) = J_{g_3}(v) J_{g_2}(u) J_{g_1}(x)

The order matters. Values flow forward:

xuvy x \to u \to v \to y

The Jacobian product is written in the reverse textual order because matrix multiplication applies right to left:

Δy=Jg3(v)(Jg2(u)(Jg1(x)Δx)) \Delta y = J_{g_3}(v) \left( J_{g_2}(u) \left( J_{g_1}(x)\Delta x \right) \right)

Forward mode evaluates this from the inside outward, propagating tangent perturbations along with values.

Reverse mode evaluates the transposed action backward, propagating adjoints from outputs to inputs.

Forward Accumulation

Forward mode applies the chain rule in the same direction as value evaluation.

For a composition

y=g3(g2(g1(x))) y = g_3(g_2(g_1(x)))

and an input tangent x˙\dot{x}, forward mode computes

u˙=Jg1(x)x˙ \dot{u} = J_{g_1}(x)\dot{x} v˙=Jg2(u)u˙ \dot{v} = J_{g_2}(u)\dot{u} y˙=Jg3(v)v˙ \dot{y} = J_{g_3}(v)\dot{v}

Thus,

y˙=Jg3(v)Jg2(u)Jg1(x)x˙ \dot{y} = J_{g_3}(v) J_{g_2}(u) J_{g_1}(x) \dot{x}

This is a Jacobian-vector product:

y˙=Jf(x)x˙ \dot{y} = J_f(x)\dot{x}

Forward mode never has to build the full Jacobian. It only propagates the effect of one input direction.

Reverse Accumulation

Reverse mode applies the chain rule backward.

For the same computation,

xuvy x \to u \to v \to y

reverse mode starts with an output cotangent or adjoint yˉ\bar{y}. It propagates sensitivities backward:

vˉ=Jg3(v)Tyˉ \bar{v} = J_{g_3}(v)^T \bar{y} uˉ=Jg2(u)Tvˉ \bar{u} = J_{g_2}(u)^T \bar{v} xˉ=Jg1(x)Tuˉ \bar{x} = J_{g_1}(x)^T \bar{u}

Combining these,

xˉ=Jg1(x)TJg2(u)TJg3(v)Tyˉ \bar{x} = J_{g_1}(x)^T J_{g_2}(u)^T J_{g_3}(v)^T \bar{y}

Equivalently,

xˉ=Jf(x)Tyˉ \bar{x} = J_f(x)^T\bar{y}

This is a vector-Jacobian product, written in column-vector convention.

Reverse mode is efficient when the output dimension is small. For a scalar loss, one reverse pass computes the gradient with respect to all inputs.

Branching Computations

Programs often have values that feed more than one later operation.

Example:

u = x * x
v = sin(u)
w = exp(u)
y = v + w

Here, uu is used twice. The output is

y=sin(x2)+exp(x2) y = \sin(x^2) + \exp(x^2)

The derivative with respect to uu receives contributions from both branches:

dydu=dvdudydv+dwdudydw \frac{dy}{du} = \frac{dv}{du}\frac{dy}{dv} + \frac{dw}{du}\frac{dy}{dw}

Since

y=v+w y = v + w

we have

dydv=1,dydw=1 \frac{dy}{dv} = 1, \qquad \frac{dy}{dw} = 1

Therefore,

dydu=cos(u)+exp(u) \frac{dy}{du} = \cos(u) + \exp(u)

Reverse mode implements this by accumulation. If a variable contributes to several downstream paths, its adjoint is the sum of contributions from all paths.

This is one reason reverse mode needs careful handling of storage and mutation. Intermediate variables may receive adjoint updates from multiple consumers.

Multiple Inputs

For a primitive operation with multiple inputs,

z=ϕ(x,y) z = \phi(x, y)

the local linear rule is

z˙=ϕxx˙+ϕyy˙ \dot{z} = \frac{\partial \phi}{\partial x}\dot{x} + \frac{\partial \phi}{\partial y}\dot{y}

For multiplication,

z=xy z = xy

the rule is

z˙=yx˙+xy˙ \dot{z} = y\dot{x} + x\dot{y}

Reverse mode uses the transposed rule. Given zˉ\bar{z}, the adjoints are

xˉ+=yzˉ \bar{x} \mathrel{+}= y\bar{z} yˉ+=xzˉ \bar{y} \mathrel{+}= x\bar{z}

The symbol +=\mathrel{+}= matters. A variable may receive sensitivity from more than one downstream use.

Multiple Outputs

For a primitive operation with multiple outputs,

(y1,y2)=ϕ(x) (y_1, y_2) = \phi(x)

forward mode propagates tangents to each output:

y˙1=y1xx˙ \dot{y}_1 = \frac{\partial y_1}{\partial x}\dot{x} y˙2=y2xx˙ \dot{y}_2 = \frac{\partial y_2}{\partial x}\dot{x}

Reverse mode accumulates input sensitivity from each output:

xˉ+=y1xyˉ1+y2xyˉ2 \bar{x} \mathrel{+}= \frac{\partial y_1}{\partial x}\bar{y}_1 + \frac{\partial y_2}{\partial x}\bar{y}_2

This is the same chain rule, expressed through a local input-output interface.

Chain Rule as an Interface Contract

An AD system can be organized around a simple contract.

Each primitive operation must provide:

  1. A primal evaluation rule.
  2. A forward derivative rule, or JVP rule.
  3. A reverse derivative rule, or VJP rule.

For example, multiplication has primal rule:

z=xy z = xy

forward rule:

z˙=yx˙+xy˙ \dot{z} = y\dot{x} + x\dot{y}

reverse rule:

xˉ+=yzˉ \bar{x} \mathrel{+}= y\bar{z} yˉ+=xzˉ \bar{y} \mathrel{+}= x\bar{z}

Once every primitive provides these local rules, a full AD system can differentiate large programs by composition.

Algebraic Interpretation

A computation can be seen as a graph of typed mappings. Each node is a primitive map. Each edge carries values. Differentiation replaces each primitive map with a rule for propagating linear information.

Forward mode composes maps of the form:

(x,x˙)(y,y˙) (x, \dot{x}) \mapsto (y, \dot{y})

Reverse mode composes maps of the form:

(y,yˉ)(x,xˉ) (y, \bar{y}) \mapsto (x, \bar{x})

The same program can therefore support different derivative computations depending on how we traverse the composition.

This is why the chain rule is not merely a calculus identity in AD. It is the algebra that makes derivative programs compositional.

Practical Consequences

The chain rule explains several design facts about AD systems.

First, AD needs intermediate primal values. The derivative of sinu\sin u needs cosu\cos u. The derivative of multiplication xyxy needs the values of xx and yy. Reverse mode therefore often stores a tape of intermediate values or recomputes them later.

Second, derivative propagation follows data dependencies. If a value does not affect the output, it receives zero sensitivity. If it affects the output through several paths, its sensitivities add.

Third, AD scales because the chain rule is local. A large program can be differentiated by giving derivative rules for a small set of primitive operations.

The rest of automatic differentiation is mostly about how to execute this composition efficiently: which direction to accumulate, what to store, what to recompute, how to handle control flow, and how to represent derivative information without materializing enormous matrices.