Differentiable programming treats differentiation as a general programming-language feature. A program can contain numerical kernels, control flow, data structures, solvers,...
Differentiable programming treats differentiation as a general programming-language feature. A program can contain numerical kernels, control flow, data structures, solvers, simulations, and model code. If the program denotes a differentiable computation, the system should be able to compute derivatives of that program.
This is broader than neural network training. A differentiable program may include a physics simulator, a renderer, a database query, an optimizer, or a probabilistic model. The central idea is that gradients should be available across the whole program, not only inside a fixed tensor graph.
From Automatic Differentiation to Differentiable Programming
Automatic differentiation began as a technique for computing derivatives of numerical programs. Differentiable programming extends that idea into a design principle for languages and systems.
The core operator is a program transformation:
Given a scalar-output function, grad produces another function that returns its gradient.
In code, this appears as:
def loss(params, batch):
pred = model(params, batch.x)
return mse(pred, batch.y)
g = grad(loss)(params, batch)The important point is not the syntax. The important point is that loss can be an ordinary program. It can call functions, branch, loop, allocate intermediate values, and invoke libraries.
Differentiation as a First-Class Transformation
In a differentiable programming system, transformations such as grad, vjp, jvp, jacfwd, and jacrev are first-class operations.
| Transformation | Meaning |
|---|---|
grad(f) | Gradient of scalar-output function |
jvp(f) | Jacobian-vector product |
vjp(f) | Vector-Jacobian product |
jacfwd(f) | Jacobian using forward mode |
jacrev(f) | Jacobian using reverse mode |
hessian(f) | Second derivative matrix |
value_and_grad(f) | Primal value and gradient together |
This style changes how numerical software is written. Instead of manually deriving update rules, the programmer writes the objective and asks the system for derivatives.
Programs Beyond Tensor Graphs
Early deep learning systems often represented models as static tensor graphs. A graph contained operations such as matrix multiplication, convolution, addition, and activation functions.
Differentiable programming relaxes this model. The differentiable object is the program itself.
For example:
def f(x):
y = 0.0
for i in range(10):
if x[i] > 0:
y = y + x[i] * x[i]
else:
y = y - x[i]
return yThis program contains a loop and a branch. A differentiable programming system differentiates the executed computation. The derivative depends on the path taken through the program.
This makes differentiation useful for real numerical code, where fixed graphs are often too restrictive.
The Language Boundary
A central difficulty is deciding which parts of the language are differentiable.
Arithmetic over real-valued arrays is differentiable. Integer indexing, mutation, allocation, I/O, hashing, sorting, and discrete control decisions need more care.
| Construct | Differentiation issue |
|---|---|
| Floating point arithmetic | Usually differentiable except at singularities |
| Branching | Differentiates the selected branch |
| Loops | Differentiates the executed iterations |
| Integer indexing | Index choice is usually non-differentiable |
| Sorting | Piecewise differentiable, discontinuous at ties |
| Mutation | Requires correct adjoint semantics |
| I/O | Usually outside derivative computation |
| Randomness | Needs reparameterization or score estimators |
| Allocation | Affects runtime, not usually mathematical derivative |
A differentiable programming language must make these boundaries explicit. Silent behavior is dangerous. A system should reject invalid differentiation, define a subgradient convention, or require a custom rule.
Custom Derivative Rules
Not every useful operation should be differentiated by expanding its implementation. Some operations have better derivative formulas than their source code suggests.
For example, a numerically stable logsumexp implementation may contain branching and shifts:
def logsumexp(x):
m = max(x)
return m + log(sum(exp(x - m)))Differentiating the implementation directly may work, but a custom derivative rule is often clearer and more stable.
A custom rule gives the AD system a local derivative definition:
| Operation | Custom derivative reason |
|---|---|
logsumexp | Numerical stability |
| Matrix inverse | Avoid differentiating solver internals naively |
| Cholesky factorization | Structured derivative |
| ODE solver | Use adjoint method or sensitivity equation |
| Optimization solver | Use implicit differentiation |
| Sampling operation | Use reparameterization |
Custom derivatives are part of the language interface. They allow library authors to expose mathematically correct and efficient differentiation behavior.
Differentiating Through Libraries
A real program uses libraries. A differentiable programming system must decide how derivative information crosses library boundaries.
There are several strategies.
| Strategy | Description |
|---|---|
| Operator overloading | Library functions execute on derivative-aware values |
| Tracing | Runtime records primitive operations |
| Source transformation | Compiler rewrites source code |
| Compiler IR differentiation | AD operates on lowered intermediate representation |
| Custom primitive rules | Libraries expose derivative rules manually |
Each strategy has a different boundary. Operator overloading works well when all operations dispatch through overloaded types. Tracing works when the runtime can observe the operations. Source transformation works when source code is available and transformable. Compiler IR differentiation works when the program has been lowered into an analyzable representation.
Differentiable Programming and Compilation
Differentiable programming is closely connected to compiler design.
A compiler-based system can perform several transformations together:
- Normalize the original program.
- Differentiate the normalized representation.
- Optimize the primal and derivative code.
- Fuse kernels.
- Plan memory.
- Lower to CPU, GPU, TPU, or accelerator code.
This is important because naive differentiated programs are often inefficient. Reverse mode can store too many intermediates. Higher-order AD can duplicate computation. Tensor programs can produce many small kernels.
A practical differentiable compiler performs AD and optimization as a single pipeline.
The Role of Types
Types help define which programs are differentiable.
A language may distinguish between:
| Type | Differentiation role |
|---|---|
Float | Differentiable scalar |
Vector Float | Differentiable array |
Int | Usually non-differentiable |
Bool | Usually non-differentiable |
String | Non-differentiable |
Array Int Float | Differentiable values with discrete indices |
Function | Differentiable only under constraints |
A typed system can reject invalid uses early:
@differentiable
func f(_ x: Float) -> Float {
return x * x
}The annotation marks a function as participating in differentiation. The compiler can then check whether all operations inside the function support derivatives.
Shape types and effect systems extend this further. Shape types ensure tensor dimensions match. Effect systems track mutation, randomness, I/O, and other behavior that affects derivative semantics.
Mutation and State
Mutation is one of the hardest issues in differentiable programming.
Consider:
def f(x):
y = x
y[0] = y[0] * 2
return sum(y)The assignment changes the value of y. If y aliases x, the mutation also changes x. Reverse mode must reconstruct the correct sequence of states and propagate adjoints through each update.
There are several implementation choices.
| Approach | Behavior |
|---|---|
| Disallow mutation | Simplest semantics |
| Functionalize mutation | Rewrite updates into immutable values |
| Tape mutations | Record old values for reverse pass |
| Use linear types | Ensure values have unique ownership |
| Define array update adjoints | Treat mutation as scatter and gather |
Functionalization is common. The system rewrites mutation into pure operations, making the derivative transformation easier to define.
Control Flow and Dynamic Programs
Differentiable programming must support control flow because scientific and machine learning code is full of it.
A loop differentiates the actual iterations run:
def fixed_point(x):
y = x
while norm(g(y) - y) > 1e-6:
y = g(y)
return yDifferentiating this program by unrolling the loop gives the derivative of the algorithm, not necessarily the derivative of the mathematical fixed point. These can differ.
This distinction matters.
| Target | Meaning |
|---|---|
| Differentiate the algorithm | Derivative of the finite executed computation |
| Differentiate the solution | Derivative of the mathematical object computed |
| Differentiate the implementation | Derivative of the exact source-level operations |
For solvers, implicit differentiation may be preferable to differentiating every iteration.
Differentiable Programming in Practice
Modern systems implement differentiable programming with different tradeoffs.
| System | Main style |
|---|---|
| PyTorch | Dynamic tracing with eager execution |
| TensorFlow | Graph tracing and compiler paths |
| JAX | Pure functional transformations over traced programs |
| Julia Zygote | Source-to-source AD |
| Enzyme | Compiler IR-level AD |
| Swift AD | Language-integrated typed AD |
| Taichi | Differentiable simulation DSL |
| Dr.Jit | Differentiable rendering and simulation kernels |
The field has no single dominant architecture. The right design depends on the host language, execution target, and expected workloads.
Differentiable Programming Beyond Machine Learning
Differentiable programming is useful wherever a program contains parameters that should be optimized.
Examples include:
| Domain | Differentiable program |
|---|---|
| Rendering | Image formation pipeline |
| Physics | Simulator with tunable parameters |
| Robotics | Controller and dynamics model |
| Finance | Pricing model and risk objective |
| Biology | Kinetic model or molecular simulation |
| Databases | Learned cost model or differentiable query component |
| Compilers | Autotuning objective |
| Control | Trajectory optimizer |
The common pattern is objective-driven computation. A program computes a loss, error, likelihood, reward, or constraint violation. AD provides the derivative needed to improve parameters.
Design Requirements
A serious differentiable programming system should provide:
| Requirement | Reason |
|---|---|
| Correct derivative semantics | Users must know what is being differentiated |
| Efficient reverse mode | Scalar losses over many parameters are common |
| Forward mode support | Needed for JVPs, Jacobians, and implicit methods |
| Higher-order derivatives | Needed for curvature and meta-optimization |
| Custom derivative rules | Needed for stability and performance |
| Control-flow support | Real programs branch and loop |
| Mutation model | State must have defined adjoint behavior |
| Compiler optimization | Naive AD produces inefficient code |
| Debugging tools | Gradients need inspection and testing |
Differentiable programming becomes a systems problem, not only a calculus problem.
Failure Modes
Differentiable programming systems fail in characteristic ways.
| Failure mode | Example |
|---|---|
| Wrong derivative target | Differentiating solver iterations instead of solved equation |
| Memory blowup | Reverse mode stores every intermediate |
| Silent zero gradients | Discrete operations cut gradient flow |
| Numerically unstable gradients | Naive rules amplify floating point error |
| Excessive recompilation | Dynamic shapes or branches trigger new traces |
| Perturbation confusion | Nested AD mixes derivative levels |
| Invalid custom rules | User-supplied adjoints violate the true derivative |
These failures are often subtle. Good systems expose diagnostics, gradient checks, and explicit boundaries between differentiable and non-differentiable code.
Summary
Differentiable programming generalizes automatic differentiation from numerical kernels to whole programs. It asks the language and compiler to treat derivatives as ordinary program transformations.
The core challenge is semantic clarity. A user should know whether the system differentiates the mathematical function, the algorithm, the implementation, or a custom abstraction exposed by a library. Once that boundary is clear, the remaining problems are compiler and runtime engineering: representation, memory, optimization, dispatch, and hardware lowering.