# The Chain Rule

The chain rule is the mathematical rule that makes backpropagation possible. Neural networks are built by composing many functions. The chain rule tells us how to differentiate such compositions.

A deep network may contain hundreds or thousands of operations: matrix multiplications, additions, nonlinear activations, normalization layers, attention blocks, losses, and regularization terms. PyTorch does not need a separate derivative formula for the entire network. It only needs derivative rules for each local operation. The chain rule combines these local derivatives into gradients for the full computation.

### Composition of Functions

Suppose a value \(x\) is transformed by one function and then another:

$$
y = g(x),
$$

$$
z = f(y).
$$

Then \(z\) is a function of \(x\):

$$
z = f(g(x)).
$$

The chain rule says

$$
\frac{dz}{dx} =
\frac{dz}{dy}
\frac{dy}{dx}.
$$

The derivative of the whole computation is the product of the derivative of the outer function and the derivative of the inner function.

For example, let

$$
y = x + 1,
$$

$$
z = y^2.
$$

Then

$$
z = (x+1)^2.
$$

The local derivatives are

$$
\frac{dz}{dy}=2y,
$$

$$
\frac{dy}{dx}=1.
$$

Therefore

$$
\frac{dz}{dx}=2y.
$$

Since \(y=x+1\),

$$
\frac{dz}{dx}=2(x+1).
$$

At \(x=3\), this gives

$$
\frac{dz}{dx}=8.
$$

In PyTorch:

```python
import torch

x = torch.tensor(3.0, requires_grad=True)

y = x + 1
z = y ** 2

z.backward()

print(x.grad)  # tensor(8.)
```

PyTorch applies the same rule during `backward()`.

### Local Derivatives

A computational graph is useful because each operation has a local derivative. The global derivative is assembled from these local pieces.

Consider

$$
z = (x+y)^2.
$$

Introduce an intermediate value:

$$
a = x+y,
$$

$$
z = a^2.
$$

The local derivatives are

$$
\frac{\partial a}{\partial x}=1,
\quad
\frac{\partial a}{\partial y}=1,
$$

$$
\frac{\partial z}{\partial a}=2a.
$$

The chain rule gives

$$
\frac{\partial z}{\partial x} =
\frac{\partial z}{\partial a}
\frac{\partial a}{\partial x} =
2a,
$$

$$
\frac{\partial z}{\partial y} =
\frac{\partial z}{\partial a}
\frac{\partial a}{\partial y} =
2a.
$$

At \(x=2\), \(y=3\), we have \(a=5\), so

$$
\frac{\partial z}{\partial x}=10,
\quad
\frac{\partial z}{\partial y}=10.
$$

In PyTorch:

```python
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)

a = x + y
z = a ** 2

z.backward()

print(x.grad)  # tensor(10.)
print(y.grad)  # tensor(10.)
```

The graph stores enough information to know that `z` came from squaring `a`, and `a` came from adding `x` and `y`.

### Multiple Paths

A variable may affect the output through more than one path. In that case, the total derivative is the sum of all path contributions.

Consider

$$
z = x^2 + x.
$$

There are two paths from \(x\) to \(z\). One path goes through \(x^2\). The other path goes directly through \(x\).

Break the computation into smaller steps:

$$
a = x^2,
$$

$$
b = x,
$$

$$
z = a+b.
$$

The derivative through the first path is

$$
\frac{da}{dx}=2x.
$$

The derivative through the second path is

$$
\frac{db}{dx}=1.
$$

Since

$$
z=a+b,
$$

the total derivative is

$$
\frac{dz}{dx}=2x+1.
$$

At \(x=3\),

$$
\frac{dz}{dx}=7.
$$

In PyTorch:

```python
x = torch.tensor(3.0, requires_grad=True)

z = x ** 2 + x
z.backward()

print(x.grad)  # tensor(7.)
```

This is why gradients are accumulated. If one tensor contributes to the loss through several routes, PyTorch adds all derivative contributions.

### Chain Rule on Graphs

For a general computational graph, the chain rule can be stated in terms of parents and children.

Suppose a variable \(v\) affects a final scalar loss \(L\) through several downstream variables \(u_1,\ldots,u_k\). Then

$$
\frac{\partial L}{\partial v} =
\sum_{i=1}^{k}
\frac{\partial L}{\partial u_i}
\frac{\partial u_i}{\partial v}.
$$

This formula is the core of backpropagation.

Each node receives gradients from its children. It multiplies those incoming gradients by local derivatives. Then it adds the results and sends gradients to its own parents.

Using adjoint notation,

$$
\bar{v} = \frac{\partial L}{\partial v}.
$$

Then the rule becomes

$$
\bar{v} =
\sum_{i=1}^{k}
\bar{u}_i
\frac{\partial u_i}{\partial v}.
$$

This is the backward pass in one equation.

### Vector-Valued Functions

Neural networks usually operate on vectors and tensors, not only scalars. The chain rule still applies, but the derivatives become Jacobians.

Suppose

$$
y = g(x),
\quad
z = f(y),
$$

where

$$
x\in\mathbb{R}^n,
\quad
y\in\mathbb{R}^m,
\quad
z\in\mathbb{R}.
$$

The gradient of \(z\) with respect to \(x\) is

$$
\nabla_x z =
J_g(x)^\top \nabla_y z.
$$

Here

$$
J_g(x) =
\frac{\partial y}{\partial x}
\in\mathbb{R}^{m\times n}
$$

is the Jacobian of \(g\).

In practice, PyTorch does not usually build the full Jacobian. It computes the product

$$
J_g(x)^\top \nabla_y z.
$$

This is a vector-Jacobian product. It is enough for backpropagation and much cheaper than materializing the full Jacobian.

### A Linear Layer Example

Consider a linear layer without bias:

$$
y = Wx.
$$

Let

$$
W\in\mathbb{R}^{m\times n},
\quad
x\in\mathbb{R}^{n},
\quad
y\in\mathbb{R}^{m}.
$$

Suppose a scalar loss \(L\) depends on \(y\). Let

$$
\bar{y} = \nabla_y L.
$$

The gradient with respect to \(x\) is

$$
\bar{x} = W^\top \bar{y}.
$$

The gradient with respect to \(W\) is

$$
\bar{W} = \bar{y}x^\top.
$$

These formulas are applications of the chain rule.

For a batch,

$$
Y = XW^\top,
$$

where

$$
X\in\mathbb{R}^{B\times n},
\quad
W\in\mathbb{R}^{m\times n},
\quad
Y\in\mathbb{R}^{B\times m}.
$$

If

$$
\bar{Y} = \nabla_Y L,
$$

then

$$
\nabla_X L = \bar{Y}W,
$$

$$
\nabla_W L = \bar{Y}^\top X.
$$

In PyTorch:

```python
B = 4
n = 3
m = 2

X = torch.randn(B, n, requires_grad=True)
W = torch.randn(m, n, requires_grad=True)

Y = X @ W.T
L = Y.sum()

L.backward()

print(X.grad.shape)  # torch.Size([4, 3])
print(W.grad.shape)  # torch.Size([2, 3])
```

The gradient shapes match the original tensor shapes.

### Chain Rule Through Nonlinearities

Activation functions are applied elementwise. This makes their backward rules simple.

For ReLU,

$$
y = \max(0,x).
$$

The derivative is

$$
\frac{dy}{dx} =
\begin{cases}
1, & x>0, \\
0, & x<0.
\end{cases}
$$

At \(x=0\), the derivative is not uniquely defined. PyTorch uses a conventional subgradient.

If a loss \(L\) depends on \(y\), then

$$
\frac{\partial L}{\partial x} =
\frac{\partial L}{\partial y}
\frac{\partial y}{\partial x}.
$$

So ReLU passes the upstream gradient through positive inputs and blocks it for negative inputs.

In PyTorch:

```python
x = torch.tensor([-2.0, 0.5, 3.0], requires_grad=True)

y = torch.relu(x)
L = y.sum()

L.backward()

print(x.grad)  # tensor([0., 1., 1.])
```

For sigmoid,

$$
y = \sigma(x)=\frac{1}{1+e^{-x}}.
$$

Its derivative is

$$
\frac{dy}{dx}=y(1-y).
$$

The backward pass multiplies the upstream gradient by \(y(1-y)\).

```python
x = torch.tensor([0.0], requires_grad=True)

y = torch.sigmoid(x)
L = y.sum()

L.backward()

print(x.grad)  # tensor([0.2500])
```

At \(x=0\), \(\sigma(x)=0.5\), so the derivative is \(0.25\).

### Chain Rule Through Loss Functions

Loss functions also participate in the computational graph.

For mean squared error,

$$
L = \frac{1}{n}\sum_{i=1}^{n}(\hat{y}_i-y_i)^2.
$$

The derivative with respect to a prediction is

$$
\frac{\partial L}{\partial \hat{y}_i} =
\frac{2}{n}(\hat{y}_i-y_i).
$$

This gradient is then propagated backward through the model that produced \(\hat{y}\).

In PyTorch:

```python
pred = torch.tensor([2.0, 4.0, 6.0], requires_grad=True)
target = torch.tensor([1.0, 5.0, 2.0])

loss = ((pred - target) ** 2).mean()
loss.backward()

print(pred.grad)
```

The loss gradient is the starting signal for backpropagation. The model receives this signal at its output and propagates it backward to all parameters that contributed to the predictions.

### Chain Rule in a Training Step

A complete training step applies the chain rule over the whole model.

```python
import torch
from torch import nn

model = nn.Sequential(
    nn.Linear(10, 32),
    nn.ReLU(),
    nn.Linear(32, 1),
)

X = torch.randn(16, 10)
target = torch.randn(16, 1)

pred = model(X)
loss = ((pred - target) ** 2).mean()

model.zero_grad()
loss.backward()
```

The computation can be read as a composition:

$$
X
\longrightarrow
\text{Linear}_1
\longrightarrow
\text{ReLU}
\longrightarrow
\text{Linear}_2
\longrightarrow
\hat{y}
\longrightarrow
L.
$$

The backward pass applies the chain rule in reverse order:

$$
L
\longrightarrow
\hat{y}
\longrightarrow
\text{Linear}_2
\longrightarrow
\text{ReLU}
\longrightarrow
\text{Linear}_1.
$$

Each module contributes a local backward rule. PyTorch combines those rules into parameter gradients.

### Why Deep Networks Can Be Hard to Train

The chain rule also explains some training difficulties. A deep network multiplies many derivative terms together. If many terms are smaller than one, gradients may shrink as they move backward. This is the vanishing gradient problem.

If many terms are larger than one, gradients may grow rapidly. This is the exploding gradient problem.

For a simple repeated composition,

$$
h_t = f(h_{t-1}),
$$

the derivative after many steps contains a product:

$$
\frac{\partial h_T}{\partial h_0} =
\prod_{t=1}^{T}
\frac{\partial h_t}{\partial h_{t-1}}.
$$

This product may become very small or very large.

Modern architectures reduce these problems through careful initialization, normalization, residual connections, gated units, and gradient clipping. These techniques will appear throughout later chapters.

### Common PyTorch Mistakes

A common mistake is breaking the graph accidentally by converting tensors to Python numbers:

```python
loss_value = loss.item()
loss_value.backward()  # error
```

The method `.item()` extracts a plain Python number. That number has no computational graph.

Use `.item()` for logging only:

```python
print(loss.item())
loss.backward()
```

Another mistake is detaching an intermediate value too early:

```python
h = encoder(x)
h = h.detach()
out = decoder(h)
loss = loss_fn(out, target)
loss.backward()
```

This prevents gradients from flowing into `encoder`. This may be intentional in some algorithms, but it should not happen accidentally.

A third mistake is using non-differentiable operations where gradients are needed. For example, `argmax` produces discrete indices. Gradients cannot flow through a hard index choice in the usual way.

```python
logits = model(x)
idx = logits.argmax(dim=-1)
```

This is fine for prediction, but it should usually not be placed inside a differentiable training path.

### Summary

The chain rule gives the derivative of a composed function by multiplying local derivatives. In computational graphs, the same rule is applied node by node. If a variable affects the loss through multiple paths, the gradient contributions from all paths are added.

PyTorch autograd implements this process automatically. During the forward pass, it records operations. During the backward pass, it starts from the loss and applies local backward rules in reverse order.

Backpropagation is the chain rule applied efficiently to large computational graphs.

