An automatic differentiation engine becomes useful only after it supports a sufficiently rich set of primitive operations. The collection of these primitives is the operator...
An automatic differentiation engine becomes useful only after it supports a sufficiently rich set of primitive operations. The collection of these primitives is the operator library.
The operator library defines:
- which computations are differentiable
- how derivatives propagate
- what tensor semantics exist
- how shapes and dtypes behave
- which kernels execute on which devices
In a minimal engine, the operator library may contain fewer than ten scalar operations. In a production tensor system, it may contain thousands of operators.
Primitive Operations
A primitive operation is an operation whose derivative rule is implemented directly.
Examples:
- addition
- multiplication
- sine
- exponential
- matrix multiplication
- convolution
- reduction
- reshape
Everything else can be expressed as compositions of primitives.
For example:
does not require a dedicated cube operator:
x2 = mul(x, x)
x3 = mul(x2, x)The derivative emerges automatically from composition.
The operator library therefore defines the algebraic basis of the AD system.
Minimal Scalar Operator Set
A small scalar engine can begin with:
| Operator | Meaning |
|---|---|
| add | Addition |
| sub | Subtraction |
| mul | Multiplication |
| div | Division |
| sin | Sine |
| cos | Cosine |
| exp | Exponential |
| log | Natural logarithm |
A tape representation:
type OpKind uint8
const (
OpAdd OpKind = iota
OpSub
OpMul
OpDiv
OpSin
OpCos
OpExp
OpLog
)Each operation needs:
- forward evaluation
- backward rule
Addition:
func (t *Tape) Add(a, b Slot) Slot {
out := t.alloc(
t.Values[a] + t.Values[b],
t.RequiresGrad[a] || t.RequiresGrad[b],
)
if t.RequiresGrad[out] {
t.Instrs = append(t.Instrs, Instr{
Op: OpAdd,
Out: out,
A: a,
B: b,
})
}
return out
}Backward:
case OpAdd:
t.Grads[ins.A] += t.Grads[ins.Out]
t.Grads[ins.B] += t.Grads[ins.Out]This pattern repeats for every primitive.
Local Derivative Rules
Every operator implements a local Jacobian action.
For scalar operations:
the backward rule computes:
The operator library therefore contains executable forms of derivative identities.
Multiplication:
gives:
Division:
gives:
The engine itself does not understand calculus globally. Each primitive contributes one local rule.
Forward and Backward Registration
A useful operator abstraction separates operation metadata from execution.
type Operator struct {
Name string
Forward func(*Tape, []Slot) Slot
Backward func(*Tape, Instr)
}Registration:
var Ops = map[OpKind]Operator{}Example:
Ops[OpAdd] = Operator{
Name: "add",
Backward: func(t *Tape, ins Instr) {
t.Grads[ins.A] += t.Grads[ins.Out]
t.Grads[ins.B] += t.Grads[ins.Out]
},
}This structure:
- centralizes derivative rules
- simplifies debugging
- supports reflection and serialization
- decouples graph representation from derivative logic
A closure-per-node design is simpler for teaching. An operator table scales better.
Tensor Operators
Scalar operators are easy because shapes are trivial. Tensor operators require shape semantics.
Tensor addition:
[a, b, c] + [x, y, z]propagates gradients elementwise.
Matrix multiplication:
has backward rules:
A tensor operator therefore needs:
- shape information
- layout information
- dtype information
- broadcasting rules
- reduction semantics
A tensor node:
type TensorNode struct {
Value Tensor
Grad Tensor
Shape []int
DType DType
Prev []*TensorNode
Op OpKind
}The operator library becomes partly a tensor algebra system.
Broadcasting Semantics
Broadcasting complicates backward propagation.
Example:
x shape: [batch, features]
b shape: [features]
y = x + bForward broadcasting duplicates b conceptually across the batch dimension.
Backward propagation must reverse that duplication:
grad_b = reduce_sum(grad_y, axis=batch)The backward rule therefore depends on:
- original input shapes
- broadcasted output shape
The operator library must preserve this metadata.
A minimal tensor engine should store original operand shapes inside the instruction.
type Instr struct {
Op OpKind
Out Slot
A Slot
B Slot
ShapeA []int
ShapeB []int
}Shape Transformations
Shape-changing operators often have trivial derivatives mathematically but important systems behavior.
Reshape:
y = reshape(x)Backward:
grad_x = reshape(grad_y, original_shape)Transpose:
y = transpose(x)Backward:
grad_x = transpose(grad_y)Slicing:
y = x[2:5]Backward scatters the gradient into the original tensor shape.
These operators demonstrate an important principle:
The backward rule must reverse the data movement semantics of the forward rule.Reduction Operators
Reductions shrink dimensions.
Example:
y = sum(x)Backward:
grad_x = broadcast(grad_y)Mean:
y = mean(x)Backward:
grad_x = broadcast(grad_y / n)Maximum is harder.
y = max(x)Backward must identify which elements achieved the maximum.
The forward pass therefore often stores an argmax mask or index set.
Operator design is tightly connected to saved intermediate state.
Numerical Stability
The operator library must consider stable derivatives.
Naive softmax:
is numerically unstable for large values.
Stable implementation:
then:
The backward rule should use the stable forward result rather than recomputing unstable exponentials.
Many practical AD bugs are numerical rather than symbolic.
Fused Operators
A fused operator combines several operations into one primitive.
Example:
y = relu(matmul(x, w) + b)Possible representations:
| Strategy | Representation |
|---|---|
| Unfused | matmul → add → relu |
| Fused | single kernel |
Fusion can reduce:
- memory traffic
- temporary allocations
- kernel launch overhead
But fusion complicates:
- debugging
- intermediate inspection
- operator reuse
A minimal engine should start unfused. Fusion belongs in a compiler or optimization layer.
Operator Purity
Differentiable operators should ideally behave as pure functions:
same inputs -> same outputsPurity simplifies:
- caching
- graph replay
- checkpointing
- common subexpression elimination
- compiler optimization
Stateful operators complicate reverse mode.
Random number generation is a common example.
y = dropout(x)Backward requires replaying the same random mask used during forward.
The operator must therefore:
- save the mask
- or save the RNG state
The operator library defines the system’s state semantics as much as its derivative semantics.
Custom Operators
Users often need operators outside the built-in library.
A useful engine exposes registration:
func Register(
name string,
forward func(...Tensor) Tensor,
backward func(...Tensor) []Tensor,
)The user supplies both:
- primal computation
- gradient computation
This is important for:
- optimized kernels
- domain-specific operations
- interoperability
- experimental layers
A custom operator is essentially a manually supplied local derivative rule.
Operator Validation
Every operator should be testable independently.
A common strategy:
- choose random inputs
- compute AD gradient
- compare against finite differences
Finite difference approximation:
Example test:
func CheckGrad(
f func(float64) float64,
g func(float64) float64,
x float64,
) {
eps := 1e-6
fd := (f(x+eps) - f(x-eps)) / (2 * eps)
ad := g(x)
fmt.Println(fd, ad)
}This catches:
- sign errors
- missing reductions
- broadcasting mistakes
- shape mismatches
- unstable formulas
Operator testing is one of the highest leverage activities in AD implementation.
Minimal Operator Interface
A compact operator interface for a small tensor engine:
type Operator interface {
Forward(*Tape, []Slot) Slot
Backward(*Tape, Instr)
}Or more explicit:
type Operator struct {
Name string
Arity int
Forward func(*Tape, []Slot) Slot
Backward func(*Tape, Instr)
}This separates:
- graph structure
- storage
- execution
- differentiation logic
The operator library becomes the semantic core of the AD engine.
Minimal Correctness Invariant
Every operator must satisfy:
Its backward rule computes the transpose Jacobian action of the forward rule.If:
then the backward rule computes:
This is the central invariant of reverse mode.
Everything else in the operator library:
- tensor shapes
- broadcasting
- memory reuse
- device kernels
- fusion
- dtype promotion
exists to make this invariant efficient and practical at scale.