# Staging and Partial Evaluation

## Staging and Partial Evaluation

Staging is the separation of a program into phases.

In an automatic differentiation system, the most common phases are:

```text
trace time
compile time
run time
```

A staged system does not execute every part of the user program in the same way. Some parts run immediately in the host language. Other parts are captured into an intermediate representation. That captured representation may then be differentiated, optimized, compiled, cached, and executed later.

This is one of the main reasons modern AD systems can combine a friendly programming interface with compiler-grade performance.

### The Basic Idea

Consider a function:

```python
def f(x):
    y = x * x
    return sin(y)
```

An eager system evaluates the function directly every time.

A staged system may first capture the computation:

```text
x0 = input
x1 = mul x0 x0
x2 = sin x1
return x2
```

Then it may transform the captured program:

```text
x0, dx0 = input
x1 = mul x0 x0
dx1 = add (mul dx0 x0), (mul x0 dx0)
x2 = sin x1
dx2 = mul (cos x1), dx1
return x2, dx2
```

Finally, it may compile this transformed program into efficient executable code.

The user sees a function call. Internally, the system has separated the call into stages.

### Why Staging Matters for AD

Automatic differentiation benefits from staging because derivative computation is often more useful as a reusable program than as a one-time side effect.

For example:

```text
grad(f)
```

can be treated as a new function. That function can be compiled, optimized, cached, and called many times.

Without staging, the AD system may recompute transformation decisions on every call. With staging, it can produce a derivative program once and reuse it.

A typical staged AD pipeline is:

```text
user function
    -> specialize
    -> trace or lower to IR
    -> differentiate
    -> optimize
    -> compile
    -> execute
```

Each arrow separates one concern from the next.

### Static and Dynamic Parts

Partial evaluation is the process of executing the parts of a program that are known early, while leaving the unknown parts for later.

Suppose a function has one static argument and one dynamic argument:

```python
def power(x, n):
    y = 1
    for i in range(n):
        y = y * x
    return y
```

If `n` is known during tracing, the loop can be unrolled.

For `n = 3`, the staged program is:

```text
x0 = input
x1 = mul 1 x0
x2 = mul x1 x0
x3 = mul x2 x0
return x3
```

The argument `n` was consumed at trace time. The argument `x` remains a runtime input.

This is partial evaluation: run what can be run now, generate code for what must run later.

### Specialization

Staging usually specializes code to some properties of the input.

Common specialization parameters include:

| Property | Example |
|---|---|
| Dtype | `float32`, `float64`, `int32` |
| Shape | `[32, 128]`, `[B, 768]` |
| Rank | scalar, vector, matrix |
| Static arguments | loop counts, flags, algorithm choices |
| Device | CPU, GPU, TPU |
| Layout | row-major, column-major, blocked |

Specialization gives the compiler more information. It can choose kernels, allocate buffers, remove branches, and simplify derivative code.

The cost is recompilation. When a specialized property changes, the system may need a new staged program.

### Example: Static Branch

Consider:

```python
def f(x, normalize):
    if normalize:
        x = x / norm(x)
    return x * x
```

If `normalize` is static, the staged system creates one program for `normalize = True` and another for `normalize = False`.

For `True`:

```text
x1 = norm x0
x2 = div x0 x1
x3 = mul x2 x2
return x3
```

For `False`:

```text
x3 = mul x0 x0
return x3
```

The branch disappears from each staged program. The derivative transform only sees the selected version.

This gives efficient code, but it also means the static argument is part of the compilation cache key.

### Example: Dynamic Branch

If the branch depends on a runtime value, it cannot be resolved during staging.

```python
def f(x):
    if sum(x) > 0:
        return x * x
    else:
        return -x
```

If `sum(x) > 0` is not known until runtime, the staged program must contain an explicit control-flow operation:

```text
p = gt (reduce_sum x), 0
y = cond p then_graph else_graph x
return y
```

The AD system must differentiate the `cond` operation as a control-flow primitive.

Forward mode differentiates the executed branch. Reverse mode must send adjoints back through the executed branch.

### Partial Evaluation and AD

Partial evaluation can happen before or after differentiation.

If partial evaluation happens before AD, the derivative transform sees a simplified primal program.

```text
f -> partially evaluate -> simplified f -> differentiate
```

If it happens after AD, the optimizer simplifies both primal and derivative code.

```text
f -> differentiate -> derivative program -> partially evaluate
```

In practice, systems often do both. Early partial evaluation removes unnecessary structure. Later partial evaluation cleans up derivative artifacts.

For example, if a tangent seed is known to be zero, forward-mode code can remove any tangent computation that depends only on that zero seed.

### Residuals

When partial evaluation cannot compute a value immediately, but a later stage will need it, the value becomes a residual.

For reverse-mode AD, residuals are especially important.

A forward pass may compute:

```text
y = sin(x)
```

The backward pass needs `x` or sometimes `y` to compute:

```text
bar_x += cos(x) * bar_y
```

If the backward pass is staged as a separate program, the needed primal value must be saved as a residual.

A staged reverse-mode transform often produces:

```text
forward:
    y = sin(x)
    return y, residuals(x)

backward:
    x = residuals.x
    bar_x = cos(x) * bar_y
    return bar_x
```

Residual management is a central compiler problem. Too many residuals cause high memory use. Too few residuals cause recomputation or incorrect gradients.

### Staged Reverse Mode

Reverse-mode staging often splits computation into two generated functions:

```text
primal function
pullback function
```

The primal function computes the output and packages residuals.

```text
(y, res) = f_primal(x)
```

The pullback function consumes output adjoints and residuals.

```text
bar_x = f_pullback(res, bar_y)
```

This design is clean because it separates forward execution from backward execution.

It also matches the mathematical idea of a pullback: given a cotangent at the output, produce a cotangent at the input.

### Compile-Time Constants

Partial evaluation can turn compile-time constants into simpler derivative code.

Consider:

```python
def f(x):
    return 3.0 * x
```

Forward mode gives:

```text
y = 3.0 * x
dy = 3.0 * dx + 0.0 * x
```

After partial evaluation and simplification:

```text
y = 3.0 * x
dy = 3.0 * dx
```

Reverse mode gives:

```text
bar_x += 3.0 * bar_y
```

The constant has no adjoint because it is not a differentiable input.

This seems minor, but in large programs, removing inactive constants, static configuration, and zero tangents makes derivative code much smaller.

### Staging and Performance

Staging improves performance by moving work out of the runtime path.

The system can perform:

```text
shape inference
type checking
activity analysis
AD transformation
memory planning
kernel selection
fusion
code generation
```

before repeated execution.

This is especially valuable for training loops, simulations, batched inference, and optimization problems where the same computation structure runs many times.

The first call may be expensive because it traces and compiles. Later calls can be fast because they reuse the staged executable.

### Staging and Cache Keys

A staged system usually caches compiled programs.

A cache key may include:

```text
function identity
static argument values
input dtypes
input shapes
device
compiler options
AD transform stack
```

The transform stack matters. These may produce different staged programs:

```text
jit(f)
grad(f)
jit(grad(f))
grad(jit(f))
vmap(grad(f))
```

A correct system must distinguish them.

A poor cache key causes either unnecessary recompilation or, worse, reuse of a program under invalid assumptions.

### Staging Boundaries

A staging boundary is the line between host-language execution and captured computation.

Inside the boundary, operations become IR nodes. Outside the boundary, operations run normally.

For example:

```python
@jit
def f(x):
    y = x * x
    return y
```

The body of `f` is staged. Code calling `f` may remain ordinary host-language code.

The boundary determines what the AD system can see. If an operation occurs outside the boundary, the compiler cannot optimize or differentiate through it unless it is represented as an explicit primitive.

### Host Language Effects

Staging changes the meaning of side effects.

Consider:

```python
def f(x):
    log.append(x)
    return x * x
```

If `f` is staged, the append may happen at trace time, not each runtime execution. This is usually not what the programmer expects.

For this reason, staged AD systems often restrict side effects inside staged functions or require effectful operations to be represented explicitly in the IR.

The same issue appears with printing, file I/O, mutation of global state, object identity, and exceptions.

### Staging and Randomness

Randomness must not be hidden in the host language.

Bad staged design:

```python
def f(x):
    r = random()
    return x + r
```

The random value may be sampled once during tracing and baked into the compiled program.

A better design passes randomness explicitly:

```text
r, key2 = random_uniform(key1)
y = x + r
return y, key2
```

The random key becomes part of the staged computation. This makes randomness reproducible and visible to differentiation and compilation.

### Shape Polymorphic Staging

Basic staging specializes to exact shapes. Shape polymorphic staging allows some dimensions to remain symbolic.

Instead of compiling for:

```text
tensor<f32, [32, 768]>
```

the system compiles for:

```text
tensor<f32, [B, 768]>
```

This avoids recompilation for different batch sizes.

Shape polymorphism is useful but hard. The compiler must generate code that handles symbolic dimensions, dynamic memory sizes, and shape-dependent operations. AD rules must also preserve symbolic shapes correctly.

### Multi-Stage AD

AD transformations can be nested with staging transformations.

Examples:

```text
jit(grad(f))
grad(jit(f))
vmap(jit(grad(f)))
jit(vmap(grad(f)))
```

These are not always equivalent operationally.

`jit(grad(f))` differentiates first, then stages the gradient function.

`grad(jit(f))` stages the primal function, then differentiates through the staged call if the system supports it.

Mathematically, both may compute the same derivative. Operationally, they can differ in compilation boundaries, residual storage, cache behavior, and optimization opportunities.

A well-designed AD system defines transformation composition precisely.

### Staging and Custom Primitives

Custom primitives need staging rules.

A primitive may provide:

```text
abstract evaluation rule
lowering rule
JVP rule
VJP rule
batching rule
effect rule
```

The abstract evaluation rule tells the staging system the output type and shape. The lowering rule tells the compiler how to implement the primitive. The JVP and VJP rules tell AD how to differentiate it.

Without these rules, the primitive becomes a black box. The system may be able to execute it eagerly but unable to stage or differentiate it.

### Failure Modes

Staging introduces predictable failure modes.

A common failure is trying to use a staged value where the host language requires a concrete value.

```python
def f(x):
    return [0] * int(x)
```

If `x` is dynamic, `int(x)` cannot be known at trace time.

Another failure is accidental static capture.

```python
scale = 2

@jit
def f(x):
    return scale * x
```

If `scale` changes later, the compiled function may still use the old value unless the system tracks it as part of the cache key or treats it as an explicit input.

A third failure is excessive recompilation due to changing shapes or static arguments.

These failures are not bugs in the idea of staging. They are consequences of phase separation.

### Design Rules

A staged AD system should make the following explicit:

```text
Which values are static?
Which values are dynamic?
Which operations are captured?
Which effects are allowed?
What determines the cache key?
Which control-flow constructs are staged?
How are residuals represented?
How do AD, batching, and compilation compose?
```

Users can reason about staged systems when these rules are stable and visible.

### Summary

Staging separates program construction from program execution. Partial evaluation executes known parts early and leaves unknown parts as generated code.

For automatic differentiation, this separation is powerful. It lets the system capture numerical programs, transform them into derivative programs, optimize them, compile them, and reuse them.

The price is a stricter execution model. Some values exist at trace time, others at run time. Some arguments are static, others dynamic. Side effects and randomness must be explicit. Control flow depending on dynamic values must be represented in the staged IR.

Staging makes AD less like a runtime trick and more like a compiler pipeline.

