# Differentiating Stateful Systems

## Differentiating Stateful Systems

A stateful system is a program whose output depends not only on its explicit inputs, but also on stored state. The state may live in variables, objects, arrays, files, random number generators, caches, mutable buffers, optimizers, simulators, databases, or external services.

In pure mathematical notation, a function has the form:

$$
y = f(x)
$$

A stateful computation has the form:

$$
(y, s') = F(x, s)
$$

where $s$ is the old state and $s'$ is the new state.

This distinction matters for automatic differentiation. AD differentiates computations. If the computation reads or writes state, the derivative must account for how state participates in the execution.

## Explicit State

The cleanest way to differentiate a stateful system is to make the state explicit.

Instead of:

```text
counter = 0

def f(x):
    global counter
    counter += 1
    return x * counter
```

write:

```text
def f(x, counter):
    counter2 = counter + 1
    y = x * counter2
    return y, counter2
```

Now the computation is an ordinary function:

$$
(y, c') = F(x, c)
$$

and AD can differentiate it with respect to $x$, $c$, or both.

If $c$ is treated as an integer counter, then gradients do not flow through it. If $c$ is a floating-point state variable, then it may have a derivative.

The key rule is simple: differentiable state must appear in the computational graph.

## Hidden State

Hidden state creates ambiguity.

```text
class Accumulator:
    def __init__(self):
        self.total = 0

    def add(self, x):
        self.total += x
        return self.total
```

The output of `add(x)` depends on all previous calls. The method call is not a pure function of `x`.

A sequence of calls:

```text
a = acc.add(x1)
b = acc.add(x2)
c = acc.add(x3)
```

produces:

$$
a = x_1
$$

$$
b = x_1 + x_2
$$

$$
c = x_1 + x_2 + x_3
$$

The derivative of $c$ with respect to each input is:

$$
\frac{\partial c}{\partial x_1} =
\frac{\partial c}{\partial x_2} =
\frac{\partial c}{\partial x_3} =
1.
$$

An AD system can compute this only if the state updates are visible to the differentiated program. If state changes occur outside the trace, the gradient may be incomplete or wrong.

## Mutation and Versioning

Mutation overwrites values.

```text
x = x * x
x = x + 1
```

In source code, the name `x` is reused. Mathematically, these are different values:

$$
x_1 = x_0^2
$$

$$
x_2 = x_1 + 1
$$

Reverse mode needs this distinction. The backward rule for $x_1=x_0^2$ needs $x_0$, even though the variable name has later been reused.

A compiler usually converts mutable code into static single assignment form:

```text
x0 = input
x1 = x0 * x0
x2 = x1 + 1
```

SSA makes data dependencies explicit. It turns stateful-looking local mutation into a graph of immutable values.

Array mutation is harder because only part of a value may change:

```text
a[i] = a[i] + x
```

Reverse mode may need the old value of `a[i]`, the index `i`, and the aliasing relationship between `a` and other views.

## Aliasing

Aliasing occurs when two names refer to the same storage.

```text
a = zeros(10)
b = a
b[0] = 1
```

Now `a[0]` has changed too.

Aliasing complicates AD because the dataflow graph cannot be read from variable names alone. The system must reason about storage identity.

Consider:

```text
y = a[0] * a[0]
b = a
b[0] = 3
```

The backward pass for `y` needs the value of `a[0]` before mutation. If aliasing allows `b[0]` to overwrite it, the saved value must be protected.

Common strategies include:

| Strategy | Method |
|---|---|
| Prohibit unsafe mutation | Reject programs that mutate saved values |
| Version counters | Detect mutation after saving |
| Copy-on-write | Preserve old storage when needed |
| Functional updates | Rewrite updates as new values |
| Alias analysis | Prove mutation cannot affect saved values |

The correctness invariant is that every backward rule must see the forward values it depends on.

## Stateful Layers

Machine learning models often contain stateful layers. Batch normalization is a common example.

During training, a batch normalization layer computes batch statistics and updates running estimates:

```text
mean = batch_mean(x)
var  = batch_var(x)

running_mean = momentum * running_mean + (1 - momentum) * mean
running_var  = momentum * running_var  + (1 - momentum) * var

y = normalize(x, mean, var)
```

There are two different computations here.

The output `y` is differentiable with respect to `x`, scale, and bias.

The running statistics are state updates used for later inference. They are usually not optimized by gradient descent.

So an AD system must distinguish:

```text
differentiable parameters
differentiable activations
non-differentiated state updates
```

This distinction is often represented by marking some values as trainable parameters, some as buffers, and some as temporary activations.

## Optimizer State

Optimizers also carry state.

For momentum SGD:

```text
v = beta * v + grad
theta = theta - lr * v
```

The state $v$ affects future parameters.

In ordinary training, gradients are computed through the model loss with respect to $\theta$, but not through the optimizer update itself. The update is performed outside the differentiated objective.

In meta-learning, however, one may differentiate through optimization:

$$
\theta_{t+1} =
\theta_t - \eta \nabla_\theta L(\theta_t)
$$

Now the optimizer state and update rule are part of the computation graph. This requires higher-order differentiation and careful treatment of state across training steps.

The same code therefore has different derivative meaning depending on whether optimizer state is inside or outside the differentiated region.

## Random State

Random number generators are stateful.

```text
eps = randn()
y = x + eps
```

A random draw updates the generator state. The sample itself is usually treated as non-differentiable with respect to the generator state.

For reparameterized randomness, the sample is represented as:

$$
z = \mu + \sigma \epsilon
$$

where $\epsilon$ is sampled independently and treated as fixed during differentiation.

Then:

$$
\frac{\partial z}{\partial \mu}=1
$$

$$
\frac{\partial z}{\partial \sigma}=\epsilon
$$

A clean AD design passes random keys or generator state explicitly:

```text
eps, key2 = normal(key)
z = mu + sigma * eps
return z, key2
```

This makes randomness reproducible and makes state transitions visible.

## Stateful Simulators

Physical simulators are stateful by nature. A simulation step updates position, velocity, constraints, and sometimes internal solver state.

```text
state = step(state, control)
```

Mathematically:

$$
s_{t+1}=F(s_t,u_t)
$$

A rollout is:

$$
s_T = F(F(\cdots F(s_0,u_0),u_1),\ldots,u_{T-1})
$$

Differentiating a rollout applies the same rules as loop differentiation. The state transition Jacobian is multiplied through time.

Reverse mode requires either storing states or recomputing them.

Stateful simulators add complications:

```text
contact caches
adaptive step sizes
solver warm starts
constraint activation
collision histories
```

Some of these are differentiable state. Some are discrete execution aids. The derivative is meaningful only after deciding which parts of the simulator state belong to the mathematical model.

## External State

Programs may read from files, databases, clocks, network services, or global registries.

```text
scale = read_config("scale")
y = scale * x
```

If `scale` is external and not represented as an AD input, gradients with respect to it cannot be computed.

External state also harms reproducibility. The same input may produce different outputs if the external state changes.

For differentiable programming, external state should be handled in one of three ways:

```text
make it an explicit input
treat it as a constant
exclude it from differentiated regions
```

Silent dependence on external mutable state is a common source of incorrect gradients.

## Caches and Memoization

Caches improve performance but can obscure data dependencies.

```text
if key in cache:
    return cache[key]

y = expensive(x)
cache[key] = y
return y
```

If the cache key does not fully capture the differentiable inputs, the program may return a stale value. The gradient then corresponds to the cached computation, not the intended computation.

In AD systems, cached values should either:

```text
store derivative history safely
detach intentionally
be recomputed inside the traced region
be keyed by all relevant inputs
```

A cache that stores graph-attached values may also retain large computation graphs and cause memory growth.

## State and Purity

Pure functions are easiest to differentiate because dependencies are explicit.

A pure function has no hidden reads or writes:

```text
y = f(x)
```

A state-passing version remains pure:

```text
y, state2 = f(x, state1)
```

This style is common in differentiable programming systems that target compilation, parallelism, and transformation. Explicit state makes it possible to apply AD, vectorization, checkpointing, and staging without guessing about hidden effects.

The cost is verbosity. The benefit is correctness.

## Correctness Rule

For a stateful computation:

```text
(y, s_next) = F(x, s)
```

AD can compute correct derivatives only for dependencies represented inside the differentiated computation.

```text
Visible differentiable state:
    participates in the graph.

Visible non-differentiable state:
    controls execution but receives no gradient.

Hidden mutable state:
    risks incomplete or incorrect gradients.

External state:
    must be explicit, constant, or excluded.
```

State does not prevent automatic differentiation. It requires a precise boundary around what counts as part of the function being differentiated.

