Skip to content

Computational Graphs

A computational graph is a graph that represents a numerical computation. The nodes represent values or operations. The edges describe how data flows from one operation to the next.

A computational graph is a graph that represents a numerical computation. The nodes represent values or operations. The edges describe how data flows from one operation to the next.

In deep learning, computational graphs are important because they give a precise way to describe forward computation and gradient computation. A neural network takes input tensors, applies a sequence of operations, produces an output tensor, computes a loss, and then differentiates that loss with respect to the model parameters. This entire process can be represented as a graph.

A Simple Computation

Consider the scalar computation

z=(x+y)2. z = (x + y)^2.

This computation can be broken into two steps:

a=x+y, a = x + y, z=a2. z = a^2.

The value aa is an intermediate result. The final output zz depends on aa, and aa depends on xx and yy.

The computational graph is:

x,ya=x+yz=a2. x,y \longrightarrow a=x+y \longrightarrow z=a^2.

This graph records the dependency structure of the computation. It says that zz cannot be computed until aa has been computed, and aa cannot be computed until xx and yy are known.

In PyTorch:

import torch

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)

a = x + y
z = a ** 2

print(z)  # tensor(25., grad_fn=<PowBackward0>)

The result is 2525, since

(2+3)2=25. (2+3)^2 = 25.

The printed tensor also contains a grad_fn. This tells us that PyTorch has recorded how z was computed. Since z depends on previous tensors that require gradients, PyTorch attaches gradient information to it.

Nodes and Edges

A computational graph has two main components.

A node may represent a tensor value, such as xx, yy, or aa. A node may also represent an operation, such as addition, multiplication, matrix multiplication, or exponentiation.

An edge represents dependency. If a value is needed to compute another value, there is an edge connecting them.

For example,

a=x+y a = x + y

depends on xx and yy. Thus the graph has edges from xx to the addition operation, and from yy to the addition operation.

Then

z=a2 z = a^2

depends on aa. Thus the graph has an edge from aa to the squaring operation.

The graph is directed because computation has a direction. Values flow from inputs to outputs during the forward pass.

Forward Computation

The forward pass evaluates the graph from inputs to outputs.

For the computation

z=(x+y)2, z = (x+y)^2,

the forward pass is:

x=2,y=3, x = 2,\quad y = 3, a=x+y=5, a = x + y = 5, z=a2=25. z = a^2 = 25.

In a neural network, the forward pass is the computation that maps input data to predictions.

For a linear layer,

Y=XW+b, Y = XW^\top + b,

where

XRB×d, X\in\mathbb{R}^{B\times d}, WRh×d, W\in\mathbb{R}^{h\times d}, bRh, b\in\mathbb{R}^{h},

and

YRB×h. Y\in\mathbb{R}^{B\times h}.

The computational graph contains nodes for XX, WW, bb, the matrix multiplication, the addition, and the output YY.

In PyTorch:

import torch
from torch import nn

linear = nn.Linear(3, 4)

X = torch.randn(5, 3)
Y = linear(X)

print(Y.shape)  # torch.Size([5, 4])

The layer internally computes a matrix multiplication and adds a bias. PyTorch records these operations when gradients are needed.

Loss as the Final Node

Training usually ends the forward pass with a scalar loss. The loss measures how far the model prediction is from the target.

For example, suppose a model produces predictions y^\hat{y}, and the true targets are yy. A mean squared error loss is

L=1ni=1n(y^iyi)2. L = \frac{1}{n}\sum_{i=1}^{n}(\hat{y}_i-y_i)^2.

The loss is a scalar. This scalar is the final node of the training graph.

A typical training graph has the form

Xmodely^L. X \longrightarrow \text{model} \longrightarrow \hat{y} \longrightarrow L.

The model parameters are also inputs to the graph. If the model has parameters θ\theta, then more precisely:

(X,θ)y^L. (X,\theta) \longrightarrow \hat{y} \longrightarrow L.

Training asks how LL changes when θ\theta changes. That question is answered by gradients:

θL. \nabla_\theta L.

In PyTorch:

import torch
from torch import nn

model = nn.Linear(3, 1)

X = torch.randn(8, 3)
y = torch.randn(8, 1)

pred = model(X)
loss = ((pred - y) ** 2).mean()

print(loss.shape)  # torch.Size([])

The loss has shape [], so it is a scalar tensor.

Backward Dependencies

The computational graph is used again during backpropagation. During the forward pass, PyTorch records the operations that produced each tensor. During the backward pass, PyTorch walks backward through the graph and applies the chain rule.

For the simple computation

z=(x+y)2, z = (x+y)^2,

we can compute

zxandzy. \frac{\partial z}{\partial x} \quad\text{and}\quad \frac{\partial z}{\partial y}.

Let

a=x+y. a = x+y.

Then

z=a2. z = a^2.

By the chain rule,

zx=zaax. \frac{\partial z}{\partial x} = \frac{\partial z}{\partial a} \frac{\partial a}{\partial x}.

Since

za=2a \frac{\partial z}{\partial a}=2a

and

ax=1, \frac{\partial a}{\partial x}=1,

we get

zx=2a. \frac{\partial z}{\partial x}=2a.

Similarly,

zy=2a. \frac{\partial z}{\partial y}=2a.

At x=2x=2 and y=3y=3, we have a=5a=5, so

zx=10,zy=10. \frac{\partial z}{\partial x}=10, \quad \frac{\partial z}{\partial y}=10.

In PyTorch:

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)

z = (x + y) ** 2
z.backward()

print(x.grad)  # tensor(10.)
print(y.grad)  # tensor(10.)

The gradients match the manual calculation.

Leaf Tensors

A leaf tensor is a tensor that is created directly by the user and has no previous operation that produced it. Model parameters are usually leaf tensors.

For example:

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)

a = x + y
z = a ** 2

Here x and y are leaf tensors. The tensor a is not a leaf tensor because it was produced by addition. The tensor z is not a leaf tensor because it was produced by exponentiation.

print(x.is_leaf)  # True
print(y.is_leaf)  # True
print(a.is_leaf)  # False
print(z.is_leaf)  # False

By default, PyTorch stores gradients in the .grad field of leaf tensors that have requires_grad=True.

z.backward()

print(x.grad)  # tensor(10.)
print(y.grad)  # tensor(10.)
print(a.grad)  # usually None

The intermediate tensor a participates in gradient computation, but PyTorch does not store its gradient by default. This saves memory.

To keep the gradient of an intermediate tensor, call retain_grad():

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)

a = x + y
a.retain_grad()

z = a ** 2
z.backward()

print(a.grad)  # tensor(10.)

Dynamic Computational Graphs

PyTorch uses dynamic computational graphs. This means the graph is built during ordinary Python execution.

Each time the forward computation runs, PyTorch creates a new graph from the operations that actually occur. This gives PyTorch a natural programming model. Loops, conditionals, recursion, and ordinary Python control flow can affect the graph.

Example:

import torch

x = torch.tensor(2.0, requires_grad=True)

if x.item() > 0:
    y = x ** 2
else:
    y = -x

y.backward()

print(x.grad)  # tensor(4.)

Since x=2x=2, the branch y=x2y=x^2 is used. The derivative is

dydx=2x=4. \frac{dy}{dx}=2x=4.

If the input changed and a different branch ran, PyTorch would build a different graph.

This is one of the central differences between PyTorch and older static-graph systems. In a static graph system, the graph is usually defined first and executed later. In PyTorch, the graph is created as the computation runs.

Computational Graphs for Neural Networks

A neural network is a large computational graph composed of many simple operations.

Consider a two-layer neural network:

h=σ(XW1+b1), h = \sigma(XW_1^\top + b_1), y^=hW2+b2, \hat{y} = hW_2^\top + b_2, L=(y^,y). L = \ell(\hat{y}, y).

The graph contains the input XX, parameters W1,b1,W2,b2W_1,b_1,W_2,b_2, matrix multiplications, additions, the activation function σ\sigma, the prediction y^\hat{y}, and the loss LL.

In PyTorch:

import torch
from torch import nn

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(10, 32)
        self.activation = nn.ReLU()
        self.layer2 = nn.Linear(32, 1)

    def forward(self, x):
        h = self.layer1(x)
        h = self.activation(h)
        y = self.layer2(h)
        return y

model = MLP()

X = torch.randn(16, 10)
target = torch.randn(16, 1)

pred = model(X)
loss = ((pred - target) ** 2).mean()

loss.backward()

When loss.backward() runs, PyTorch computes gradients for every parameter that contributed to the loss:

for name, param in model.named_parameters():
    print(name, param.shape, param.grad.shape)

Typical output:

layer1.weight torch.Size([32, 10]) torch.Size([32, 10])
layer1.bias torch.Size([32]) torch.Size([32])
layer2.weight torch.Size([1, 32]) torch.Size([1, 32])
layer2.bias torch.Size([1]) torch.Size([1])

Each gradient tensor has the same shape as the corresponding parameter tensor.

Graph Construction and requires_grad

PyTorch only records operations for tensors that require gradients. A tensor requires gradients when its requires_grad field is True.

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0)

z = x * y
print(z.requires_grad)  # True

Although y does not require gradients, z does because z depends on x.

If no input requires gradients, PyTorch does not build a gradient graph:

x = torch.tensor(2.0)
y = torch.tensor(3.0)

z = x * y
print(z.requires_grad)  # False
print(z.grad_fn)        # None

Model parameters created by PyTorch modules usually require gradients by default:

layer = nn.Linear(3, 4)

print(layer.weight.requires_grad)  # True
print(layer.bias.requires_grad)    # True

This is why calling loss.backward() computes gradients for the model parameters.

Detaching from the Graph

Sometimes a tensor should be treated as a constant, even though it was produced by previous operations. PyTorch provides .detach() for this purpose.

x = torch.tensor(2.0, requires_grad=True)

y = x ** 2
z = y.detach()

print(y.requires_grad)  # True
print(z.requires_grad)  # False

The tensor z shares the numerical value of y, but it no longer participates in the gradient graph.

This is useful in several situations: logging values, stopping gradients through part of a model, updating target networks in reinforcement learning, and preventing unnecessary memory use during evaluation.

Example:

x = torch.tensor(2.0, requires_grad=True)

y = x ** 2
z = y.detach()
w = z * 3

# w has no path back to x
print(w.requires_grad)  # False

Once detached, the computation no longer contributes gradients to the original tensor.

Graph Lifetime

By default, PyTorch frees the computational graph after backward() finishes. This saves memory.

x = torch.tensor(2.0, requires_grad=True)

y = x ** 2
y.backward()

# Calling y.backward() again would fail unless the graph was retained.

To run backward through the same graph more than once, use retain_graph=True:

x = torch.tensor(2.0, requires_grad=True)

y = x ** 2
y.backward(retain_graph=True)
y.backward()

This is rarely needed in ordinary supervised learning. It is more common in advanced algorithms that use multiple gradient computations from the same forward pass.

A typical training loop builds a new graph on every iteration:

for X, y in dataloader:
    pred = model(X)
    loss = loss_fn(pred, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Each call to model(X) constructs a fresh graph. Each call to loss.backward() uses and then releases that graph.

Computational Graphs and Memory

Computational graphs require memory. During the forward pass, PyTorch stores intermediate values needed for the backward pass. For example, to compute the gradient of

y=x2, y = x^2,

PyTorch needs the value of xx. For more complex operations, it may need activations, masks, normalization statistics, or other saved tensors.

This explains why training usually uses more memory than inference. During inference, gradients are unnecessary, so the graph does not need to be stored.

PyTorch provides torch.no_grad() for inference:

model.eval()

with torch.no_grad():
    pred = model(X)

Inside this block, PyTorch does not build a gradient graph. This reduces memory use and usually improves speed.

For newer PyTorch code, torch.inference_mode() can be used for inference when tensors will not later participate in gradient computation:

model.eval()

with torch.inference_mode():
    pred = model(X)

Both tools prevent unnecessary graph construction during evaluation.

A Complete Example

The following example shows a full forward and backward computation.

import torch
from torch import nn

torch.manual_seed(0)

model = nn.Sequential(
    nn.Linear(4, 8),
    nn.ReLU(),
    nn.Linear(8, 1),
)

X = torch.randn(10, 4)
y = torch.randn(10, 1)

pred = model(X)
loss = ((pred - y) ** 2).mean()

print(loss)
print(loss.grad_fn)

model.zero_grad()
loss.backward()

for name, param in model.named_parameters():
    print(name, param.shape, param.grad.shape)

The forward pass constructs the graph. The loss is the final scalar node. The backward pass traverses the graph in reverse and fills the gradients of the model parameters.

The important point is that the user does not manually construct the graph. The graph is created by ordinary tensor operations.

Common Errors

A common error is trying to call backward() on a tensor that has no gradient graph:

x = torch.tensor(2.0)
y = x ** 2
y.backward()

This fails because x does not require gradients, so y also does not require gradients.

Correct version:

x = torch.tensor(2.0, requires_grad=True)
y = x ** 2
y.backward()

Another common error is forgetting to clear accumulated gradients:

loss.backward()
optimizer.step()

PyTorch accumulates gradients by default. Therefore a training loop should usually include:

optimizer.zero_grad()
loss.backward()
optimizer.step()

A third common error is accidentally detaching a tensor that should remain differentiable:

h = model.encoder(x)
h = h.detach()
y = model.decoder(h)

In this code, the decoder can receive gradients, but the encoder cannot receive gradients from y. This may be intentional, but it often happens by mistake.

Summary

A computational graph records how tensors are produced from other tensors. In PyTorch, the graph is built dynamically during the forward computation. If tensors require gradients, PyTorch records the operations needed to compute derivatives later.

The forward pass evaluates the graph from inputs to outputs. The backward pass traverses the graph in reverse and applies the chain rule. Model parameters are leaf tensors, and their gradients are stored after backward().

Understanding computational graphs makes PyTorch behavior predictable. It explains requires_grad, grad_fn, detach, no_grad, gradient accumulation, and the memory difference between training and inference.