Skip to content

Differentiable Programming

Differentiable programming treats differentiation as a general programming-language feature. A program can contain numerical kernels, control flow, data structures, solvers,...

Differentiable programming treats differentiation as a general programming-language feature. A program can contain numerical kernels, control flow, data structures, solvers, simulations, and model code. If the program denotes a differentiable computation, the system should be able to compute derivatives of that program.

This is broader than neural network training. A differentiable program may include a physics simulator, a renderer, a database query, an optimizer, or a probabilistic model. The central idea is that gradients should be available across the whole program, not only inside a fixed tensor graph.

From Automatic Differentiation to Differentiable Programming

Automatic differentiation began as a technique for computing derivatives of numerical programs. Differentiable programming extends that idea into a design principle for languages and systems.

The core operator is a program transformation:

grad:(XR)(XX) \mathrm{grad} : (X \rightarrow \mathbb{R}) \rightarrow (X \rightarrow X)

Given a scalar-output function, grad produces another function that returns its gradient.

In code, this appears as:

def loss(params, batch):
    pred = model(params, batch.x)
    return mse(pred, batch.y)

g = grad(loss)(params, batch)

The important point is not the syntax. The important point is that loss can be an ordinary program. It can call functions, branch, loop, allocate intermediate values, and invoke libraries.

Differentiation as a First-Class Transformation

In a differentiable programming system, transformations such as grad, vjp, jvp, jacfwd, and jacrev are first-class operations.

TransformationMeaning
grad(f)Gradient of scalar-output function
jvp(f)Jacobian-vector product
vjp(f)Vector-Jacobian product
jacfwd(f)Jacobian using forward mode
jacrev(f)Jacobian using reverse mode
hessian(f)Second derivative matrix
value_and_grad(f)Primal value and gradient together

This style changes how numerical software is written. Instead of manually deriving update rules, the programmer writes the objective and asks the system for derivatives.

Programs Beyond Tensor Graphs

Early deep learning systems often represented models as static tensor graphs. A graph contained operations such as matrix multiplication, convolution, addition, and activation functions.

Differentiable programming relaxes this model. The differentiable object is the program itself.

For example:

def f(x):
    y = 0.0
    for i in range(10):
        if x[i] > 0:
            y = y + x[i] * x[i]
        else:
            y = y - x[i]
    return y

This program contains a loop and a branch. A differentiable programming system differentiates the executed computation. The derivative depends on the path taken through the program.

This makes differentiation useful for real numerical code, where fixed graphs are often too restrictive.

The Language Boundary

A central difficulty is deciding which parts of the language are differentiable.

Arithmetic over real-valued arrays is differentiable. Integer indexing, mutation, allocation, I/O, hashing, sorting, and discrete control decisions need more care.

ConstructDifferentiation issue
Floating point arithmeticUsually differentiable except at singularities
BranchingDifferentiates the selected branch
LoopsDifferentiates the executed iterations
Integer indexingIndex choice is usually non-differentiable
SortingPiecewise differentiable, discontinuous at ties
MutationRequires correct adjoint semantics
I/OUsually outside derivative computation
RandomnessNeeds reparameterization or score estimators
AllocationAffects runtime, not usually mathematical derivative

A differentiable programming language must make these boundaries explicit. Silent behavior is dangerous. A system should reject invalid differentiation, define a subgradient convention, or require a custom rule.

Custom Derivative Rules

Not every useful operation should be differentiated by expanding its implementation. Some operations have better derivative formulas than their source code suggests.

For example, a numerically stable logsumexp implementation may contain branching and shifts:

def logsumexp(x):
    m = max(x)
    return m + log(sum(exp(x - m)))

Differentiating the implementation directly may work, but a custom derivative rule is often clearer and more stable.

A custom rule gives the AD system a local derivative definition:

OperationCustom derivative reason
logsumexpNumerical stability
Matrix inverseAvoid differentiating solver internals naively
Cholesky factorizationStructured derivative
ODE solverUse adjoint method or sensitivity equation
Optimization solverUse implicit differentiation
Sampling operationUse reparameterization

Custom derivatives are part of the language interface. They allow library authors to expose mathematically correct and efficient differentiation behavior.

Differentiating Through Libraries

A real program uses libraries. A differentiable programming system must decide how derivative information crosses library boundaries.

There are several strategies.

StrategyDescription
Operator overloadingLibrary functions execute on derivative-aware values
TracingRuntime records primitive operations
Source transformationCompiler rewrites source code
Compiler IR differentiationAD operates on lowered intermediate representation
Custom primitive rulesLibraries expose derivative rules manually

Each strategy has a different boundary. Operator overloading works well when all operations dispatch through overloaded types. Tracing works when the runtime can observe the operations. Source transformation works when source code is available and transformable. Compiler IR differentiation works when the program has been lowered into an analyzable representation.

Differentiable Programming and Compilation

Differentiable programming is closely connected to compiler design.

A compiler-based system can perform several transformations together:

  1. Normalize the original program.
  2. Differentiate the normalized representation.
  3. Optimize the primal and derivative code.
  4. Fuse kernels.
  5. Plan memory.
  6. Lower to CPU, GPU, TPU, or accelerator code.

This is important because naive differentiated programs are often inefficient. Reverse mode can store too many intermediates. Higher-order AD can duplicate computation. Tensor programs can produce many small kernels.

A practical differentiable compiler performs AD and optimization as a single pipeline.

The Role of Types

Types help define which programs are differentiable.

A language may distinguish between:

TypeDifferentiation role
FloatDifferentiable scalar
Vector FloatDifferentiable array
IntUsually non-differentiable
BoolUsually non-differentiable
StringNon-differentiable
Array Int FloatDifferentiable values with discrete indices
FunctionDifferentiable only under constraints

A typed system can reject invalid uses early:

@differentiable
func f(_ x: Float) -> Float {
    return x * x
}

The annotation marks a function as participating in differentiation. The compiler can then check whether all operations inside the function support derivatives.

Shape types and effect systems extend this further. Shape types ensure tensor dimensions match. Effect systems track mutation, randomness, I/O, and other behavior that affects derivative semantics.

Mutation and State

Mutation is one of the hardest issues in differentiable programming.

Consider:

def f(x):
    y = x
    y[0] = y[0] * 2
    return sum(y)

The assignment changes the value of y. If y aliases x, the mutation also changes x. Reverse mode must reconstruct the correct sequence of states and propagate adjoints through each update.

There are several implementation choices.

ApproachBehavior
Disallow mutationSimplest semantics
Functionalize mutationRewrite updates into immutable values
Tape mutationsRecord old values for reverse pass
Use linear typesEnsure values have unique ownership
Define array update adjointsTreat mutation as scatter and gather

Functionalization is common. The system rewrites mutation into pure operations, making the derivative transformation easier to define.

Control Flow and Dynamic Programs

Differentiable programming must support control flow because scientific and machine learning code is full of it.

A loop differentiates the actual iterations run:

def fixed_point(x):
    y = x
    while norm(g(y) - y) > 1e-6:
        y = g(y)
    return y

Differentiating this program by unrolling the loop gives the derivative of the algorithm, not necessarily the derivative of the mathematical fixed point. These can differ.

This distinction matters.

TargetMeaning
Differentiate the algorithmDerivative of the finite executed computation
Differentiate the solutionDerivative of the mathematical object computed
Differentiate the implementationDerivative of the exact source-level operations

For solvers, implicit differentiation may be preferable to differentiating every iteration.

Differentiable Programming in Practice

Modern systems implement differentiable programming with different tradeoffs.

SystemMain style
PyTorchDynamic tracing with eager execution
TensorFlowGraph tracing and compiler paths
JAXPure functional transformations over traced programs
Julia ZygoteSource-to-source AD
EnzymeCompiler IR-level AD
Swift ADLanguage-integrated typed AD
TaichiDifferentiable simulation DSL
Dr.JitDifferentiable rendering and simulation kernels

The field has no single dominant architecture. The right design depends on the host language, execution target, and expected workloads.

Differentiable Programming Beyond Machine Learning

Differentiable programming is useful wherever a program contains parameters that should be optimized.

Examples include:

DomainDifferentiable program
RenderingImage formation pipeline
PhysicsSimulator with tunable parameters
RoboticsController and dynamics model
FinancePricing model and risk objective
BiologyKinetic model or molecular simulation
DatabasesLearned cost model or differentiable query component
CompilersAutotuning objective
ControlTrajectory optimizer

The common pattern is objective-driven computation. A program computes a loss, error, likelihood, reward, or constraint violation. AD provides the derivative needed to improve parameters.

Design Requirements

A serious differentiable programming system should provide:

RequirementReason
Correct derivative semanticsUsers must know what is being differentiated
Efficient reverse modeScalar losses over many parameters are common
Forward mode supportNeeded for JVPs, Jacobians, and implicit methods
Higher-order derivativesNeeded for curvature and meta-optimization
Custom derivative rulesNeeded for stability and performance
Control-flow supportReal programs branch and loop
Mutation modelState must have defined adjoint behavior
Compiler optimizationNaive AD produces inefficient code
Debugging toolsGradients need inspection and testing

Differentiable programming becomes a systems problem, not only a calculus problem.

Failure Modes

Differentiable programming systems fail in characteristic ways.

Failure modeExample
Wrong derivative targetDifferentiating solver iterations instead of solved equation
Memory blowupReverse mode stores every intermediate
Silent zero gradientsDiscrete operations cut gradient flow
Numerically unstable gradientsNaive rules amplify floating point error
Excessive recompilationDynamic shapes or branches trigger new traces
Perturbation confusionNested AD mixes derivative levels
Invalid custom rulesUser-supplied adjoints violate the true derivative

These failures are often subtle. Good systems expose diagnostics, gradient checks, and explicit boundaries between differentiable and non-differentiable code.

Summary

Differentiable programming generalizes automatic differentiation from numerical kernels to whole programs. It asks the language and compiler to treat derivatives as ordinary program transformations.

The core challenge is semantic clarity. A user should know whether the system differentiates the mathematical function, the algorithm, the implementation, or a custom abstraction exposed by a library. Once that boundary is clear, the remaining problems are compiler and runtime engineering: representation, memory, optimization, dispatch, and hardware lowering.