Skip to content

The Chain Rule

The chain rule is the mathematical rule that makes backpropagation possible. Neural networks are built by composing many functions. The chain rule tells us how to differentiate such compositions.

The chain rule is the mathematical rule that makes backpropagation possible. Neural networks are built by composing many functions. The chain rule tells us how to differentiate such compositions.

A deep network may contain hundreds or thousands of operations: matrix multiplications, additions, nonlinear activations, normalization layers, attention blocks, losses, and regularization terms. PyTorch does not need a separate derivative formula for the entire network. It only needs derivative rules for each local operation. The chain rule combines these local derivatives into gradients for the full computation.

Composition of Functions

Suppose a value xx is transformed by one function and then another:

y=g(x), y = g(x), z=f(y). z = f(y).

Then zz is a function of xx:

z=f(g(x)). z = f(g(x)).

The chain rule says

dzdx=dzdydydx. \frac{dz}{dx} = \frac{dz}{dy} \frac{dy}{dx}.

The derivative of the whole computation is the product of the derivative of the outer function and the derivative of the inner function.

For example, let

y=x+1, y = x + 1, z=y2. z = y^2.

Then

z=(x+1)2. z = (x+1)^2.

The local derivatives are

dzdy=2y, \frac{dz}{dy}=2y, dydx=1. \frac{dy}{dx}=1.

Therefore

dzdx=2y. \frac{dz}{dx}=2y.

Since y=x+1y=x+1,

dzdx=2(x+1). \frac{dz}{dx}=2(x+1).

At x=3x=3, this gives

dzdx=8. \frac{dz}{dx}=8.

In PyTorch:

import torch

x = torch.tensor(3.0, requires_grad=True)

y = x + 1
z = y ** 2

z.backward()

print(x.grad)  # tensor(8.)

PyTorch applies the same rule during backward().

Local Derivatives

A computational graph is useful because each operation has a local derivative. The global derivative is assembled from these local pieces.

Consider

z=(x+y)2. z = (x+y)^2.

Introduce an intermediate value:

a=x+y, a = x+y, z=a2. z = a^2.

The local derivatives are

ax=1,ay=1, \frac{\partial a}{\partial x}=1, \quad \frac{\partial a}{\partial y}=1, za=2a. \frac{\partial z}{\partial a}=2a.

The chain rule gives

zx=zaax=2a, \frac{\partial z}{\partial x} = \frac{\partial z}{\partial a} \frac{\partial a}{\partial x} = 2a, zy=zaay=2a. \frac{\partial z}{\partial y} = \frac{\partial z}{\partial a} \frac{\partial a}{\partial y} = 2a.

At x=2x=2, y=3y=3, we have a=5a=5, so

zx=10,zy=10. \frac{\partial z}{\partial x}=10, \quad \frac{\partial z}{\partial y}=10.

In PyTorch:

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)

a = x + y
z = a ** 2

z.backward()

print(x.grad)  # tensor(10.)
print(y.grad)  # tensor(10.)

The graph stores enough information to know that z came from squaring a, and a came from adding x and y.

Multiple Paths

A variable may affect the output through more than one path. In that case, the total derivative is the sum of all path contributions.

Consider

z=x2+x. z = x^2 + x.

There are two paths from xx to zz. One path goes through x2x^2. The other path goes directly through xx.

Break the computation into smaller steps:

a=x2, a = x^2, b=x, b = x, z=a+b. z = a+b.

The derivative through the first path is

dadx=2x. \frac{da}{dx}=2x.

The derivative through the second path is

dbdx=1. \frac{db}{dx}=1.

Since

z=a+b, z=a+b,

the total derivative is

dzdx=2x+1. \frac{dz}{dx}=2x+1.

At x=3x=3,

dzdx=7. \frac{dz}{dx}=7.

In PyTorch:

x = torch.tensor(3.0, requires_grad=True)

z = x ** 2 + x
z.backward()

print(x.grad)  # tensor(7.)

This is why gradients are accumulated. If one tensor contributes to the loss through several routes, PyTorch adds all derivative contributions.

Chain Rule on Graphs

For a general computational graph, the chain rule can be stated in terms of parents and children.

Suppose a variable vv affects a final scalar loss LL through several downstream variables u1,,uku_1,\ldots,u_k. Then

Lv=i=1kLuiuiv. \frac{\partial L}{\partial v} = \sum_{i=1}^{k} \frac{\partial L}{\partial u_i} \frac{\partial u_i}{\partial v}.

This formula is the core of backpropagation.

Each node receives gradients from its children. It multiplies those incoming gradients by local derivatives. Then it adds the results and sends gradients to its own parents.

Using adjoint notation,

vˉ=Lv. \bar{v} = \frac{\partial L}{\partial v}.

Then the rule becomes

vˉ=i=1kuˉiuiv. \bar{v} = \sum_{i=1}^{k} \bar{u}_i \frac{\partial u_i}{\partial v}.

This is the backward pass in one equation.

Vector-Valued Functions

Neural networks usually operate on vectors and tensors, not only scalars. The chain rule still applies, but the derivatives become Jacobians.

Suppose

y=g(x),z=f(y), y = g(x), \quad z = f(y),

where

xRn,yRm,zR. x\in\mathbb{R}^n, \quad y\in\mathbb{R}^m, \quad z\in\mathbb{R}.

The gradient of zz with respect to xx is

xz=Jg(x)yz. \nabla_x z = J_g(x)^\top \nabla_y z.

Here

Jg(x)=yxRm×n J_g(x) = \frac{\partial y}{\partial x} \in\mathbb{R}^{m\times n}

is the Jacobian of gg.

In practice, PyTorch does not usually build the full Jacobian. It computes the product

Jg(x)yz. J_g(x)^\top \nabla_y z.

This is a vector-Jacobian product. It is enough for backpropagation and much cheaper than materializing the full Jacobian.

A Linear Layer Example

Consider a linear layer without bias:

y=Wx. y = Wx.

Let

WRm×n,xRn,yRm. W\in\mathbb{R}^{m\times n}, \quad x\in\mathbb{R}^{n}, \quad y\in\mathbb{R}^{m}.

Suppose a scalar loss LL depends on yy. Let

yˉ=yL. \bar{y} = \nabla_y L.

The gradient with respect to xx is

xˉ=Wyˉ. \bar{x} = W^\top \bar{y}.

The gradient with respect to WW is

Wˉ=yˉx. \bar{W} = \bar{y}x^\top.

These formulas are applications of the chain rule.

For a batch,

Y=XW, Y = XW^\top,

where

XRB×n,WRm×n,YRB×m. X\in\mathbb{R}^{B\times n}, \quad W\in\mathbb{R}^{m\times n}, \quad Y\in\mathbb{R}^{B\times m}.

If

Yˉ=YL, \bar{Y} = \nabla_Y L,

then

XL=YˉW, \nabla_X L = \bar{Y}W, WL=YˉX. \nabla_W L = \bar{Y}^\top X.

In PyTorch:

B = 4
n = 3
m = 2

X = torch.randn(B, n, requires_grad=True)
W = torch.randn(m, n, requires_grad=True)

Y = X @ W.T
L = Y.sum()

L.backward()

print(X.grad.shape)  # torch.Size([4, 3])
print(W.grad.shape)  # torch.Size([2, 3])

The gradient shapes match the original tensor shapes.

Chain Rule Through Nonlinearities

Activation functions are applied elementwise. This makes their backward rules simple.

For ReLU,

y=max(0,x). y = \max(0,x).

The derivative is

dydx={1,x>0,0,x<0. \frac{dy}{dx} = \begin{cases} 1, & x>0, \\ 0, & x<0. \end{cases}

At x=0x=0, the derivative is not uniquely defined. PyTorch uses a conventional subgradient.

If a loss LL depends on yy, then

Lx=Lyyx. \frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \frac{\partial y}{\partial x}.

So ReLU passes the upstream gradient through positive inputs and blocks it for negative inputs.

In PyTorch:

x = torch.tensor([-2.0, 0.5, 3.0], requires_grad=True)

y = torch.relu(x)
L = y.sum()

L.backward()

print(x.grad)  # tensor([0., 1., 1.])

For sigmoid,

y=σ(x)=11+ex. y = \sigma(x)=\frac{1}{1+e^{-x}}.

Its derivative is

dydx=y(1y). \frac{dy}{dx}=y(1-y).

The backward pass multiplies the upstream gradient by y(1y)y(1-y).

x = torch.tensor([0.0], requires_grad=True)

y = torch.sigmoid(x)
L = y.sum()

L.backward()

print(x.grad)  # tensor([0.2500])

At x=0x=0, σ(x)=0.5\sigma(x)=0.5, so the derivative is 0.250.25.

Chain Rule Through Loss Functions

Loss functions also participate in the computational graph.

For mean squared error,

L=1ni=1n(y^iyi)2. L = \frac{1}{n}\sum_{i=1}^{n}(\hat{y}_i-y_i)^2.

The derivative with respect to a prediction is

Ly^i=2n(y^iyi). \frac{\partial L}{\partial \hat{y}_i} = \frac{2}{n}(\hat{y}_i-y_i).

This gradient is then propagated backward through the model that produced y^\hat{y}.

In PyTorch:

pred = torch.tensor([2.0, 4.0, 6.0], requires_grad=True)
target = torch.tensor([1.0, 5.0, 2.0])

loss = ((pred - target) ** 2).mean()
loss.backward()

print(pred.grad)

The loss gradient is the starting signal for backpropagation. The model receives this signal at its output and propagates it backward to all parameters that contributed to the predictions.

Chain Rule in a Training Step

A complete training step applies the chain rule over the whole model.

import torch
from torch import nn

model = nn.Sequential(
    nn.Linear(10, 32),
    nn.ReLU(),
    nn.Linear(32, 1),
)

X = torch.randn(16, 10)
target = torch.randn(16, 1)

pred = model(X)
loss = ((pred - target) ** 2).mean()

model.zero_grad()
loss.backward()

The computation can be read as a composition:

XLinear1ReLULinear2y^L. X \longrightarrow \text{Linear}_1 \longrightarrow \text{ReLU} \longrightarrow \text{Linear}_2 \longrightarrow \hat{y} \longrightarrow L.

The backward pass applies the chain rule in reverse order:

Ly^Linear2ReLULinear1. L \longrightarrow \hat{y} \longrightarrow \text{Linear}_2 \longrightarrow \text{ReLU} \longrightarrow \text{Linear}_1.

Each module contributes a local backward rule. PyTorch combines those rules into parameter gradients.

Why Deep Networks Can Be Hard to Train

The chain rule also explains some training difficulties. A deep network multiplies many derivative terms together. If many terms are smaller than one, gradients may shrink as they move backward. This is the vanishing gradient problem.

If many terms are larger than one, gradients may grow rapidly. This is the exploding gradient problem.

For a simple repeated composition,

ht=f(ht1), h_t = f(h_{t-1}),

the derivative after many steps contains a product:

hTh0=t=1Ththt1. \frac{\partial h_T}{\partial h_0} = \prod_{t=1}^{T} \frac{\partial h_t}{\partial h_{t-1}}.

This product may become very small or very large.

Modern architectures reduce these problems through careful initialization, normalization, residual connections, gated units, and gradient clipping. These techniques will appear throughout later chapters.

Common PyTorch Mistakes

A common mistake is breaking the graph accidentally by converting tensors to Python numbers:

loss_value = loss.item()
loss_value.backward()  # error

The method .item() extracts a plain Python number. That number has no computational graph.

Use .item() for logging only:

print(loss.item())
loss.backward()

Another mistake is detaching an intermediate value too early:

h = encoder(x)
h = h.detach()
out = decoder(h)
loss = loss_fn(out, target)
loss.backward()

This prevents gradients from flowing into encoder. This may be intentional in some algorithms, but it should not happen accidentally.

A third mistake is using non-differentiable operations where gradients are needed. For example, argmax produces discrete indices. Gradients cannot flow through a hard index choice in the usual way.

logits = model(x)
idx = logits.argmax(dim=-1)

This is fine for prediction, but it should usually not be placed inside a differentiable training path.

Summary

The chain rule gives the derivative of a composed function by multiplying local derivatives. In computational graphs, the same rule is applied node by node. If a variable affects the loss through multiple paths, the gradient contributions from all paths are added.

PyTorch autograd implements this process automatically. During the forward pass, it records operations. During the backward pass, it starts from the loss and applies local backward rules in reverse order.

Backpropagation is the chain rule applied efficiently to large computational graphs.