Skip to content

Differentiating Stateful Systems

A stateful system is a program whose output depends not only on its explicit inputs, but also on stored state. The state may live in variables, objects, arrays, files, random...

A stateful system is a program whose output depends not only on its explicit inputs, but also on stored state. The state may live in variables, objects, arrays, files, random number generators, caches, mutable buffers, optimizers, simulators, databases, or external services.

In pure mathematical notation, a function has the form:

y=f(x) y = f(x)

A stateful computation has the form:

(y,s)=F(x,s) (y, s') = F(x, s)

where ss is the old state and ss' is the new state.

This distinction matters for automatic differentiation. AD differentiates computations. If the computation reads or writes state, the derivative must account for how state participates in the execution.

Explicit State

The cleanest way to differentiate a stateful system is to make the state explicit.

Instead of:

counter = 0

def f(x):
    global counter
    counter += 1
    return x * counter

write:

def f(x, counter):
    counter2 = counter + 1
    y = x * counter2
    return y, counter2

Now the computation is an ordinary function:

(y,c)=F(x,c) (y, c') = F(x, c)

and AD can differentiate it with respect to xx, cc, or both.

If cc is treated as an integer counter, then gradients do not flow through it. If cc is a floating-point state variable, then it may have a derivative.

The key rule is simple: differentiable state must appear in the computational graph.

Hidden State

Hidden state creates ambiguity.

class Accumulator:
    def __init__(self):
        self.total = 0

    def add(self, x):
        self.total += x
        return self.total

The output of add(x) depends on all previous calls. The method call is not a pure function of x.

A sequence of calls:

a = acc.add(x1)
b = acc.add(x2)
c = acc.add(x3)

produces:

a=x1 a = x_1 b=x1+x2 b = x_1 + x_2 c=x1+x2+x3 c = x_1 + x_2 + x_3

The derivative of cc with respect to each input is:

cx1=cx2=cx3=1. \frac{\partial c}{\partial x_1} = \frac{\partial c}{\partial x_2} = \frac{\partial c}{\partial x_3} = 1.

An AD system can compute this only if the state updates are visible to the differentiated program. If state changes occur outside the trace, the gradient may be incomplete or wrong.

Mutation and Versioning

Mutation overwrites values.

x = x * x
x = x + 1

In source code, the name x is reused. Mathematically, these are different values:

x1=x02 x_1 = x_0^2 x2=x1+1 x_2 = x_1 + 1

Reverse mode needs this distinction. The backward rule for x1=x02x_1=x_0^2 needs x0x_0, even though the variable name has later been reused.

A compiler usually converts mutable code into static single assignment form:

x0 = input
x1 = x0 * x0
x2 = x1 + 1

SSA makes data dependencies explicit. It turns stateful-looking local mutation into a graph of immutable values.

Array mutation is harder because only part of a value may change:

a[i] = a[i] + x

Reverse mode may need the old value of a[i], the index i, and the aliasing relationship between a and other views.

Aliasing

Aliasing occurs when two names refer to the same storage.

a = zeros(10)
b = a
b[0] = 1

Now a[0] has changed too.

Aliasing complicates AD because the dataflow graph cannot be read from variable names alone. The system must reason about storage identity.

Consider:

y = a[0] * a[0]
b = a
b[0] = 3

The backward pass for y needs the value of a[0] before mutation. If aliasing allows b[0] to overwrite it, the saved value must be protected.

Common strategies include:

StrategyMethod
Prohibit unsafe mutationReject programs that mutate saved values
Version countersDetect mutation after saving
Copy-on-writePreserve old storage when needed
Functional updatesRewrite updates as new values
Alias analysisProve mutation cannot affect saved values

The correctness invariant is that every backward rule must see the forward values it depends on.

Stateful Layers

Machine learning models often contain stateful layers. Batch normalization is a common example.

During training, a batch normalization layer computes batch statistics and updates running estimates:

mean = batch_mean(x)
var  = batch_var(x)

running_mean = momentum * running_mean + (1 - momentum) * mean
running_var  = momentum * running_var  + (1 - momentum) * var

y = normalize(x, mean, var)

There are two different computations here.

The output y is differentiable with respect to x, scale, and bias.

The running statistics are state updates used for later inference. They are usually not optimized by gradient descent.

So an AD system must distinguish:

differentiable parameters
differentiable activations
non-differentiated state updates

This distinction is often represented by marking some values as trainable parameters, some as buffers, and some as temporary activations.

Optimizer State

Optimizers also carry state.

For momentum SGD:

v = beta * v + grad
theta = theta - lr * v

The state vv affects future parameters.

In ordinary training, gradients are computed through the model loss with respect to θ\theta, but not through the optimizer update itself. The update is performed outside the differentiated objective.

In meta-learning, however, one may differentiate through optimization:

θt+1=θtηθL(θt) \theta_{t+1} = \theta_t - \eta \nabla_\theta L(\theta_t)

Now the optimizer state and update rule are part of the computation graph. This requires higher-order differentiation and careful treatment of state across training steps.

The same code therefore has different derivative meaning depending on whether optimizer state is inside or outside the differentiated region.

Random State

Random number generators are stateful.

eps = randn()
y = x + eps

A random draw updates the generator state. The sample itself is usually treated as non-differentiable with respect to the generator state.

For reparameterized randomness, the sample is represented as:

z=μ+σϵ z = \mu + \sigma \epsilon

where ϵ\epsilon is sampled independently and treated as fixed during differentiation.

Then:

zμ=1 \frac{\partial z}{\partial \mu}=1 zσ=ϵ \frac{\partial z}{\partial \sigma}=\epsilon

A clean AD design passes random keys or generator state explicitly:

eps, key2 = normal(key)
z = mu + sigma * eps
return z, key2

This makes randomness reproducible and makes state transitions visible.

Stateful Simulators

Physical simulators are stateful by nature. A simulation step updates position, velocity, constraints, and sometimes internal solver state.

state = step(state, control)

Mathematically:

st+1=F(st,ut) s_{t+1}=F(s_t,u_t)

A rollout is:

sT=F(F(F(s0,u0),u1),,uT1) s_T = F(F(\cdots F(s_0,u_0),u_1),\ldots,u_{T-1})

Differentiating a rollout applies the same rules as loop differentiation. The state transition Jacobian is multiplied through time.

Reverse mode requires either storing states or recomputing them.

Stateful simulators add complications:

contact caches
adaptive step sizes
solver warm starts
constraint activation
collision histories

Some of these are differentiable state. Some are discrete execution aids. The derivative is meaningful only after deciding which parts of the simulator state belong to the mathematical model.

External State

Programs may read from files, databases, clocks, network services, or global registries.

scale = read_config("scale")
y = scale * x

If scale is external and not represented as an AD input, gradients with respect to it cannot be computed.

External state also harms reproducibility. The same input may produce different outputs if the external state changes.

For differentiable programming, external state should be handled in one of three ways:

make it an explicit input
treat it as a constant
exclude it from differentiated regions

Silent dependence on external mutable state is a common source of incorrect gradients.

Caches and Memoization

Caches improve performance but can obscure data dependencies.

if key in cache:
    return cache[key]

y = expensive(x)
cache[key] = y
return y

If the cache key does not fully capture the differentiable inputs, the program may return a stale value. The gradient then corresponds to the cached computation, not the intended computation.

In AD systems, cached values should either:

store derivative history safely
detach intentionally
be recomputed inside the traced region
be keyed by all relevant inputs

A cache that stores graph-attached values may also retain large computation graphs and cause memory growth.

State and Purity

Pure functions are easiest to differentiate because dependencies are explicit.

A pure function has no hidden reads or writes:

y = f(x)

A state-passing version remains pure:

y, state2 = f(x, state1)

This style is common in differentiable programming systems that target compilation, parallelism, and transformation. Explicit state makes it possible to apply AD, vectorization, checkpointing, and staging without guessing about hidden effects.

The cost is verbosity. The benefit is correctness.

Correctness Rule

For a stateful computation:

(y, s_next) = F(x, s)

AD can compute correct derivatives only for dependencies represented inside the differentiated computation.

Visible differentiable state:
    participates in the graph.

Visible non-differentiable state:
    controls execution but receives no gradient.

Hidden mutable state:
    risks incomplete or incorrect gradients.

External state:
    must be explicit, constant, or excluded.

State does not prevent automatic differentiation. It requires a precise boundary around what counts as part of the function being differentiated.