# Chapter 9. Differentiation of Control Flow

## Conditionals

A conditional is a program construct that chooses one computation among several possible computations. In ordinary code, this is written as `if`, `else`, `switch`, `case`, pattern matching, or guarded expressions. In automatic differentiation, a conditional raises a simple question:

When a program chooses one branch, which derivative should the differentiated program compute?

For a fixed execution, the answer is direct. The derivative follows the branch that was actually executed. If the input changes enough to make the program choose a different branch, the derivative may change discontinuously, or may stop existing at the boundary between branches.

A conditional therefore turns one program into a piecewise-defined function.

```text
if x > 0:
    y = x * x
else:
    y = 3 * x
```

This program denotes the mathematical function

$$
f(x) =
\begin{cases}
x^2, & x > 0, \\
3x, & x \le 0.
\end{cases}
$$

Its derivative away from the branch boundary is

$$
f'(x) =
\begin{cases}
2x, & x > 0, \\
3, & x < 0.
\end{cases}
$$

At $x = 0$, the derivative does not exist, because the left derivative is $3$ and the right derivative is $0$.

This is the central rule for differentiating conditionals: automatic differentiation differentiates the executed path. It does not automatically reason about all paths unless the AD system has special support for symbolic or abstract control flow.

## Execution Path Semantics

Consider a program

$$
y = P(x)
$$

where $P$ contains a conditional:

```text
if c(x):
    y = f(x)
else:
    y = g(x)
```

Here `c(x)` is the branch predicate. The denotation is

$$
P(x) =
\begin{cases}
f(x), & c(x) = \text{true}, \\
g(x), & c(x) = \text{false}.
\end{cases}
$$

At an input $x_0$, the program executes exactly one branch. If $c(x_0)$ is true, AD computes the derivative of $f$ at $x_0$. If $c(x_0)$ is false, AD computes the derivative of $g$ at $x_0$.

This gives a local derivative of the chosen branch:

$$
DP(x_0) =
\begin{cases}
Df(x_0), & c(x_0) = \text{true}, \\
Dg(x_0), & c(x_0) = \text{false}.
\end{cases}
$$

This formula is valid only where the active branch remains locally stable. In other words, there must be a neighborhood around $x_0$ in which the same branch is selected. If an arbitrarily small perturbation of $x_0$ can switch the branch, then the branch boundary matters.

For predicates such as `x > 0`, the boundary is $x = 0$. Away from the boundary, the condition behaves as a constant with respect to small perturbations. At the boundary, the program may be continuous, discontinuous, differentiable, or non-differentiable depending on the branch formulas.

## Conditionals in Forward Mode

Forward mode propagates primal values and tangent values together. For a variable $x$, forward mode evaluates a pair

$$
(x, \dot{x})
$$

where $x$ is the ordinary value and $\dot{x}$ is the tangent.

For a conditional, the primal predicate is evaluated first. The selected branch is then evaluated on both primal and tangent values.

```text
def f(x):
    if x > 0:
        return x * x
    else:
        return 3 * x
```

Forward mode transforms this conceptually into:

```text
def jvp_f(x, dx):
    if x > 0:
        y = x * x
        dy = dx * x + x * dx
        return y, dy
    else:
        y = 3 * x
        dy = 3 * dx
        return y, dy
```

The condition uses the primal value `x`, not the tangent `dx`. The tangent does not decide which branch is taken. It only propagates through the branch chosen by the primal execution.

For example, at $x = 2$ with $\dot{x} = 1$, the program takes the first branch:

$$
y = 4,\qquad \dot{y} = 4.
$$

At $x = -2$ with $\dot{x} = 1$, the program takes the second branch:

$$
y = -6,\qquad \dot{y} = 3.
$$

At $x = 0$, the program takes the `else` branch, because `x > 0` is false. The AD result is

$$
y = 0,\qquad \dot{y} = 3.
$$

This value is the derivative of the executed branch, not the mathematical derivative of the full piecewise function. The full function has no derivative at $0$.

This distinction is important. AD gives a derivative of the executed computation. When the executed computation represents a non-smooth mathematical function, the returned value at a boundary may be a branch derivative, a subgradient-like convention, or simply an implementation artifact.

## Conditionals in Reverse Mode

Reverse mode first executes the primal program and records the operations needed for the backward pass. When a conditional appears, only the active branch contributes operations to the tape.

For the same function,

```text
def f(x):
    if x > 0:
        return x * x
    else:
        return 3 * x
```

a reverse-mode execution at $x = 2$ records multiplication by `x` inside the true branch. The backward pass accumulates

$$
\bar{x} \mathrel{+}= 2x\bar{y}.
$$

At $x = -2$, the tape records multiplication by the constant `3`. The backward pass accumulates

$$
\bar{x} \mathrel{+}= 3\bar{y}.
$$

The inactive branch contributes nothing. This matches ordinary execution: code that did not run has no recorded intermediate values and no local adjoints.

A reverse-mode transformation may be written schematically as:

```text
def vjp_f(x, y_bar):
    if x > 0:
        y = x * x
        x_bar = 2 * x * y_bar
        return y, x_bar
    else:
        y = 3 * x
        x_bar = 3 * y_bar
        return y, x_bar
```

Real reverse-mode systems usually separate the forward recording phase from the backward replay phase. The branch decision made during the forward pass must be available during the backward pass. This can be done by storing a branch tag, storing a tape of executed operations, or compiling structured control flow into an intermediate representation that preserves the branch.

## Branch Predicates Are Usually Not Differentiated

Most AD systems treat branch predicates as control decisions, not differentiable computations. A predicate such as

```text
x > 0
```

returns a Boolean value. Boolean values do not have ordinary real-valued derivatives. The derivative of the comparison itself is not propagated.

This means the conditional

```text
if x > 0:
    y = f(x)
else:
    y = g(x)
```

does not produce a term involving the derivative of `x > 0`. AD does not compute something like

$$
\frac{d}{dx}\mathbf{1}_{x>0}.
$$

Instead, it treats the branch as fixed for the current execution and differentiates the selected branch.

This behavior is correct for inputs away from branch boundaries. Near boundaries, it can hide important structure. For example,

```text
def step(x):
    if x > 0:
        return 1
    else:
        return 0
```

AD returns derivative $0$ on both branches, because each branch is constant. But the function has a jump discontinuity at $0$. The derivative information returned by AD does not describe the discontinuity.

## Continuity and Differentiability at Branch Boundaries

At a branch boundary, the function may fall into several cases.

First, the function may be discontinuous:

$$
f(x) =
\begin{cases}
1, & x > 0, \\
0, & x \le 0.
\end{cases}
$$

There is no derivative at the boundary.

Second, the function may be continuous but non-differentiable:

$$
f(x) =
\begin{cases}
x, & x > 0, \\
-x, & x \le 0.
\end{cases}
$$

This is $f(x)=|x|$. At $0$, the left derivative is $-1$, and the right derivative is $1$. The derivative does not exist.

Third, the function may be differentiable even though it is written with a conditional:

$$
f(x) =
\begin{cases}
x^2, & x > 0, \\
x^2, & x \le 0.
\end{cases}
$$

The branches agree in both value and derivative at the boundary. The conditional has no mathematical effect, although it may still matter operationally.

Fourth, the function may be continuous and once differentiable, but fail to have higher derivatives. For example,

$$
f(x) =
\begin{cases}
x^2, & x > 0, \\
0, & x \le 0.
\end{cases}
$$

At $0$, the first derivative exists and equals $0$. The second derivative has a jump across the boundary.

When using AD, the programmer must distinguish the derivative of the executed branch from the derivative of the whole piecewise function. AD systems generally do not prove that branch boundaries are harmless.

## Structured Conditionals in Array Programs

In tensor programming, a conditional often appears in elementwise form:

```text
y = where(x > 0, x * x, 3 * x)
```

This differs from an ordinary scalar `if`. A scalar `if` chooses one branch for the whole program. An elementwise `where` chooses per element.

Mathematically,

$$
y_i =
\begin{cases}
x_i^2, & x_i > 0, \\
3x_i, & x_i \le 0.
\end{cases}
$$

The derivative is also elementwise away from boundaries:

$$
\frac{\partial y_i}{\partial x_i} =
\begin{cases}
2x_i, & x_i > 0, \\
3, & x_i < 0.
\end{cases}
$$

The Jacobian is diagonal when each output element depends only on the corresponding input element.

There is a systems-level detail here. Some array systems evaluate both branches of `where` before selecting values. Others compile it into predicated execution. This matters when one branch contains invalid operations.

For example:

```text
y = where(x > 0, sqrt(x), 0)
```

If `sqrt(x)` is evaluated for negative elements before masking, the program may produce NaNs internally. These NaNs can contaminate gradients even if the selected primal output looks valid. Robust AD code often rewrites such expressions to avoid undefined inactive computations.

A safer form is:

```text
safe_x = where(x > 0, x, 0)
y = where(x > 0, sqrt(safe_x), 0)
```

This ensures that `sqrt` receives a valid input for every element.

## Control Flow in Tracing Systems

Some AD systems trace a program by observing one execution. This works well for straight-line code, but conditionals introduce specialization.

Suppose a tracer observes:

```text
if x > 0:
    y = x * x
else:
    y = 3 * x
```

with $x = 2$. The trace may contain only the true branch. If the traced program is reused for $x = -2$, it may compute the wrong result unless the tracing system represents the conditional explicitly.

There are two common designs.

The first design is trace-by-execution. The system records only the operations that ran. This is simple and flexible for dynamic programs, but the resulting trace is valid only for the observed control path.

The second design is structured control flow. The system represents conditionals as explicit IR nodes, such as:

```text
cond(predicate, true_branch, false_branch)
```

Both branches are available to the compiler, and the predicate is evaluated at runtime. This supports compilation across different branch choices, but imposes stronger restrictions on branch shapes, types, and side effects.

Structured conditionals are common in staged and compiled AD systems because they allow optimization, batching, and device execution.

## Type and Shape Constraints

A conditional expression must have a well-defined output type. In ordinary dynamic programming languages, branches may return different shapes or types:

```text
if flag:
    return scalar
else:
    return vector
```

This is difficult for compiled AD. The differentiated program needs a stable representation for both primal and derivative values. For this reason, many AD systems require both branches of a structured conditional to return the same pytree structure, tensor rank, dtype, and sometimes static shape.

For example:

```text
cond(flag, branch_a, branch_b, x)
```

usually requires:

```text
branch_a(x).shape == branch_b(x).shape
branch_a(x).dtype  == branch_b(x).dtype
```

The derivative outputs must also match. If one branch returns a differentiable floating-point tensor and another branch returns an integer or Boolean, the gradient structure may be undefined or partially zero.

This constraint is not merely an implementation inconvenience. A function with branch-dependent output type does not denote a single smooth map between fixed vector spaces. Classical derivatives require a stable domain and codomain.

## Differentiable Alternatives to Hard Branches

Some models use conditionals as hard decisions:

```text
if score > threshold:
    use expert_a()
else:
    use expert_b()
```

The branch decision is non-differentiable with respect to `score`. If the goal is gradient-based optimization over the decision itself, a hard conditional usually blocks useful gradient flow.

A common alternative is a soft mixture:

$$
y = \alpha f(x) + (1-\alpha)g(x)
$$

where

$$
\alpha = \sigma(k\,s)
$$

and $s$ is a score. As $k$ grows, the sigmoid becomes sharper. For finite $k$, the function remains smooth.

This changes the model. It no longer executes exactly one branch. It blends branches. The benefit is that gradients can flow into the gating score. The cost is extra computation and a different inductive bias.

Other alternatives include softmax routing, straight-through estimators, stochastic relaxations, and subgradient conventions. These are modeling choices, not automatic consequences of AD.

## Correctness Rule

For a conditional program

```text
if c(x):
    y = f(x)
else:
    y = g(x)
```

AD computes the derivative of the active branch at the current input, provided the active branch is differentiable.

This derivative equals the derivative of the whole program only when the branch choice is locally constant around the input, or when the branches agree sufficiently at the boundary.

The practical rule is:

```text
Away from branch boundaries:
    AD gives the ordinary derivative.

At branch boundaries:
    AD gives the derivative of the selected branch.
    The mathematical derivative may not exist.

For hard decisions:
    gradients do not usually flow through the decision itself.
```

Conditionals therefore do not break automatic differentiation. They make the differentiated function piecewise. The AD system follows the same control path as the primal execution, and the user remains responsible for the mathematical meaning of derivatives at branch boundaries.

