# Reverse-Mode Differentiation

Reverse-mode differentiation is the method used by backpropagation. It computes derivatives by first evaluating a function forward, then propagating gradient information backward from the output to the inputs.

This method is especially useful in deep learning because neural networks usually have many parameters and one scalar loss. Reverse-mode differentiation can compute the gradient of one scalar output with respect to millions or billions of parameters efficiently.

### The Problem Setting

Assume we have a function

$$
L = f(\theta),
$$

where \(\theta\) is a vector of parameters and \(L\) is a scalar loss.

If

$$
\theta =
\begin{bmatrix}
\theta_1 \\
\theta_2 \\
\vdots \\
\theta_n
\end{bmatrix},
$$

then the gradient is

$$
\nabla_\theta L =
\begin{bmatrix}
\frac{\partial L}{\partial \theta_1} \\
\frac{\partial L}{\partial \theta_2} \\
\vdots \\
\frac{\partial L}{\partial \theta_n}
\end{bmatrix}.
$$

A deep model may have millions of parameters. Computing each partial derivative separately would be too expensive. Reverse-mode differentiation avoids that cost by reusing intermediate derivative information.

### Forward Mode Versus Reverse Mode

There are two broad modes of automatic differentiation: forward mode and reverse mode.

Forward mode propagates derivatives from inputs to outputs. It answers: if one input changes, how does each later value change?

Reverse mode propagates derivatives from outputs to inputs. It answers: if the final output changes, how much did each earlier value contribute?

For a function

$$
f:\mathbb{R}^n\to\mathbb{R}^m,
$$

forward mode is efficient when \(n\) is small. Reverse mode is efficient when \(m\) is small.

Deep learning training usually has

$$
f:\mathbb{R}^n\to\mathbb{R},
$$

where \(n\) may be very large. The output is one scalar loss. This is the ideal case for reverse mode.

### A Simple Example

Consider

$$
z = (x+y)^2.
$$

Break it into intermediate variables:

$$
a = x+y,
$$

$$
z = a^2.
$$

The forward pass computes values:

$$
x=2,\quad y=3,
$$

$$
a=5,
$$

$$
z=25.
$$

The reverse pass computes derivatives of the final output \(z\) with respect to each intermediate value.

We start with

$$
\frac{\partial z}{\partial z}=1.
$$

Then

$$
\frac{\partial z}{\partial a}=2a=10.
$$

Since

$$
a=x+y,
$$

we have

$$
\frac{\partial a}{\partial x}=1,
\quad
\frac{\partial a}{\partial y}=1.
$$

Therefore

$$
\frac{\partial z}{\partial x} =
\frac{\partial z}{\partial a}
\frac{\partial a}{\partial x} =
10,
$$

and

$$
\frac{\partial z}{\partial y} =
\frac{\partial z}{\partial a}
\frac{\partial a}{\partial y} =
10.
$$

In PyTorch:

```python
import torch

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)

a = x + y
z = a ** 2

z.backward()

print(x.grad)  # tensor(10.)
print(y.grad)  # tensor(10.)
```

PyTorch performs this reverse traversal automatically.

### Adjoints

Reverse-mode differentiation often uses the term adjoint. The adjoint of a variable is the derivative of the final scalar output with respect to that variable.

For a variable \(v\), its adjoint is written as

$$
\bar{v} = \frac{\partial L}{\partial v}.
$$

Here \(L\) is the final scalar loss.

For the computation

$$
a=x+y,
\quad
z=a^2,
\quad
L=z,
$$

the adjoints are

$$
\bar{z} = \frac{\partial L}{\partial z}=1,
$$

$$
\bar{a} = \frac{\partial L}{\partial a}=2a,
$$

$$
\bar{x} = \frac{\partial L}{\partial x}=\bar{a}\frac{\partial a}{\partial x},
$$

$$
\bar{y} = \frac{\partial L}{\partial y}=\bar{a}\frac{\partial a}{\partial y}.
$$

At \(a=5\), we get

$$
\bar{x}=10,
\quad
\bar{y}=10.
$$

The adjoint notation is useful because it describes the backward pass locally. Each operation receives an upstream adjoint and distributes it to its inputs.

### Local Backward Rules

Reverse-mode differentiation depends on local backward rules. Each primitive operation knows how to send gradients to its inputs.

For addition,

$$
z = x+y.
$$

If the upstream gradient is \(\bar{z}\), then

$$
\bar{x} \mathrel{+}= \bar{z},
$$

$$
\bar{y} \mathrel{+}= \bar{z}.
$$

For multiplication,

$$
z = xy.
$$

The local derivatives are

$$
\frac{\partial z}{\partial x}=y,
\quad
\frac{\partial z}{\partial y}=x.
$$

Thus

$$
\bar{x} \mathrel{+}= \bar{z}y,
$$

$$
\bar{y} \mathrel{+}= \bar{z}x.
$$

For squaring,

$$
z=x^2.
$$

The backward rule is

$$
\bar{x} \mathrel{+}= \bar{z}\,2x.
$$

The symbol \(\mathrel{+}=\) means that gradients are accumulated. A variable may influence the output through multiple paths, so all contributions must be added.

### Why Gradients Accumulate

Consider

$$
z = x^2 + x.
$$

Break it into steps:

$$
a = x^2,
$$

$$
b = x,
$$

$$
z = a+b.
$$

The variable \(x\) contributes to \(z\) through two paths: one through \(a\), and one through \(b\).

The derivative is

$$
\frac{dz}{dx}=2x+1.
$$

Reverse mode obtains this by adding contributions from both paths.

At \(x=3\),

$$
\frac{dz}{dx}=7.
$$

In PyTorch:

```python
x = torch.tensor(3.0, requires_grad=True)

z = x ** 2 + x
z.backward()

print(x.grad)  # tensor(7.)
```

This accumulation behavior is the reason PyTorch adds gradients into `.grad` fields rather than replacing them automatically.

### Reverse Pass on a Larger Graph

Consider the computation

$$
u = xy,
$$

$$
v = x+y,
$$

$$
L = u+v.
$$

The forward pass computes \(u\), \(v\), and \(L\). The reverse pass starts with

$$
\bar{L}=1.
$$

Since

$$
L=u+v,
$$

we get

$$
\bar{u} = 1,
\quad
\bar{v}=1.
$$

Since

$$
u=xy,
$$

we add

$$
\bar{x} \mathrel{+}= \bar{u}y,
$$

$$
\bar{y} \mathrel{+}= \bar{u}x.
$$

Since

$$
v=x+y,
$$

we add

$$
\bar{x} \mathrel{+}= \bar{v},
$$

$$
\bar{y} \mathrel{+}= \bar{v}.
$$

Thus

$$
\bar{x}=y+1,
$$

$$
\bar{y}=x+1.
$$

At \(x=2\), \(y=3\),

$$
\bar{x}=4,
\quad
\bar{y}=3.
$$

PyTorch:

```python
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)

u = x * y
v = x + y
L = u + v

L.backward()

print(x.grad)  # tensor(4.)
print(y.grad)  # tensor(3.)
```

### Vector-Jacobian Products

Reverse mode can be understood through vector-Jacobian products.

Suppose

$$
y = f(x),
$$

where \(x\in\mathbb{R}^n\) and \(y\in\mathbb{R}^m\). The Jacobian is

$$
J =
\frac{\partial y}{\partial x}
\in\mathbb{R}^{m\times n}.
$$

If a later scalar loss \(L\) depends on \(y\), then the upstream gradient is

$$
\bar{y} =
\frac{\partial L}{\partial y}
\in\mathbb{R}^{m}.
$$

The gradient with respect to \(x\) is

$$
\bar{x} =
J^\top \bar{y}.
$$

This operation is a vector-Jacobian product. It avoids explicitly forming the full Jacobian. This matters because Jacobians in deep learning can be extremely large.

PyTorch’s backward pass is primarily a system for computing vector-Jacobian products efficiently.

### Non-Scalar Outputs

When the output is scalar, PyTorch implicitly uses an upstream gradient of \(1\).

```python
x = torch.tensor(3.0, requires_grad=True)
y = x ** 2

y.backward()
```

This means

$$
\frac{dy}{dy}=1.
$$

When the output is not scalar, the user must provide the upstream gradient.

```python
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)

y = x ** 2
y.backward(torch.tensor([1.0, 1.0, 1.0]))

print(x.grad)  # tensor([2., 4., 6.])
```

Here the provided vector acts as \(\bar{y}\). Since

$$
y_i=x_i^2,
$$

the vector-Jacobian product gives

$$
\bar{x}_i=\bar{y}_i 2x_i.
$$

With \(\bar{y}_i=1\), this gives \(2x_i\).

A different upstream gradient gives a different vector-Jacobian product:

```python
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)

y = x ** 2
y.backward(torch.tensor([10.0, 1.0, 0.1]))

print(x.grad)  # tensor([20.0000, 4.0000, 0.6000])
```

### Reverse Mode in Neural Networks

A neural network is a composition of many operations:

$$
L =
\ell(f_\theta(X), y).
$$

The forward pass computes predictions and loss. The reverse pass computes

$$
\nabla_\theta L.
$$

For a multilayer network,

$$
h_1 = \sigma(XW_1^\top+b_1),
$$

$$
h_2 = \sigma(h_1W_2^\top+b_2),
$$

$$
\hat{y}=h_2W_3^\top+b_3,
$$

$$
L=\ell(\hat{y}, y).
$$

Reverse mode starts from \(L\), then moves backward through the loss, final linear layer, activation, hidden linear layer, another activation, and first linear layer. Each operation applies its local backward rule.

PyTorch code:

```python
import torch
from torch import nn

model = nn.Sequential(
    nn.Linear(10, 32),
    nn.ReLU(),
    nn.Linear(32, 32),
    nn.ReLU(),
    nn.Linear(32, 1),
)

X = torch.randn(16, 10)
target = torch.randn(16, 1)

pred = model(X)
loss = ((pred - target) ** 2).mean()

loss.backward()

for name, param in model.named_parameters():
    print(name, param.grad.shape)
```

The programmer does not write the derivative rules for the whole network. PyTorch composes local backward rules automatically.

### Cost of Reverse Mode

Reverse mode has two important costs.

The first cost is computation. The backward pass usually costs the same order of magnitude as the forward pass. For many neural networks, one training step costs roughly one forward pass plus one backward pass.

The second cost is memory. Reverse mode needs intermediate values from the forward pass. These values are stored so that backward rules can use them later.

For example, the derivative of

$$
z=x^2
$$

needs the value of \(x\). The derivative of ReLU needs to know which entries were positive. The derivative of batch normalization needs normalization statistics. The derivative of attention needs intermediate tensors related to queries, keys, values, attention scores, and probabilities.

This is why training uses more memory than inference.

### Checkpointing

Gradient checkpointing reduces memory use by storing fewer intermediate activations during the forward pass. During the backward pass, missing activations are recomputed.

This trades computation for memory.

Without checkpointing, the system stores many intermediate values:

$$
\text{more memory, less recomputation}.
$$

With checkpointing, the system stores fewer intermediate values:

$$
\text{less memory, more recomputation}.
$$

In PyTorch, checkpointing can be applied with `torch.utils.checkpoint`:

```python
import torch
from torch.utils.checkpoint import checkpoint

def block(x):
    return layer2(torch.relu(layer1(x)))

x = torch.randn(16, 128, requires_grad=True)
y = checkpoint(block, x)
loss = y.sum()
loss.backward()
```

Checkpointing is common when training large models that would otherwise exceed GPU memory.

### In-Place Operations and Reverse Mode

Reverse mode depends on saved forward values. In-place operations can overwrite those values and break gradient computation.

Example:

```python
x = torch.randn(4, requires_grad=True)

y = x ** 2
x.add_(1.0)  # unsafe
loss = y.sum()
loss.backward()
```

The backward rule for \(x^2\) needs the original value of \(x\). If that value is modified in place, PyTorch may raise an error.

A safer version uses out-of-place operations:

```python
x = torch.randn(4, requires_grad=True)

y = x ** 2
x2 = x + 1.0
loss = y.sum()
loss.backward()
```

In-place operations are sometimes useful for memory efficiency, but they should be used carefully when gradients are involved.

### Detaching and Stopping Gradients

Sometimes a computation should stop gradient flow. PyTorch uses `.detach()` for this.

```python
x = torch.tensor(2.0, requires_grad=True)

y = x ** 2
z = y.detach()
w = 3 * z

print(w.requires_grad)  # False
```

The tensor `z` has the same numerical value as `y`, but it has no connection to the graph that produced `y`.

This is used in target networks, contrastive learning, reinforcement learning, teacher-student methods, logging, and some optimization algorithms.

A common pattern is:

```python
with torch.no_grad():
    target = teacher_model(x)

pred = student_model(x)
loss = loss_fn(pred, target)
loss.backward()
```

The target is treated as a fixed value. Gradients update the student model but not the teacher computation.

### Summary

Reverse-mode differentiation computes gradients by traversing a computational graph backward from a scalar output. It starts with an upstream gradient of \(1\) at the loss, then applies local backward rules for each operation.

This method is efficient for deep learning because training usually requires the gradient of one scalar loss with respect to many parameters. PyTorch implements reverse mode through autograd. It records operations during the forward pass, stores needed intermediate values, and computes vector-Jacobian products during the backward pass.

