Skip to content

Differentiable Subprograms

A differentiable subprogram is a program fragment that can participate in derivative propagation as a coherent unit. Instead of differentiating an entire application...

A differentiable subprogram is a program fragment that can participate in derivative propagation as a coherent unit. Instead of differentiating an entire application monolithically, AD systems decompose computation into smaller callable pieces with well-defined derivative behavior.

A simple example is:

def square(x):
    return x * x

Mathematically:

f(x)=x2. f(x) = x^2.

Its derivative is:

f(x)=2x. f'(x) = 2x.

An AD system can either inline the body into a larger graph or treat the function as a reusable differentiable component.

Functions as Differentiable Maps

A differentiable subprogram behaves like a map:

f:XY. f : X \to Y.

Automatic differentiation constructs associated derivative maps.

Forward mode constructs:

Df:(x,x˙)(f(x),y˙). Df : (x, \dot x) \mapsto (f(x), \dot y).

Reverse mode constructs a backward map:

Bf:(x,yˉ)xˉ. B_f : (x, \bar y) \mapsto \bar x.

The function boundary becomes part of the differentiation structure.

Encapsulation

Subprograms encapsulate local computation.

Example:

def layer(x, w, b):
    return relu(matmul(x, w) + b)

A larger model may call this repeatedly:

h1 = layer(x, w1, b1)
h2 = layer(h1, w2, b2)
y  = layer(h2, w3, b3)

The AD system can:

StrategyMeaning
Inlineexpand the body each call
Reuse transformed versioncache derivative transform
Treat as primitiveuse custom derivative rule

Encapsulation allows modular differentiation.

Call Graphs

Programs with functions form a call graph.

Example:

main
 ├─> encoder
 │    ├─> attention
 │    └─> mlp
 └─> decoder
      ├─> attention
      └─> mlp

AD propagates derivatives through the same call structure.

Forward mode follows call direction. Reverse mode propagates adjoints back through return dependencies.

A reverse-mode engine must know:

ItemPurpose
Inputs to functionbackward derivatives
Outputs from functionoutput adjoints
Saved intermediateslocal backward rules
Call orderingreverse traversal

Function Composition

Subprograms compose naturally.

If:

y=g(f(x)), y = g(f(x)),

then:

dydx=dydfdfdx. \frac{dy}{dx} = \frac{dy}{df} \frac{df}{dx}.

In code:

u = f(x)
y = g(u)

The dependency graph becomes:

x -> f -> u -> g -> y

AD applies the chain rule across function boundaries exactly as across primitive operations.

Local Derivative Contracts

A differentiable subprogram exposes a local derivative contract.

For forward mode:

(primal_in, tangent_in)
    ->
(primal_out, tangent_out)

For reverse mode:

(primal_in, primal_out, output_adj)
    ->
input_adj

The outer system does not need to know the internal implementation if the contract is correct.

This abstraction enables custom derivative rules.

Primitive Operations

Some subprograms are treated as primitives.

Example:

y = sin(x)

The AD system does not expand the implementation of sine from numerical approximation code. Instead, it uses the known derivative rule:

ddxsin(x)=cos(x). \frac{d}{dx}\sin(x) = \cos(x).

Similarly for:

PrimitiveDerivative
exp(x)exp(x)
log(x)1/x
matmul(a,b)matrix rules
conv(x,w)convolution rules

Primitive differentiation hides implementation complexity.

User-Defined Functions

User-defined functions can usually be differentiated automatically.

Example:

def f(x):
    a = x * x
    b = sin(a)
    return b + 1

The system builds a dependency graph for the body and differentiates it mechanically.

Forward-mode transformed version conceptually becomes:

def df(x, dx):
    a  = x * x
    da = 2 * x * dx

    b  = sin(a)
    db = cos(a) * da

    y  = b + 1
    dy = db

    return y, dy

This transformation is systematic.

Closures

A closure captures external variables.

def make_scale(a):
    def scale(x):
        return a * x
    return scale

The inner function depends on both x and captured variable a.

Mathematically:

f(x;a)=ax. f(x; a) = ax.

The derivatives are:

fx=a,fa=x. \frac{\partial f}{\partial x} = a, \qquad \frac{\partial f}{\partial a} = x.

AD systems must track captured values as dependencies.

Recursive Functions

Recursive subprograms define themselves in terms of earlier calls.

Example:

def f(x, n):
    if n == 0:
        return x
    return sin(f(x, n - 1))

For fixed n, the call tree expands into a finite dependency graph.

AD differentiates the expanded execution trace.

Reverse mode must preserve:

ItemReason
Call stackreverse traversal
Local variableslocal derivatives
Return structureadjoint propagation

Recursive AD therefore resembles stack replay.

Higher-Order Functions

A higher-order function takes functions as inputs or outputs.

Example:

def apply_twice(f, x):
    return f(f(x))

If f is differentiable:

y=f(f(x)). y = f(f(x)).

Then:

dydx=f(f(x))f(x). \frac{dy}{dx} = f'(f(x))f'(x).

AD systems for functional languages often treat derivative transforms themselves as higher-order functions.

Differentiation as Program Transformation

A differentiable subprogram may be transformed into a new subprogram.

Original:

f : X -> Y

Forward-mode transform:

Df : (X × TX) -> (Y × TY)

Reverse-mode transform:

Rf : X -> (Y, Y* -> X*)

where TXTX and TYTY are tangent spaces, and XX^*, YY^* are cotangent spaces.

Thus AD itself becomes a compiler transform on callable units.

Custom Gradient Rules

Sometimes the default derivative is inefficient or numerically unstable.

A user may define a custom backward rule.

Example:

def stable_logsumexp(x):
    ...

The backward rule can be supplied directly:

def backward(output_adj):
    ...

Advantages include:

BenefitExample
Better stabilitylog-sum-exp
Lower memoryfused backward
Faster executioncustom kernels
Implicit differentiationsolvers
Approximate gradientsquantization

The subprogram becomes a primitive with user-defined differentiation semantics.

Opaque External Functions

Some subprograms call external libraries.

y = cuda_kernel(x)

The AD system may not know the internal implementation.

Possible strategies:

StrategyMeaning
Treat as non-differentiablestop gradient
Provide custom rulemanual backward
Trace internal opsif supported
Use finite differencesfallback approximation

Large systems often rely heavily on custom derivative rules for external kernels.

Nested Differentiation

A differentiable subprogram may itself invoke AD.

Example:

def gradient_norm(f, x):
    g = grad(f)(x)
    return dot(g, g)

Now AD is applied to a program that already performs differentiation.

This creates nested derivative structures:

Outer levelInner level
differentiate gradient_normdifferentiate f

Nested AD requires careful management of tangent and adjoint scopes.

Perturbation Confusion

Nested differentiation can accidentally mix derivative levels.

Example:

grad(lambda x: grad(f)(x))(x)

The inner and outer derivative computations must remain distinct.

Correct systems isolate derivative contexts so that:

LevelMeaning
Inner tangentderivative of f
Outer tangentderivative of derivative

Without isolation, perturbations may interfere and produce incorrect higher-order derivatives.

Function Boundaries and Optimization

Subprogram boundaries influence optimization.

Inlining may expose more fusion opportunities:

y = relu(matmul(x, w) + b)

Modular boundaries may improve reuse and compilation caching.

Compiler-based systems often balance:

GoalPreference
Maximum optimizationaggressive inlining
Fast compilationpreserve modularity
Reusabilitycached differentiated kernels
Lower memoryfused backward passes

Differentiable subprograms are therefore both semantic and optimization units.

Interface Design

A minimal differentiable interface may look like:

type Function interface {
    Forward(x Value) Value
    Backward(yAdj Value) Value
}

More realistic systems require:

RequirementReason
Multiple inputstensor programs
Multiple outputsstructured models
Saved intermediatesreverse mode
Device awarenessGPU execution
Shape metadatatensor validation
Batched evaluationvectorization

Production AD systems therefore build sophisticated callable abstractions around these basic ideas.

Core Idea

A differentiable subprogram is a callable computation unit with well-defined derivative behavior. Automatic differentiation propagates through function boundaries exactly as through primitive operations: by composing local derivative rules according to the dependency structure.

Subprograms provide modularity, reuse, abstraction, and optimization boundaries. They also enable custom gradients, higher-order differentiation, and compiler-level differentiation transforms.