Skip to content

Graph Representation

A graph representation makes the structure of a differentiated computation explicit. In reverse mode, this structure is required because the backward pass must know which...

A graph representation makes the structure of a differentiated computation explicit. In reverse mode, this structure is required because the backward pass must know which intermediate values depend on which earlier values. In compiler-based systems, the graph may also serve as an optimization target.

A computation graph is a directed graph whose nodes represent values or operations, and whose edges represent data dependencies.

For an expression such as:

z=xy+sin(x) z = x y + \sin(x)

the graph has inputs x and y, intermediate nodes for multiplication and sine, and an output node for addition.

x ─┬─> mul ─┐
   └─> sin ─┤
y ───> mul ─┘
            add -> z

The graph records more than arithmetic structure. It records the dependency order needed to propagate gradients correctly.

Values, Operations, and Edges

There are two common graph styles.

StyleNode meansEdge means
Value graphA computed valueDependency from parent value
Operation graphAn operationTensor or scalar flowing between operations

A small educational engine usually uses a value graph. Each node stores:

  • the computed value
  • the accumulated gradient
  • pointers to parent nodes
  • a backward function
type Node struct {
    Value float64
    Grad  float64

    Prev []*Node

    Op string

    Backward func()
}

The Op field is optional but useful for debugging and visualization.

func Var(x float64) *Node {
    return &Node{
        Value: x,
        Op:    "var",
    }
}

For a production system, nodes usually carry more metadata: shape, dtype, device, layout, aliasing information, and source-location information.

Local Graph Construction

Each primitive operation creates a new node.

func Add(a, b *Node) *Node {
    out := &Node{
        Value: a.Value + b.Value,
        Prev:  []*Node{a, b},
        Op:    "add",
    }

    out.Backward = func() {
        a.Grad += out.Grad
        b.Grad += out.Grad
    }

    return out
}

The graph is built incrementally as the user program runs.

x := Var(2)
y := Var(3)

z := Add(Mul(x, y), Sin(x))

This creates the graph dynamically. The program itself decides which operations run. Conditionals, loops, and recursion are handled by ordinary execution.

Parent Links and Child Links

A minimal reverse-mode engine only needs parent links:

Prev []*Node

The backward pass starts at the output, walks through parents, and builds a topological order.

Some systems also store child links:

Next []*Node

Child links make certain analyses easier, such as forward traversal, dead-code detection, or visualization. But they are not necessary for a minimal engine.

Parent links are enough because reverse mode starts from the output and walks backward.

Topological Order

A graph representation must support reverse traversal in a valid order.

If node b depends on node a, then b must run its backward rule before a.

a -> b -> output

The backward order is:

output, b, a

A depth-first topological sort is sufficient for a small engine.

func topo(n *Node, seen map[*Node]bool, order *[]*Node) {
    if seen[n] {
        return
    }

    seen[n] = true

    for _, p := range n.Prev {
        topo(p, seen, order)
    }

    *order = append(*order, n)
}

The resulting order lists dependencies before dependents. Reverse iteration gives the backward order.

func Backward(root *Node) {
    var order []*Node

    topo(root, map[*Node]bool{}, &order)

    root.Grad = 1

    for i := len(order) - 1; i >= 0; i-- {
        if order[i].Backward != nil {
            order[i].Backward()
        }
    }
}

This traversal works for directed acyclic computation graphs. Most scalar expression graphs created during one forward pass are acyclic.

Shared Subexpressions

Graph representation must handle sharing.

Consider:

x := Var(3)
a := Mul(x, x)
z := Add(a, a)

Mathematically:

z=2x2 z = 2x^2

The node a contributes to z through two edges. During backward propagation, its gradient must receive both contributions.

The Add rule handles this naturally:

func Add(a, b *Node) *Node {
    out := &Node{
        Value: a.Value + b.Value,
        Prev:  []*Node{a, b},
        Op:    "add",
    }

    out.Backward = func() {
        a.Grad += out.Grad
        b.Grad += out.Grad
    }

    return out
}

If a and b are the same pointer, then the same node receives two additions.

a.Grad += out.Grad
a.Grad += out.Grad

This gives the correct factor of two.

Shared subexpressions are one reason gradients must accumulate with +=.

Avoiding Duplicate Traversal

A graph may reuse the same node many times. The topological sort must avoid visiting it repeatedly.

seen map[*Node]bool

Without this, shared subgraphs cause repeated backward execution and incorrect gradients.

However, the backward rule itself must still account for multiple incoming edges. This distinction matters:

ConcernRequired behavior
Topological traversalVisit each node once
Gradient accumulationAdd every edge contribution

The graph traversal removes duplicate node visits. The local backward rules preserve duplicate edge contributions.

Operation Metadata

A useful node often records the operation that created it.

type OpKind string

const (
    OpVar OpKind = "var"
    OpAdd OpKind = "add"
    OpMul OpKind = "mul"
    OpSin OpKind = "sin"
)

Then:

type Node struct {
    Value float64
    Grad  float64
    Prev  []*Node
    Op    OpKind

    Backward func()
}

This helps with:

  • debugging
  • graph printing
  • testing
  • visualization
  • profiling
  • error messages

For example:

func Print(n *Node, indent string, seen map[*Node]bool) {
    if seen[n] {
        fmt.Printf("%s%s value=%g grad=%g [seen]\n", indent, n.Op, n.Value, n.Grad)
        return
    }

    seen[n] = true

    fmt.Printf("%s%s value=%g grad=%g\n", indent, n.Op, n.Value, n.Grad)

    for _, p := range n.Prev {
        Print(p, indent+"  ", seen)
    }
}

A graph printer is one of the best tools for validating a small AD engine.

Separating Data from Backward Logic

The simplest engine stores a closure on every node:

Backward func()

This is easy to write, but it has tradeoffs.

DesignBenefitCost
Closure per nodeSimple local implementationMore allocation, harder serialization
Op enum plus operandsCompact, inspectableRequires central backward dispatcher
Static IR nodeCompiler friendlyMore upfront design
Tape instructionEfficient append-only executionLess natural graph inspection

A closure-based node is ideal for a teaching engine. A production engine often uses operation records or IR instructions instead.

A dispatcher-based version looks like this:

type Node struct {
    Value float64
    Grad  float64
    Prev  []*Node
    Op    OpKind
}

func backward(n *Node) {
    switch n.Op {
    case OpAdd:
        a, b := n.Prev[0], n.Prev[1]
        a.Grad += n.Grad
        b.Grad += n.Grad

    case OpMul:
        a, b := n.Prev[0], n.Prev[1]
        a.Grad += b.Value * n.Grad
        b.Grad += a.Value * n.Grad

    case OpSin:
        x := n.Prev[0]
        x.Grad += math.Cos(x.Value) * n.Grad
    }
}

This representation makes the graph more explicit and easier to serialize.

Tensor Graphs

Scalar nodes are enough to explain reverse mode, but practical AD systems operate on tensors.

A tensor node stores:

type TensorNode struct {
    Value Tensor
    Grad  Tensor

    Shape []int
    DType DType

    Prev []*TensorNode
    Op   OpKind

    Backward func()
}

Tensor graph representation must handle shape semantics.

For example, broadcasting introduces reduction in the backward pass.

If:

z = x + b

where x has shape [batch, features] and b has shape [features], then the backward pass for b must sum over the batch dimension.

The graph must preserve enough information to reconstruct this rule.

A tensor graph therefore needs metadata beyond dependency edges:

  • original input shapes
  • broadcasted output shape
  • dtype
  • device
  • memory layout
  • aliasing and view information

Scalar AD hides these issues. Tensor AD exposes them immediately.

Graph Lifetime

A dynamic reverse-mode graph usually has the lifetime of one forward evaluation.

build graph -> run backward -> release graph

This is natural for eager systems.

Long-lived graphs require additional care:

  • clearing gradients
  • avoiding stale references
  • managing retained intermediates
  • deciding whether backward can run multiple times
  • freeing nodes after last use

A minimal engine can use a simple convention:

func ZeroGrad(root *Node) {
    var order []*Node

    topo(root, map[*Node]bool{}, &order)

    for _, n := range order {
        n.Grad = 0
    }
}

Before another backward pass, gradients should be reset unless accumulation is intentional.

Graph Representation and Control Flow

Dynamic graph construction handles control flow by recording the path actually taken.

func H(x *Node) *Node {
    if x.Value > 0 {
        return Mul(x, x)
    }

    return Mul(Const(-1), x)
}

For x > 0, the graph represents:

x2 x^2

For x <= 0, the graph represents:

x -x

The graph does not contain both branches unless the system explicitly traces both branches.

This behavior is simple and useful, but it means the derivative is local to the executed path. At branch boundaries, the mathematical function may be non-smooth or undefined.

Graph Representation Choices

RequirementGood representation
Teaching engineValue graph with closures
Debuggable eager ADValue graph with op metadata
SerializationOp enum plus operands
Compiler optimizationStatic IR
Low overhead reverse modeTape
Tensor compilerGraph IR with shape and layout
Distributed executionPartitionable operation graph

There is no single best representation. The right graph representation depends on the execution model.

Minimal Correctness Invariant

A graph representation for reverse mode must preserve this invariant:

Every node knows the values and parents needed to propagate its gradient contribution.

For an operation:

y=g(x1,,xk) y = g(x_1, \dots, x_k)

the backward step must compute:

xiˉ+=yˉgxi \bar{x_i} \mathrel{+}= \bar{y}\frac{\partial g}{\partial x_i}

for each parent xix_i.

The graph is correct if every node can perform this local update and the engine executes nodes in reverse topological order.

That is the whole mathematical requirement. Everything else is systems design.