A differentiable subprogram is a program fragment that can participate in derivative propagation as a coherent unit. Instead of differentiating an entire application...
A differentiable subprogram is a program fragment that can participate in derivative propagation as a coherent unit. Instead of differentiating an entire application monolithically, AD systems decompose computation into smaller callable pieces with well-defined derivative behavior.
A simple example is:
def square(x):
return x * xMathematically:
Its derivative is:
An AD system can either inline the body into a larger graph or treat the function as a reusable differentiable component.
Functions as Differentiable Maps
A differentiable subprogram behaves like a map:
Automatic differentiation constructs associated derivative maps.
Forward mode constructs:
Reverse mode constructs a backward map:
The function boundary becomes part of the differentiation structure.
Encapsulation
Subprograms encapsulate local computation.
Example:
def layer(x, w, b):
return relu(matmul(x, w) + b)A larger model may call this repeatedly:
h1 = layer(x, w1, b1)
h2 = layer(h1, w2, b2)
y = layer(h2, w3, b3)The AD system can:
| Strategy | Meaning |
|---|---|
| Inline | expand the body each call |
| Reuse transformed version | cache derivative transform |
| Treat as primitive | use custom derivative rule |
Encapsulation allows modular differentiation.
Call Graphs
Programs with functions form a call graph.
Example:
main
├─> encoder
│ ├─> attention
│ └─> mlp
└─> decoder
├─> attention
└─> mlpAD propagates derivatives through the same call structure.
Forward mode follows call direction. Reverse mode propagates adjoints back through return dependencies.
A reverse-mode engine must know:
| Item | Purpose |
|---|---|
| Inputs to function | backward derivatives |
| Outputs from function | output adjoints |
| Saved intermediates | local backward rules |
| Call ordering | reverse traversal |
Function Composition
Subprograms compose naturally.
If:
then:
In code:
u = f(x)
y = g(u)The dependency graph becomes:
x -> f -> u -> g -> yAD applies the chain rule across function boundaries exactly as across primitive operations.
Local Derivative Contracts
A differentiable subprogram exposes a local derivative contract.
For forward mode:
(primal_in, tangent_in)
->
(primal_out, tangent_out)For reverse mode:
(primal_in, primal_out, output_adj)
->
input_adjThe outer system does not need to know the internal implementation if the contract is correct.
This abstraction enables custom derivative rules.
Primitive Operations
Some subprograms are treated as primitives.
Example:
y = sin(x)The AD system does not expand the implementation of sine from numerical approximation code. Instead, it uses the known derivative rule:
Similarly for:
| Primitive | Derivative |
|---|---|
exp(x) | exp(x) |
log(x) | 1/x |
matmul(a,b) | matrix rules |
conv(x,w) | convolution rules |
Primitive differentiation hides implementation complexity.
User-Defined Functions
User-defined functions can usually be differentiated automatically.
Example:
def f(x):
a = x * x
b = sin(a)
return b + 1The system builds a dependency graph for the body and differentiates it mechanically.
Forward-mode transformed version conceptually becomes:
def df(x, dx):
a = x * x
da = 2 * x * dx
b = sin(a)
db = cos(a) * da
y = b + 1
dy = db
return y, dyThis transformation is systematic.
Closures
A closure captures external variables.
def make_scale(a):
def scale(x):
return a * x
return scaleThe inner function depends on both x and captured variable a.
Mathematically:
The derivatives are:
AD systems must track captured values as dependencies.
Recursive Functions
Recursive subprograms define themselves in terms of earlier calls.
Example:
def f(x, n):
if n == 0:
return x
return sin(f(x, n - 1))For fixed n, the call tree expands into a finite dependency graph.
AD differentiates the expanded execution trace.
Reverse mode must preserve:
| Item | Reason |
|---|---|
| Call stack | reverse traversal |
| Local variables | local derivatives |
| Return structure | adjoint propagation |
Recursive AD therefore resembles stack replay.
Higher-Order Functions
A higher-order function takes functions as inputs or outputs.
Example:
def apply_twice(f, x):
return f(f(x))If f is differentiable:
Then:
AD systems for functional languages often treat derivative transforms themselves as higher-order functions.
Differentiation as Program Transformation
A differentiable subprogram may be transformed into a new subprogram.
Original:
f : X -> YForward-mode transform:
Df : (X × TX) -> (Y × TY)Reverse-mode transform:
Rf : X -> (Y, Y* -> X*)where and are tangent spaces, and , are cotangent spaces.
Thus AD itself becomes a compiler transform on callable units.
Custom Gradient Rules
Sometimes the default derivative is inefficient or numerically unstable.
A user may define a custom backward rule.
Example:
def stable_logsumexp(x):
...The backward rule can be supplied directly:
def backward(output_adj):
...Advantages include:
| Benefit | Example |
|---|---|
| Better stability | log-sum-exp |
| Lower memory | fused backward |
| Faster execution | custom kernels |
| Implicit differentiation | solvers |
| Approximate gradients | quantization |
The subprogram becomes a primitive with user-defined differentiation semantics.
Opaque External Functions
Some subprograms call external libraries.
y = cuda_kernel(x)The AD system may not know the internal implementation.
Possible strategies:
| Strategy | Meaning |
|---|---|
| Treat as non-differentiable | stop gradient |
| Provide custom rule | manual backward |
| Trace internal ops | if supported |
| Use finite differences | fallback approximation |
Large systems often rely heavily on custom derivative rules for external kernels.
Nested Differentiation
A differentiable subprogram may itself invoke AD.
Example:
def gradient_norm(f, x):
g = grad(f)(x)
return dot(g, g)Now AD is applied to a program that already performs differentiation.
This creates nested derivative structures:
| Outer level | Inner level |
|---|---|
differentiate gradient_norm | differentiate f |
Nested AD requires careful management of tangent and adjoint scopes.
Perturbation Confusion
Nested differentiation can accidentally mix derivative levels.
Example:
grad(lambda x: grad(f)(x))(x)The inner and outer derivative computations must remain distinct.
Correct systems isolate derivative contexts so that:
| Level | Meaning |
|---|---|
| Inner tangent | derivative of f |
| Outer tangent | derivative of derivative |
Without isolation, perturbations may interfere and produce incorrect higher-order derivatives.
Function Boundaries and Optimization
Subprogram boundaries influence optimization.
Inlining may expose more fusion opportunities:
y = relu(matmul(x, w) + b)Modular boundaries may improve reuse and compilation caching.
Compiler-based systems often balance:
| Goal | Preference |
|---|---|
| Maximum optimization | aggressive inlining |
| Fast compilation | preserve modularity |
| Reusability | cached differentiated kernels |
| Lower memory | fused backward passes |
Differentiable subprograms are therefore both semantic and optimization units.
Interface Design
A minimal differentiable interface may look like:
type Function interface {
Forward(x Value) Value
Backward(yAdj Value) Value
}More realistic systems require:
| Requirement | Reason |
|---|---|
| Multiple inputs | tensor programs |
| Multiple outputs | structured models |
| Saved intermediates | reverse mode |
| Device awareness | GPU execution |
| Shape metadata | tensor validation |
| Batched evaluation | vectorization |
Production AD systems therefore build sophisticated callable abstractions around these basic ideas.
Core Idea
A differentiable subprogram is a callable computation unit with well-defined derivative behavior. Automatic differentiation propagates through function boundaries exactly as through primitive operations: by composing local derivative rules according to the dependency structure.
Subprograms provide modularity, reuse, abstraction, and optimization boundaries. They also enable custom gradients, higher-order differentiation, and compiler-level differentiation transforms.