A graph representation makes the structure of a differentiated computation explicit. In reverse mode, this structure is required because the backward pass must know which...
A graph representation makes the structure of a differentiated computation explicit. In reverse mode, this structure is required because the backward pass must know which intermediate values depend on which earlier values. In compiler-based systems, the graph may also serve as an optimization target.
A computation graph is a directed graph whose nodes represent values or operations, and whose edges represent data dependencies.
For an expression such as:
the graph has inputs x and y, intermediate nodes for multiplication and sine, and an output node for addition.
x ─┬─> mul ─┐
└─> sin ─┤
y ───> mul ─┘
add -> zThe graph records more than arithmetic structure. It records the dependency order needed to propagate gradients correctly.
Values, Operations, and Edges
There are two common graph styles.
| Style | Node means | Edge means |
|---|---|---|
| Value graph | A computed value | Dependency from parent value |
| Operation graph | An operation | Tensor or scalar flowing between operations |
A small educational engine usually uses a value graph. Each node stores:
- the computed value
- the accumulated gradient
- pointers to parent nodes
- a backward function
type Node struct {
Value float64
Grad float64
Prev []*Node
Op string
Backward func()
}The Op field is optional but useful for debugging and visualization.
func Var(x float64) *Node {
return &Node{
Value: x,
Op: "var",
}
}For a production system, nodes usually carry more metadata: shape, dtype, device, layout, aliasing information, and source-location information.
Local Graph Construction
Each primitive operation creates a new node.
func Add(a, b *Node) *Node {
out := &Node{
Value: a.Value + b.Value,
Prev: []*Node{a, b},
Op: "add",
}
out.Backward = func() {
a.Grad += out.Grad
b.Grad += out.Grad
}
return out
}The graph is built incrementally as the user program runs.
x := Var(2)
y := Var(3)
z := Add(Mul(x, y), Sin(x))This creates the graph dynamically. The program itself decides which operations run. Conditionals, loops, and recursion are handled by ordinary execution.
Parent Links and Child Links
A minimal reverse-mode engine only needs parent links:
Prev []*NodeThe backward pass starts at the output, walks through parents, and builds a topological order.
Some systems also store child links:
Next []*NodeChild links make certain analyses easier, such as forward traversal, dead-code detection, or visualization. But they are not necessary for a minimal engine.
Parent links are enough because reverse mode starts from the output and walks backward.
Topological Order
A graph representation must support reverse traversal in a valid order.
If node b depends on node a, then b must run its backward rule before a.
a -> b -> outputThe backward order is:
output, b, aA depth-first topological sort is sufficient for a small engine.
func topo(n *Node, seen map[*Node]bool, order *[]*Node) {
if seen[n] {
return
}
seen[n] = true
for _, p := range n.Prev {
topo(p, seen, order)
}
*order = append(*order, n)
}The resulting order lists dependencies before dependents. Reverse iteration gives the backward order.
func Backward(root *Node) {
var order []*Node
topo(root, map[*Node]bool{}, &order)
root.Grad = 1
for i := len(order) - 1; i >= 0; i-- {
if order[i].Backward != nil {
order[i].Backward()
}
}
}This traversal works for directed acyclic computation graphs. Most scalar expression graphs created during one forward pass are acyclic.
Shared Subexpressions
Graph representation must handle sharing.
Consider:
x := Var(3)
a := Mul(x, x)
z := Add(a, a)Mathematically:
The node a contributes to z through two edges. During backward propagation, its gradient must receive both contributions.
The Add rule handles this naturally:
func Add(a, b *Node) *Node {
out := &Node{
Value: a.Value + b.Value,
Prev: []*Node{a, b},
Op: "add",
}
out.Backward = func() {
a.Grad += out.Grad
b.Grad += out.Grad
}
return out
}If a and b are the same pointer, then the same node receives two additions.
a.Grad += out.Grad
a.Grad += out.GradThis gives the correct factor of two.
Shared subexpressions are one reason gradients must accumulate with +=.
Avoiding Duplicate Traversal
A graph may reuse the same node many times. The topological sort must avoid visiting it repeatedly.
seen map[*Node]boolWithout this, shared subgraphs cause repeated backward execution and incorrect gradients.
However, the backward rule itself must still account for multiple incoming edges. This distinction matters:
| Concern | Required behavior |
|---|---|
| Topological traversal | Visit each node once |
| Gradient accumulation | Add every edge contribution |
The graph traversal removes duplicate node visits. The local backward rules preserve duplicate edge contributions.
Operation Metadata
A useful node often records the operation that created it.
type OpKind string
const (
OpVar OpKind = "var"
OpAdd OpKind = "add"
OpMul OpKind = "mul"
OpSin OpKind = "sin"
)Then:
type Node struct {
Value float64
Grad float64
Prev []*Node
Op OpKind
Backward func()
}This helps with:
- debugging
- graph printing
- testing
- visualization
- profiling
- error messages
For example:
func Print(n *Node, indent string, seen map[*Node]bool) {
if seen[n] {
fmt.Printf("%s%s value=%g grad=%g [seen]\n", indent, n.Op, n.Value, n.Grad)
return
}
seen[n] = true
fmt.Printf("%s%s value=%g grad=%g\n", indent, n.Op, n.Value, n.Grad)
for _, p := range n.Prev {
Print(p, indent+" ", seen)
}
}A graph printer is one of the best tools for validating a small AD engine.
Separating Data from Backward Logic
The simplest engine stores a closure on every node:
Backward func()This is easy to write, but it has tradeoffs.
| Design | Benefit | Cost |
|---|---|---|
| Closure per node | Simple local implementation | More allocation, harder serialization |
| Op enum plus operands | Compact, inspectable | Requires central backward dispatcher |
| Static IR node | Compiler friendly | More upfront design |
| Tape instruction | Efficient append-only execution | Less natural graph inspection |
A closure-based node is ideal for a teaching engine. A production engine often uses operation records or IR instructions instead.
A dispatcher-based version looks like this:
type Node struct {
Value float64
Grad float64
Prev []*Node
Op OpKind
}
func backward(n *Node) {
switch n.Op {
case OpAdd:
a, b := n.Prev[0], n.Prev[1]
a.Grad += n.Grad
b.Grad += n.Grad
case OpMul:
a, b := n.Prev[0], n.Prev[1]
a.Grad += b.Value * n.Grad
b.Grad += a.Value * n.Grad
case OpSin:
x := n.Prev[0]
x.Grad += math.Cos(x.Value) * n.Grad
}
}This representation makes the graph more explicit and easier to serialize.
Tensor Graphs
Scalar nodes are enough to explain reverse mode, but practical AD systems operate on tensors.
A tensor node stores:
type TensorNode struct {
Value Tensor
Grad Tensor
Shape []int
DType DType
Prev []*TensorNode
Op OpKind
Backward func()
}Tensor graph representation must handle shape semantics.
For example, broadcasting introduces reduction in the backward pass.
If:
z = x + bwhere x has shape [batch, features] and b has shape [features], then the backward pass for b must sum over the batch dimension.
The graph must preserve enough information to reconstruct this rule.
A tensor graph therefore needs metadata beyond dependency edges:
- original input shapes
- broadcasted output shape
- dtype
- device
- memory layout
- aliasing and view information
Scalar AD hides these issues. Tensor AD exposes them immediately.
Graph Lifetime
A dynamic reverse-mode graph usually has the lifetime of one forward evaluation.
build graph -> run backward -> release graphThis is natural for eager systems.
Long-lived graphs require additional care:
- clearing gradients
- avoiding stale references
- managing retained intermediates
- deciding whether backward can run multiple times
- freeing nodes after last use
A minimal engine can use a simple convention:
func ZeroGrad(root *Node) {
var order []*Node
topo(root, map[*Node]bool{}, &order)
for _, n := range order {
n.Grad = 0
}
}Before another backward pass, gradients should be reset unless accumulation is intentional.
Graph Representation and Control Flow
Dynamic graph construction handles control flow by recording the path actually taken.
func H(x *Node) *Node {
if x.Value > 0 {
return Mul(x, x)
}
return Mul(Const(-1), x)
}For x > 0, the graph represents:
For x <= 0, the graph represents:
The graph does not contain both branches unless the system explicitly traces both branches.
This behavior is simple and useful, but it means the derivative is local to the executed path. At branch boundaries, the mathematical function may be non-smooth or undefined.
Graph Representation Choices
| Requirement | Good representation |
|---|---|
| Teaching engine | Value graph with closures |
| Debuggable eager AD | Value graph with op metadata |
| Serialization | Op enum plus operands |
| Compiler optimization | Static IR |
| Low overhead reverse mode | Tape |
| Tensor compiler | Graph IR with shape and layout |
| Distributed execution | Partitionable operation graph |
There is no single best representation. The right graph representation depends on the execution model.
Minimal Correctness Invariant
A graph representation for reverse mode must preserve this invariant:
Every node knows the values and parents needed to propagate its gradient contribution.For an operation:
the backward step must compute:
for each parent .
The graph is correct if every node can perform this local update and the engine executes nodes in reverse topological order.
That is the whole mathematical requirement. Everything else is systems design.