A conditional is a program construct that chooses one computation among several possible computations. In ordinary code, this is written as if, else, switch, case, pattern...
Conditionals
A conditional is a program construct that chooses one computation among several possible computations. In ordinary code, this is written as if, else, switch, case, pattern matching, or guarded expressions. In automatic differentiation, a conditional raises a simple question:
When a program chooses one branch, which derivative should the differentiated program compute?
For a fixed execution, the answer is direct. The derivative follows the branch that was actually executed. If the input changes enough to make the program choose a different branch, the derivative may change discontinuously, or may stop existing at the boundary between branches.
A conditional therefore turns one program into a piecewise-defined function.
if x > 0:
y = x * x
else:
y = 3 * xThis program denotes the mathematical function
Its derivative away from the branch boundary is
At , the derivative does not exist, because the left derivative is and the right derivative is .
This is the central rule for differentiating conditionals: automatic differentiation differentiates the executed path. It does not automatically reason about all paths unless the AD system has special support for symbolic or abstract control flow.
Execution Path Semantics
Consider a program
where contains a conditional:
if c(x):
y = f(x)
else:
y = g(x)Here c(x) is the branch predicate. The denotation is
At an input , the program executes exactly one branch. If is true, AD computes the derivative of at . If is false, AD computes the derivative of at .
This gives a local derivative of the chosen branch:
This formula is valid only where the active branch remains locally stable. In other words, there must be a neighborhood around in which the same branch is selected. If an arbitrarily small perturbation of can switch the branch, then the branch boundary matters.
For predicates such as x > 0, the boundary is . Away from the boundary, the condition behaves as a constant with respect to small perturbations. At the boundary, the program may be continuous, discontinuous, differentiable, or non-differentiable depending on the branch formulas.
Conditionals in Forward Mode
Forward mode propagates primal values and tangent values together. For a variable , forward mode evaluates a pair
where is the ordinary value and is the tangent.
For a conditional, the primal predicate is evaluated first. The selected branch is then evaluated on both primal and tangent values.
def f(x):
if x > 0:
return x * x
else:
return 3 * xForward mode transforms this conceptually into:
def jvp_f(x, dx):
if x > 0:
y = x * x
dy = dx * x + x * dx
return y, dy
else:
y = 3 * x
dy = 3 * dx
return y, dyThe condition uses the primal value x, not the tangent dx. The tangent does not decide which branch is taken. It only propagates through the branch chosen by the primal execution.
For example, at with , the program takes the first branch:
At with , the program takes the second branch:
At , the program takes the else branch, because x > 0 is false. The AD result is
This value is the derivative of the executed branch, not the mathematical derivative of the full piecewise function. The full function has no derivative at .
This distinction is important. AD gives a derivative of the executed computation. When the executed computation represents a non-smooth mathematical function, the returned value at a boundary may be a branch derivative, a subgradient-like convention, or simply an implementation artifact.
Conditionals in Reverse Mode
Reverse mode first executes the primal program and records the operations needed for the backward pass. When a conditional appears, only the active branch contributes operations to the tape.
For the same function,
def f(x):
if x > 0:
return x * x
else:
return 3 * xa reverse-mode execution at records multiplication by x inside the true branch. The backward pass accumulates
At , the tape records multiplication by the constant 3. The backward pass accumulates
The inactive branch contributes nothing. This matches ordinary execution: code that did not run has no recorded intermediate values and no local adjoints.
A reverse-mode transformation may be written schematically as:
def vjp_f(x, y_bar):
if x > 0:
y = x * x
x_bar = 2 * x * y_bar
return y, x_bar
else:
y = 3 * x
x_bar = 3 * y_bar
return y, x_barReal reverse-mode systems usually separate the forward recording phase from the backward replay phase. The branch decision made during the forward pass must be available during the backward pass. This can be done by storing a branch tag, storing a tape of executed operations, or compiling structured control flow into an intermediate representation that preserves the branch.
Branch Predicates Are Usually Not Differentiated
Most AD systems treat branch predicates as control decisions, not differentiable computations. A predicate such as
x > 0returns a Boolean value. Boolean values do not have ordinary real-valued derivatives. The derivative of the comparison itself is not propagated.
This means the conditional
if x > 0:
y = f(x)
else:
y = g(x)does not produce a term involving the derivative of x > 0. AD does not compute something like
Instead, it treats the branch as fixed for the current execution and differentiates the selected branch.
This behavior is correct for inputs away from branch boundaries. Near boundaries, it can hide important structure. For example,
def step(x):
if x > 0:
return 1
else:
return 0AD returns derivative on both branches, because each branch is constant. But the function has a jump discontinuity at . The derivative information returned by AD does not describe the discontinuity.
Continuity and Differentiability at Branch Boundaries
At a branch boundary, the function may fall into several cases.
First, the function may be discontinuous:
There is no derivative at the boundary.
Second, the function may be continuous but non-differentiable:
This is . At , the left derivative is , and the right derivative is . The derivative does not exist.
Third, the function may be differentiable even though it is written with a conditional:
The branches agree in both value and derivative at the boundary. The conditional has no mathematical effect, although it may still matter operationally.
Fourth, the function may be continuous and once differentiable, but fail to have higher derivatives. For example,
At , the first derivative exists and equals . The second derivative has a jump across the boundary.
When using AD, the programmer must distinguish the derivative of the executed branch from the derivative of the whole piecewise function. AD systems generally do not prove that branch boundaries are harmless.
Structured Conditionals in Array Programs
In tensor programming, a conditional often appears in elementwise form:
y = where(x > 0, x * x, 3 * x)This differs from an ordinary scalar if. A scalar if chooses one branch for the whole program. An elementwise where chooses per element.
Mathematically,
The derivative is also elementwise away from boundaries:
The Jacobian is diagonal when each output element depends only on the corresponding input element.
There is a systems-level detail here. Some array systems evaluate both branches of where before selecting values. Others compile it into predicated execution. This matters when one branch contains invalid operations.
For example:
y = where(x > 0, sqrt(x), 0)If sqrt(x) is evaluated for negative elements before masking, the program may produce NaNs internally. These NaNs can contaminate gradients even if the selected primal output looks valid. Robust AD code often rewrites such expressions to avoid undefined inactive computations.
A safer form is:
safe_x = where(x > 0, x, 0)
y = where(x > 0, sqrt(safe_x), 0)This ensures that sqrt receives a valid input for every element.
Control Flow in Tracing Systems
Some AD systems trace a program by observing one execution. This works well for straight-line code, but conditionals introduce specialization.
Suppose a tracer observes:
if x > 0:
y = x * x
else:
y = 3 * xwith . The trace may contain only the true branch. If the traced program is reused for , it may compute the wrong result unless the tracing system represents the conditional explicitly.
There are two common designs.
The first design is trace-by-execution. The system records only the operations that ran. This is simple and flexible for dynamic programs, but the resulting trace is valid only for the observed control path.
The second design is structured control flow. The system represents conditionals as explicit IR nodes, such as:
cond(predicate, true_branch, false_branch)Both branches are available to the compiler, and the predicate is evaluated at runtime. This supports compilation across different branch choices, but imposes stronger restrictions on branch shapes, types, and side effects.
Structured conditionals are common in staged and compiled AD systems because they allow optimization, batching, and device execution.
Type and Shape Constraints
A conditional expression must have a well-defined output type. In ordinary dynamic programming languages, branches may return different shapes or types:
if flag:
return scalar
else:
return vectorThis is difficult for compiled AD. The differentiated program needs a stable representation for both primal and derivative values. For this reason, many AD systems require both branches of a structured conditional to return the same pytree structure, tensor rank, dtype, and sometimes static shape.
For example:
cond(flag, branch_a, branch_b, x)usually requires:
branch_a(x).shape == branch_b(x).shape
branch_a(x).dtype == branch_b(x).dtypeThe derivative outputs must also match. If one branch returns a differentiable floating-point tensor and another branch returns an integer or Boolean, the gradient structure may be undefined or partially zero.
This constraint is not merely an implementation inconvenience. A function with branch-dependent output type does not denote a single smooth map between fixed vector spaces. Classical derivatives require a stable domain and codomain.
Differentiable Alternatives to Hard Branches
Some models use conditionals as hard decisions:
if score > threshold:
use expert_a()
else:
use expert_b()The branch decision is non-differentiable with respect to score. If the goal is gradient-based optimization over the decision itself, a hard conditional usually blocks useful gradient flow.
A common alternative is a soft mixture:
where
and is a score. As grows, the sigmoid becomes sharper. For finite , the function remains smooth.
This changes the model. It no longer executes exactly one branch. It blends branches. The benefit is that gradients can flow into the gating score. The cost is extra computation and a different inductive bias.
Other alternatives include softmax routing, straight-through estimators, stochastic relaxations, and subgradient conventions. These are modeling choices, not automatic consequences of AD.
Correctness Rule
For a conditional program
if c(x):
y = f(x)
else:
y = g(x)AD computes the derivative of the active branch at the current input, provided the active branch is differentiable.
This derivative equals the derivative of the whole program only when the branch choice is locally constant around the input, or when the branches agree sufficiently at the boundary.
The practical rule is:
Away from branch boundaries:
AD gives the ordinary derivative.
At branch boundaries:
AD gives the derivative of the selected branch.
The mathematical derivative may not exist.
For hard decisions:
gradients do not usually flow through the decision itself.Conditionals therefore do not break automatic differentiation. They make the differentiated function piecewise. The AD system follows the same control path as the primal execution, and the user remains responsible for the mathematical meaning of derivatives at branch boundaries.