A computational graph is a graph that represents a numerical computation. The nodes represent values or operations. The edges describe how data flows from one operation to the next.
A computational graph is a graph that represents a numerical computation. The nodes represent values or operations. The edges describe how data flows from one operation to the next.
In deep learning, computational graphs are important because they give a precise way to describe forward computation and gradient computation. A neural network takes input tensors, applies a sequence of operations, produces an output tensor, computes a loss, and then differentiates that loss with respect to the model parameters. This entire process can be represented as a graph.
A Simple Computation
Consider the scalar computation
This computation can be broken into two steps:
The value is an intermediate result. The final output depends on , and depends on and .
The computational graph is:
This graph records the dependency structure of the computation. It says that cannot be computed until has been computed, and cannot be computed until and are known.
In PyTorch:
import torch
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)
a = x + y
z = a ** 2
print(z) # tensor(25., grad_fn=<PowBackward0>)The result is , since
The printed tensor also contains a grad_fn. This tells us that PyTorch has recorded how z was computed. Since z depends on previous tensors that require gradients, PyTorch attaches gradient information to it.
Nodes and Edges
A computational graph has two main components.
A node may represent a tensor value, such as , , or . A node may also represent an operation, such as addition, multiplication, matrix multiplication, or exponentiation.
An edge represents dependency. If a value is needed to compute another value, there is an edge connecting them.
For example,
depends on and . Thus the graph has edges from to the addition operation, and from to the addition operation.
Then
depends on . Thus the graph has an edge from to the squaring operation.
The graph is directed because computation has a direction. Values flow from inputs to outputs during the forward pass.
Forward Computation
The forward pass evaluates the graph from inputs to outputs.
For the computation
the forward pass is:
In a neural network, the forward pass is the computation that maps input data to predictions.
For a linear layer,
where
and
The computational graph contains nodes for , , , the matrix multiplication, the addition, and the output .
In PyTorch:
import torch
from torch import nn
linear = nn.Linear(3, 4)
X = torch.randn(5, 3)
Y = linear(X)
print(Y.shape) # torch.Size([5, 4])The layer internally computes a matrix multiplication and adds a bias. PyTorch records these operations when gradients are needed.
Loss as the Final Node
Training usually ends the forward pass with a scalar loss. The loss measures how far the model prediction is from the target.
For example, suppose a model produces predictions , and the true targets are . A mean squared error loss is
The loss is a scalar. This scalar is the final node of the training graph.
A typical training graph has the form
The model parameters are also inputs to the graph. If the model has parameters , then more precisely:
Training asks how changes when changes. That question is answered by gradients:
In PyTorch:
import torch
from torch import nn
model = nn.Linear(3, 1)
X = torch.randn(8, 3)
y = torch.randn(8, 1)
pred = model(X)
loss = ((pred - y) ** 2).mean()
print(loss.shape) # torch.Size([])The loss has shape [], so it is a scalar tensor.
Backward Dependencies
The computational graph is used again during backpropagation. During the forward pass, PyTorch records the operations that produced each tensor. During the backward pass, PyTorch walks backward through the graph and applies the chain rule.
For the simple computation
we can compute
Let
Then
By the chain rule,
Since
and
we get
Similarly,
At and , we have , so
In PyTorch:
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)
z = (x + y) ** 2
z.backward()
print(x.grad) # tensor(10.)
print(y.grad) # tensor(10.)The gradients match the manual calculation.
Leaf Tensors
A leaf tensor is a tensor that is created directly by the user and has no previous operation that produced it. Model parameters are usually leaf tensors.
For example:
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)
a = x + y
z = a ** 2Here x and y are leaf tensors. The tensor a is not a leaf tensor because it was produced by addition. The tensor z is not a leaf tensor because it was produced by exponentiation.
print(x.is_leaf) # True
print(y.is_leaf) # True
print(a.is_leaf) # False
print(z.is_leaf) # FalseBy default, PyTorch stores gradients in the .grad field of leaf tensors that have requires_grad=True.
z.backward()
print(x.grad) # tensor(10.)
print(y.grad) # tensor(10.)
print(a.grad) # usually NoneThe intermediate tensor a participates in gradient computation, but PyTorch does not store its gradient by default. This saves memory.
To keep the gradient of an intermediate tensor, call retain_grad():
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)
a = x + y
a.retain_grad()
z = a ** 2
z.backward()
print(a.grad) # tensor(10.)Dynamic Computational Graphs
PyTorch uses dynamic computational graphs. This means the graph is built during ordinary Python execution.
Each time the forward computation runs, PyTorch creates a new graph from the operations that actually occur. This gives PyTorch a natural programming model. Loops, conditionals, recursion, and ordinary Python control flow can affect the graph.
Example:
import torch
x = torch.tensor(2.0, requires_grad=True)
if x.item() > 0:
y = x ** 2
else:
y = -x
y.backward()
print(x.grad) # tensor(4.)Since , the branch is used. The derivative is
If the input changed and a different branch ran, PyTorch would build a different graph.
This is one of the central differences between PyTorch and older static-graph systems. In a static graph system, the graph is usually defined first and executed later. In PyTorch, the graph is created as the computation runs.
Computational Graphs for Neural Networks
A neural network is a large computational graph composed of many simple operations.
Consider a two-layer neural network:
The graph contains the input , parameters , matrix multiplications, additions, the activation function , the prediction , and the loss .
In PyTorch:
import torch
from torch import nn
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(10, 32)
self.activation = nn.ReLU()
self.layer2 = nn.Linear(32, 1)
def forward(self, x):
h = self.layer1(x)
h = self.activation(h)
y = self.layer2(h)
return y
model = MLP()
X = torch.randn(16, 10)
target = torch.randn(16, 1)
pred = model(X)
loss = ((pred - target) ** 2).mean()
loss.backward()When loss.backward() runs, PyTorch computes gradients for every parameter that contributed to the loss:
for name, param in model.named_parameters():
print(name, param.shape, param.grad.shape)Typical output:
layer1.weight torch.Size([32, 10]) torch.Size([32, 10])
layer1.bias torch.Size([32]) torch.Size([32])
layer2.weight torch.Size([1, 32]) torch.Size([1, 32])
layer2.bias torch.Size([1]) torch.Size([1])Each gradient tensor has the same shape as the corresponding parameter tensor.
Graph Construction and requires_grad
PyTorch only records operations for tensors that require gradients. A tensor requires gradients when its requires_grad field is True.
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0)
z = x * y
print(z.requires_grad) # TrueAlthough y does not require gradients, z does because z depends on x.
If no input requires gradients, PyTorch does not build a gradient graph:
x = torch.tensor(2.0)
y = torch.tensor(3.0)
z = x * y
print(z.requires_grad) # False
print(z.grad_fn) # NoneModel parameters created by PyTorch modules usually require gradients by default:
layer = nn.Linear(3, 4)
print(layer.weight.requires_grad) # True
print(layer.bias.requires_grad) # TrueThis is why calling loss.backward() computes gradients for the model parameters.
Detaching from the Graph
Sometimes a tensor should be treated as a constant, even though it was produced by previous operations. PyTorch provides .detach() for this purpose.
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2
z = y.detach()
print(y.requires_grad) # True
print(z.requires_grad) # FalseThe tensor z shares the numerical value of y, but it no longer participates in the gradient graph.
This is useful in several situations: logging values, stopping gradients through part of a model, updating target networks in reinforcement learning, and preventing unnecessary memory use during evaluation.
Example:
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2
z = y.detach()
w = z * 3
# w has no path back to x
print(w.requires_grad) # FalseOnce detached, the computation no longer contributes gradients to the original tensor.
Graph Lifetime
By default, PyTorch frees the computational graph after backward() finishes. This saves memory.
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2
y.backward()
# Calling y.backward() again would fail unless the graph was retained.To run backward through the same graph more than once, use retain_graph=True:
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2
y.backward(retain_graph=True)
y.backward()This is rarely needed in ordinary supervised learning. It is more common in advanced algorithms that use multiple gradient computations from the same forward pass.
A typical training loop builds a new graph on every iteration:
for X, y in dataloader:
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()Each call to model(X) constructs a fresh graph. Each call to loss.backward() uses and then releases that graph.
Computational Graphs and Memory
Computational graphs require memory. During the forward pass, PyTorch stores intermediate values needed for the backward pass. For example, to compute the gradient of
PyTorch needs the value of . For more complex operations, it may need activations, masks, normalization statistics, or other saved tensors.
This explains why training usually uses more memory than inference. During inference, gradients are unnecessary, so the graph does not need to be stored.
PyTorch provides torch.no_grad() for inference:
model.eval()
with torch.no_grad():
pred = model(X)Inside this block, PyTorch does not build a gradient graph. This reduces memory use and usually improves speed.
For newer PyTorch code, torch.inference_mode() can be used for inference when tensors will not later participate in gradient computation:
model.eval()
with torch.inference_mode():
pred = model(X)Both tools prevent unnecessary graph construction during evaluation.
A Complete Example
The following example shows a full forward and backward computation.
import torch
from torch import nn
torch.manual_seed(0)
model = nn.Sequential(
nn.Linear(4, 8),
nn.ReLU(),
nn.Linear(8, 1),
)
X = torch.randn(10, 4)
y = torch.randn(10, 1)
pred = model(X)
loss = ((pred - y) ** 2).mean()
print(loss)
print(loss.grad_fn)
model.zero_grad()
loss.backward()
for name, param in model.named_parameters():
print(name, param.shape, param.grad.shape)The forward pass constructs the graph. The loss is the final scalar node. The backward pass traverses the graph in reverse and fills the gradients of the model parameters.
The important point is that the user does not manually construct the graph. The graph is created by ordinary tensor operations.
Common Errors
A common error is trying to call backward() on a tensor that has no gradient graph:
x = torch.tensor(2.0)
y = x ** 2
y.backward()This fails because x does not require gradients, so y also does not require gradients.
Correct version:
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2
y.backward()Another common error is forgetting to clear accumulated gradients:
loss.backward()
optimizer.step()PyTorch accumulates gradients by default. Therefore a training loop should usually include:
optimizer.zero_grad()
loss.backward()
optimizer.step()A third common error is accidentally detaching a tensor that should remain differentiable:
h = model.encoder(x)
h = h.detach()
y = model.decoder(h)In this code, the decoder can receive gradients, but the encoder cannot receive gradients from y. This may be intentional, but it often happens by mistake.
Summary
A computational graph records how tensors are produced from other tensors. In PyTorch, the graph is built dynamically during the forward computation. If tensors require gradients, PyTorch records the operations needed to compute derivatives later.
The forward pass evaluates the graph from inputs to outputs. The backward pass traverses the graph in reverse and applies the chain rule. Model parameters are leaf tensors, and their gradients are stored after backward().
Understanding computational graphs makes PyTorch behavior predictable. It explains requires_grad, grad_fn, detach, no_grad, gradient accumulation, and the memory difference between training and inference.