Staging is the separation of a program into phases.
In an automatic differentiation system, the most common phases are:
trace time
compile time
run timeA staged system does not execute every part of the user program in the same way. Some parts run immediately in the host language. Other parts are captured into an intermediate representation. That captured representation may then be differentiated, optimized, compiled, cached, and executed later.
This is one of the main reasons modern AD systems can combine a friendly programming interface with compiler-grade performance.
The Basic Idea
Consider a function:
def f(x):
y = x * x
return sin(y)An eager system evaluates the function directly every time.
A staged system may first capture the computation:
x0 = input
x1 = mul x0 x0
x2 = sin x1
return x2Then it may transform the captured program:
x0, dx0 = input
x1 = mul x0 x0
dx1 = add (mul dx0 x0), (mul x0 dx0)
x2 = sin x1
dx2 = mul (cos x1), dx1
return x2, dx2Finally, it may compile this transformed program into efficient executable code.
The user sees a function call. Internally, the system has separated the call into stages.
Why Staging Matters for AD
Automatic differentiation benefits from staging because derivative computation is often more useful as a reusable program than as a one-time side effect.
For example:
grad(f)can be treated as a new function. That function can be compiled, optimized, cached, and called many times.
Without staging, the AD system may recompute transformation decisions on every call. With staging, it can produce a derivative program once and reuse it.
A typical staged AD pipeline is:
user function
-> specialize
-> trace or lower to IR
-> differentiate
-> optimize
-> compile
-> executeEach arrow separates one concern from the next.
Static and Dynamic Parts
Partial evaluation is the process of executing the parts of a program that are known early, while leaving the unknown parts for later.
Suppose a function has one static argument and one dynamic argument:
def power(x, n):
y = 1
for i in range(n):
y = y * x
return yIf n is known during tracing, the loop can be unrolled.
For n = 3, the staged program is:
x0 = input
x1 = mul 1 x0
x2 = mul x1 x0
x3 = mul x2 x0
return x3The argument n was consumed at trace time. The argument x remains a runtime input.
This is partial evaluation: run what can be run now, generate code for what must run later.
Specialization
Staging usually specializes code to some properties of the input.
Common specialization parameters include:
| Property | Example |
|---|---|
| Dtype | float32, float64, int32 |
| Shape | [32, 128], [B, 768] |
| Rank | scalar, vector, matrix |
| Static arguments | loop counts, flags, algorithm choices |
| Device | CPU, GPU, TPU |
| Layout | row-major, column-major, blocked |
Specialization gives the compiler more information. It can choose kernels, allocate buffers, remove branches, and simplify derivative code.
The cost is recompilation. When a specialized property changes, the system may need a new staged program.
Example: Static Branch
Consider:
def f(x, normalize):
if normalize:
x = x / norm(x)
return x * xIf normalize is static, the staged system creates one program for normalize = True and another for normalize = False.
For True:
x1 = norm x0
x2 = div x0 x1
x3 = mul x2 x2
return x3For False:
x3 = mul x0 x0
return x3The branch disappears from each staged program. The derivative transform only sees the selected version.
This gives efficient code, but it also means the static argument is part of the compilation cache key.
Example: Dynamic Branch
If the branch depends on a runtime value, it cannot be resolved during staging.
def f(x):
if sum(x) > 0:
return x * x
else:
return -xIf sum(x) > 0 is not known until runtime, the staged program must contain an explicit control-flow operation:
p = gt (reduce_sum x), 0
y = cond p then_graph else_graph x
return yThe AD system must differentiate the cond operation as a control-flow primitive.
Forward mode differentiates the executed branch. Reverse mode must send adjoints back through the executed branch.
Partial Evaluation and AD
Partial evaluation can happen before or after differentiation.
If partial evaluation happens before AD, the derivative transform sees a simplified primal program.
f -> partially evaluate -> simplified f -> differentiateIf it happens after AD, the optimizer simplifies both primal and derivative code.
f -> differentiate -> derivative program -> partially evaluateIn practice, systems often do both. Early partial evaluation removes unnecessary structure. Later partial evaluation cleans up derivative artifacts.
For example, if a tangent seed is known to be zero, forward-mode code can remove any tangent computation that depends only on that zero seed.
Residuals
When partial evaluation cannot compute a value immediately, but a later stage will need it, the value becomes a residual.
For reverse-mode AD, residuals are especially important.
A forward pass may compute:
y = sin(x)The backward pass needs x or sometimes y to compute:
bar_x += cos(x) * bar_yIf the backward pass is staged as a separate program, the needed primal value must be saved as a residual.
A staged reverse-mode transform often produces:
forward:
y = sin(x)
return y, residuals(x)
backward:
x = residuals.x
bar_x = cos(x) * bar_y
return bar_xResidual management is a central compiler problem. Too many residuals cause high memory use. Too few residuals cause recomputation or incorrect gradients.
Staged Reverse Mode
Reverse-mode staging often splits computation into two generated functions:
primal function
pullback functionThe primal function computes the output and packages residuals.
(y, res) = f_primal(x)The pullback function consumes output adjoints and residuals.
bar_x = f_pullback(res, bar_y)This design is clean because it separates forward execution from backward execution.
It also matches the mathematical idea of a pullback: given a cotangent at the output, produce a cotangent at the input.
Compile-Time Constants
Partial evaluation can turn compile-time constants into simpler derivative code.
Consider:
def f(x):
return 3.0 * xForward mode gives:
y = 3.0 * x
dy = 3.0 * dx + 0.0 * xAfter partial evaluation and simplification:
y = 3.0 * x
dy = 3.0 * dxReverse mode gives:
bar_x += 3.0 * bar_yThe constant has no adjoint because it is not a differentiable input.
This seems minor, but in large programs, removing inactive constants, static configuration, and zero tangents makes derivative code much smaller.
Staging and Performance
Staging improves performance by moving work out of the runtime path.
The system can perform:
shape inference
type checking
activity analysis
AD transformation
memory planning
kernel selection
fusion
code generationbefore repeated execution.
This is especially valuable for training loops, simulations, batched inference, and optimization problems where the same computation structure runs many times.
The first call may be expensive because it traces and compiles. Later calls can be fast because they reuse the staged executable.
Staging and Cache Keys
A staged system usually caches compiled programs.
A cache key may include:
function identity
static argument values
input dtypes
input shapes
device
compiler options
AD transform stackThe transform stack matters. These may produce different staged programs:
jit(f)
grad(f)
jit(grad(f))
grad(jit(f))
vmap(grad(f))A correct system must distinguish them.
A poor cache key causes either unnecessary recompilation or, worse, reuse of a program under invalid assumptions.
Staging Boundaries
A staging boundary is the line between host-language execution and captured computation.
Inside the boundary, operations become IR nodes. Outside the boundary, operations run normally.
For example:
@jit
def f(x):
y = x * x
return yThe body of f is staged. Code calling f may remain ordinary host-language code.
The boundary determines what the AD system can see. If an operation occurs outside the boundary, the compiler cannot optimize or differentiate through it unless it is represented as an explicit primitive.
Host Language Effects
Staging changes the meaning of side effects.
Consider:
def f(x):
log.append(x)
return x * xIf f is staged, the append may happen at trace time, not each runtime execution. This is usually not what the programmer expects.
For this reason, staged AD systems often restrict side effects inside staged functions or require effectful operations to be represented explicitly in the IR.
The same issue appears with printing, file I/O, mutation of global state, object identity, and exceptions.
Staging and Randomness
Randomness must not be hidden in the host language.
Bad staged design:
def f(x):
r = random()
return x + rThe random value may be sampled once during tracing and baked into the compiled program.
A better design passes randomness explicitly:
r, key2 = random_uniform(key1)
y = x + r
return y, key2The random key becomes part of the staged computation. This makes randomness reproducible and visible to differentiation and compilation.
Shape Polymorphic Staging
Basic staging specializes to exact shapes. Shape polymorphic staging allows some dimensions to remain symbolic.
Instead of compiling for:
tensor<f32, [32, 768]>the system compiles for:
tensor<f32, [B, 768]>This avoids recompilation for different batch sizes.
Shape polymorphism is useful but hard. The compiler must generate code that handles symbolic dimensions, dynamic memory sizes, and shape-dependent operations. AD rules must also preserve symbolic shapes correctly.
Multi-Stage AD
AD transformations can be nested with staging transformations.
Examples:
jit(grad(f))
grad(jit(f))
vmap(jit(grad(f)))
jit(vmap(grad(f)))These are not always equivalent operationally.
jit(grad(f)) differentiates first, then stages the gradient function.
grad(jit(f)) stages the primal function, then differentiates through the staged call if the system supports it.
Mathematically, both may compute the same derivative. Operationally, they can differ in compilation boundaries, residual storage, cache behavior, and optimization opportunities.
A well-designed AD system defines transformation composition precisely.
Staging and Custom Primitives
Custom primitives need staging rules.
A primitive may provide:
abstract evaluation rule
lowering rule
JVP rule
VJP rule
batching rule
effect ruleThe abstract evaluation rule tells the staging system the output type and shape. The lowering rule tells the compiler how to implement the primitive. The JVP and VJP rules tell AD how to differentiate it.
Without these rules, the primitive becomes a black box. The system may be able to execute it eagerly but unable to stage or differentiate it.
Failure Modes
Staging introduces predictable failure modes.
A common failure is trying to use a staged value where the host language requires a concrete value.
def f(x):
return [0] * int(x)If x is dynamic, int(x) cannot be known at trace time.
Another failure is accidental static capture.
scale = 2
@jit
def f(x):
return scale * xIf scale changes later, the compiled function may still use the old value unless the system tracks it as part of the cache key or treats it as an explicit input.
A third failure is excessive recompilation due to changing shapes or static arguments.
These failures are not bugs in the idea of staging. They are consequences of phase separation.
Design Rules
A staged AD system should make the following explicit:
Which values are static?
Which values are dynamic?
Which operations are captured?
Which effects are allowed?
What determines the cache key?
Which control-flow constructs are staged?
How are residuals represented?
How do AD, batching, and compilation compose?Users can reason about staged systems when these rules are stable and visible.
Summary
Staging separates program construction from program execution. Partial evaluation executes known parts early and leaves unknown parts as generated code.
For automatic differentiation, this separation is powerful. It lets the system capture numerical programs, transform them into derivative programs, optimize them, compile them, and reuse them.
The price is a stricter execution model. Some values exist at trace time, others at run time. Some arguments are static, others dynamic. Side effects and randomness must be explicit. Control flow depending on dynamic values must be represented in the staged IR.
Staging makes AD less like a runtime trick and more like a compiler pipeline.