Skip to content

Reverse-Mode Differentiation

Reverse-mode differentiation is the method used by backpropagation. It computes derivatives by first evaluating a function forward, then propagating gradient information backward from the output to the inputs.

Reverse-mode differentiation is the method used by backpropagation. It computes derivatives by first evaluating a function forward, then propagating gradient information backward from the output to the inputs.

This method is especially useful in deep learning because neural networks usually have many parameters and one scalar loss. Reverse-mode differentiation can compute the gradient of one scalar output with respect to millions or billions of parameters efficiently.

The Problem Setting

Assume we have a function

L=f(θ), L = f(\theta),

where θ\theta is a vector of parameters and LL is a scalar loss.

If

θ=[θ1θ2θn], \theta = \begin{bmatrix} \theta_1 \\ \theta_2 \\ \vdots \\ \theta_n \end{bmatrix},

then the gradient is

θL=[Lθ1Lθ2Lθn]. \nabla_\theta L = \begin{bmatrix} \frac{\partial L}{\partial \theta_1} \\ \frac{\partial L}{\partial \theta_2} \\ \vdots \\ \frac{\partial L}{\partial \theta_n} \end{bmatrix}.

A deep model may have millions of parameters. Computing each partial derivative separately would be too expensive. Reverse-mode differentiation avoids that cost by reusing intermediate derivative information.

Forward Mode Versus Reverse Mode

There are two broad modes of automatic differentiation: forward mode and reverse mode.

Forward mode propagates derivatives from inputs to outputs. It answers: if one input changes, how does each later value change?

Reverse mode propagates derivatives from outputs to inputs. It answers: if the final output changes, how much did each earlier value contribute?

For a function

f:RnRm, f:\mathbb{R}^n\to\mathbb{R}^m,

forward mode is efficient when nn is small. Reverse mode is efficient when mm is small.

Deep learning training usually has

f:RnR, f:\mathbb{R}^n\to\mathbb{R},

where nn may be very large. The output is one scalar loss. This is the ideal case for reverse mode.

A Simple Example

Consider

z=(x+y)2. z = (x+y)^2.

Break it into intermediate variables:

a=x+y, a = x+y, z=a2. z = a^2.

The forward pass computes values:

x=2,y=3, x=2,\quad y=3, a=5, a=5, z=25. z=25.

The reverse pass computes derivatives of the final output zz with respect to each intermediate value.

We start with

zz=1. \frac{\partial z}{\partial z}=1.

Then

za=2a=10. \frac{\partial z}{\partial a}=2a=10.

Since

a=x+y, a=x+y,

we have

ax=1,ay=1. \frac{\partial a}{\partial x}=1, \quad \frac{\partial a}{\partial y}=1.

Therefore

zx=zaax=10, \frac{\partial z}{\partial x} = \frac{\partial z}{\partial a} \frac{\partial a}{\partial x} = 10,

and

zy=zaay=10. \frac{\partial z}{\partial y} = \frac{\partial z}{\partial a} \frac{\partial a}{\partial y} = 10.

In PyTorch:

import torch

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)

a = x + y
z = a ** 2

z.backward()

print(x.grad)  # tensor(10.)
print(y.grad)  # tensor(10.)

PyTorch performs this reverse traversal automatically.

Adjoints

Reverse-mode differentiation often uses the term adjoint. The adjoint of a variable is the derivative of the final scalar output with respect to that variable.

For a variable vv, its adjoint is written as

vˉ=Lv. \bar{v} = \frac{\partial L}{\partial v}.

Here LL is the final scalar loss.

For the computation

a=x+y,z=a2,L=z, a=x+y, \quad z=a^2, \quad L=z,

the adjoints are

zˉ=Lz=1, \bar{z} = \frac{\partial L}{\partial z}=1, aˉ=La=2a, \bar{a} = \frac{\partial L}{\partial a}=2a, xˉ=Lx=aˉax, \bar{x} = \frac{\partial L}{\partial x}=\bar{a}\frac{\partial a}{\partial x}, yˉ=Ly=aˉay. \bar{y} = \frac{\partial L}{\partial y}=\bar{a}\frac{\partial a}{\partial y}.

At a=5a=5, we get

xˉ=10,yˉ=10. \bar{x}=10, \quad \bar{y}=10.

The adjoint notation is useful because it describes the backward pass locally. Each operation receives an upstream adjoint and distributes it to its inputs.

Local Backward Rules

Reverse-mode differentiation depends on local backward rules. Each primitive operation knows how to send gradients to its inputs.

For addition,

z=x+y. z = x+y.

If the upstream gradient is zˉ\bar{z}, then

xˉ+=zˉ, \bar{x} \mathrel{+}= \bar{z}, yˉ+=zˉ. \bar{y} \mathrel{+}= \bar{z}.

For multiplication,

z=xy. z = xy.

The local derivatives are

zx=y,zy=x. \frac{\partial z}{\partial x}=y, \quad \frac{\partial z}{\partial y}=x.

Thus

xˉ+=zˉy, \bar{x} \mathrel{+}= \bar{z}y, yˉ+=zˉx. \bar{y} \mathrel{+}= \bar{z}x.

For squaring,

z=x2. z=x^2.

The backward rule is

xˉ+=zˉ2x. \bar{x} \mathrel{+}= \bar{z}\,2x.

The symbol +=\mathrel{+}= means that gradients are accumulated. A variable may influence the output through multiple paths, so all contributions must be added.

Why Gradients Accumulate

Consider

z=x2+x. z = x^2 + x.

Break it into steps:

a=x2, a = x^2, b=x, b = x, z=a+b. z = a+b.

The variable xx contributes to zz through two paths: one through aa, and one through bb.

The derivative is

dzdx=2x+1. \frac{dz}{dx}=2x+1.

Reverse mode obtains this by adding contributions from both paths.

At x=3x=3,

dzdx=7. \frac{dz}{dx}=7.

In PyTorch:

x = torch.tensor(3.0, requires_grad=True)

z = x ** 2 + x
z.backward()

print(x.grad)  # tensor(7.)

This accumulation behavior is the reason PyTorch adds gradients into .grad fields rather than replacing them automatically.

Reverse Pass on a Larger Graph

Consider the computation

u=xy, u = xy, v=x+y, v = x+y, L=u+v. L = u+v.

The forward pass computes uu, vv, and LL. The reverse pass starts with

Lˉ=1. \bar{L}=1.

Since

L=u+v, L=u+v,

we get

uˉ=1,vˉ=1. \bar{u} = 1, \quad \bar{v}=1.

Since

u=xy, u=xy,

we add

xˉ+=uˉy, \bar{x} \mathrel{+}= \bar{u}y, yˉ+=uˉx. \bar{y} \mathrel{+}= \bar{u}x.

Since

v=x+y, v=x+y,

we add

xˉ+=vˉ, \bar{x} \mathrel{+}= \bar{v}, yˉ+=vˉ. \bar{y} \mathrel{+}= \bar{v}.

Thus

xˉ=y+1, \bar{x}=y+1, yˉ=x+1. \bar{y}=x+1.

At x=2x=2, y=3y=3,

xˉ=4,yˉ=3. \bar{x}=4, \quad \bar{y}=3.

PyTorch:

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)

u = x * y
v = x + y
L = u + v

L.backward()

print(x.grad)  # tensor(4.)
print(y.grad)  # tensor(3.)

Vector-Jacobian Products

Reverse mode can be understood through vector-Jacobian products.

Suppose

y=f(x), y = f(x),

where xRnx\in\mathbb{R}^n and yRmy\in\mathbb{R}^m. The Jacobian is

J=yxRm×n. J = \frac{\partial y}{\partial x} \in\mathbb{R}^{m\times n}.

If a later scalar loss LL depends on yy, then the upstream gradient is

yˉ=LyRm. \bar{y} = \frac{\partial L}{\partial y} \in\mathbb{R}^{m}.

The gradient with respect to xx is

xˉ=Jyˉ. \bar{x} = J^\top \bar{y}.

This operation is a vector-Jacobian product. It avoids explicitly forming the full Jacobian. This matters because Jacobians in deep learning can be extremely large.

PyTorch’s backward pass is primarily a system for computing vector-Jacobian products efficiently.

Non-Scalar Outputs

When the output is scalar, PyTorch implicitly uses an upstream gradient of 11.

x = torch.tensor(3.0, requires_grad=True)
y = x ** 2

y.backward()

This means

dydy=1. \frac{dy}{dy}=1.

When the output is not scalar, the user must provide the upstream gradient.

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)

y = x ** 2
y.backward(torch.tensor([1.0, 1.0, 1.0]))

print(x.grad)  # tensor([2., 4., 6.])

Here the provided vector acts as yˉ\bar{y}. Since

yi=xi2, y_i=x_i^2,

the vector-Jacobian product gives

xˉi=yˉi2xi. \bar{x}_i=\bar{y}_i 2x_i.

With yˉi=1\bar{y}_i=1, this gives 2xi2x_i.

A different upstream gradient gives a different vector-Jacobian product:

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)

y = x ** 2
y.backward(torch.tensor([10.0, 1.0, 0.1]))

print(x.grad)  # tensor([20.0000, 4.0000, 0.6000])

Reverse Mode in Neural Networks

A neural network is a composition of many operations:

L=(fθ(X),y). L = \ell(f_\theta(X), y).

The forward pass computes predictions and loss. The reverse pass computes

θL. \nabla_\theta L.

For a multilayer network,

h1=σ(XW1+b1), h_1 = \sigma(XW_1^\top+b_1), h2=σ(h1W2+b2), h_2 = \sigma(h_1W_2^\top+b_2), y^=h2W3+b3, \hat{y}=h_2W_3^\top+b_3, L=(y^,y). L=\ell(\hat{y}, y).

Reverse mode starts from LL, then moves backward through the loss, final linear layer, activation, hidden linear layer, another activation, and first linear layer. Each operation applies its local backward rule.

PyTorch code:

import torch
from torch import nn

model = nn.Sequential(
    nn.Linear(10, 32),
    nn.ReLU(),
    nn.Linear(32, 32),
    nn.ReLU(),
    nn.Linear(32, 1),
)

X = torch.randn(16, 10)
target = torch.randn(16, 1)

pred = model(X)
loss = ((pred - target) ** 2).mean()

loss.backward()

for name, param in model.named_parameters():
    print(name, param.grad.shape)

The programmer does not write the derivative rules for the whole network. PyTorch composes local backward rules automatically.

Cost of Reverse Mode

Reverse mode has two important costs.

The first cost is computation. The backward pass usually costs the same order of magnitude as the forward pass. For many neural networks, one training step costs roughly one forward pass plus one backward pass.

The second cost is memory. Reverse mode needs intermediate values from the forward pass. These values are stored so that backward rules can use them later.

For example, the derivative of

z=x2 z=x^2

needs the value of xx. The derivative of ReLU needs to know which entries were positive. The derivative of batch normalization needs normalization statistics. The derivative of attention needs intermediate tensors related to queries, keys, values, attention scores, and probabilities.

This is why training uses more memory than inference.

Checkpointing

Gradient checkpointing reduces memory use by storing fewer intermediate activations during the forward pass. During the backward pass, missing activations are recomputed.

This trades computation for memory.

Without checkpointing, the system stores many intermediate values:

more memory, less recomputation. \text{more memory, less recomputation}.

With checkpointing, the system stores fewer intermediate values:

less memory, more recomputation. \text{less memory, more recomputation}.

In PyTorch, checkpointing can be applied with torch.utils.checkpoint:

import torch
from torch.utils.checkpoint import checkpoint

def block(x):
    return layer2(torch.relu(layer1(x)))

x = torch.randn(16, 128, requires_grad=True)
y = checkpoint(block, x)
loss = y.sum()
loss.backward()

Checkpointing is common when training large models that would otherwise exceed GPU memory.

In-Place Operations and Reverse Mode

Reverse mode depends on saved forward values. In-place operations can overwrite those values and break gradient computation.

Example:

x = torch.randn(4, requires_grad=True)

y = x ** 2
x.add_(1.0)  # unsafe
loss = y.sum()
loss.backward()

The backward rule for x2x^2 needs the original value of xx. If that value is modified in place, PyTorch may raise an error.

A safer version uses out-of-place operations:

x = torch.randn(4, requires_grad=True)

y = x ** 2
x2 = x + 1.0
loss = y.sum()
loss.backward()

In-place operations are sometimes useful for memory efficiency, but they should be used carefully when gradients are involved.

Detaching and Stopping Gradients

Sometimes a computation should stop gradient flow. PyTorch uses .detach() for this.

x = torch.tensor(2.0, requires_grad=True)

y = x ** 2
z = y.detach()
w = 3 * z

print(w.requires_grad)  # False

The tensor z has the same numerical value as y, but it has no connection to the graph that produced y.

This is used in target networks, contrastive learning, reinforcement learning, teacher-student methods, logging, and some optimization algorithms.

A common pattern is:

with torch.no_grad():
    target = teacher_model(x)

pred = student_model(x)
loss = loss_fn(pred, target)
loss.backward()

The target is treated as a fixed value. Gradients update the student model but not the teacher computation.

Summary

Reverse-mode differentiation computes gradients by traversing a computational graph backward from a scalar output. It starts with an upstream gradient of 11 at the loss, then applies local backward rules for each operation.

This method is efficient for deep learning because training usually requires the gradient of one scalar loss with respect to many parameters. PyTorch implements reverse mode through autograd. It records operations during the forward pass, stores needed intermediate values, and computes vector-Jacobian products during the backward pass.