Differential equations are one of the main reasons automatic differentiation matters in scientific computing. Many scientific models are not written as closed-form functions....
Differential Equations
Differential equations are one of the main reasons automatic differentiation matters in scientific computing. Many scientific models are not written as closed-form functions. They are written as equations describing how a system changes over time, space, or both.
A simple ordinary differential equation has the form
where is the state, is time, and is a vector of parameters. Given an initial condition
a numerical solver computes an approximate trajectory
In many applications, we do not only want the trajectory. We also want derivatives of some quantity derived from the trajectory. For example, we may define a loss
and ask for
This is the central problem: differentiate through the solution of a differential equation.
The Solver as a Program
A numerical ODE solver is a program. Automatic differentiation applies to programs, so in principle we can differentiate the solver directly.
For explicit Euler, one step is
The full solver is the repeated composition
where each is one numerical step.
Forward mode propagates sensitivities alongside the state. If
then Euler gives
This is called the sensitivity equation in discrete form. It is often the most direct way to compute parameter derivatives when the number of parameters is small.
Continuous Sensitivity Equations
Instead of differentiating the numerical solver, we can differentiate the differential equation itself.
Starting from
differentiate both sides with respect to . Let
Then
The original state and its sensitivity can be solved together:
This method is useful when the state dimension is moderate and the parameter dimension is not too large. Its cost grows with the number of parameters, because has one column per parameter.
Reverse Mode and Adjoint Equations
When the number of parameters is large and the loss is scalar, reverse mode is usually preferable. This is common in optimization and machine learning.
Suppose the loss depends on the final state:
The continuous adjoint variable is
For the ODE
the adjoint equation runs backward in time:
$$ \frac{da}{dt} =
- a^\top \frac{\partial f}{\partial y}. $$
The terminal condition is
The parameter gradient is accumulated as
This is the continuous analogue of reverse-mode AD. Instead of storing every intermediate operation in a tape, the adjoint method solves another differential equation backward.
Discrete Adjoint vs Continuous Adjoint
There are two different objects that are often confused.
A discrete adjoint differentiates the numerical solver exactly. If the solver takes steps
then reverse-mode AD propagates adjoints backward through those exact steps.
A continuous adjoint differentiates the mathematical ODE first, then discretizes the adjoint equation.
These methods can give different gradients, because differentiation and discretization do not always commute.
| Method | What is differentiated | Main advantage | Main risk |
|---|---|---|---|
| Discrete adjoint | The numerical solver | Gradient matches computed solution | May require storing solver states |
| Continuous adjoint | The continuous ODE | Lower memory in some systems | Gradient may differ from solver gradient |
| Forward sensitivity | State equation plus sensitivity equation | Simple and stable for few parameters | Cost grows with parameter count |
For machine learning, the discrete adjoint is often the safer interpretation: it gives the derivative of the actual computation performed. For scientific modeling, the continuous adjoint may be more natural when the differential equation is the primary object.
Differentiating Through Adaptive Solvers
Many ODE solvers use adaptive step sizes. The solver chooses based on an error estimate. This creates extra complications.
A simplified adaptive solver does something like:
The step size now depends on the state and parameters. A fully discrete derivative must account for this dependence.
However, many practical AD systems treat solver control decisions as fixed during differentiation. That means the gradient is computed through the accepted numerical steps, but not through the logic that chose those steps. This is usually acceptable when step-size decisions are stable under small perturbations. It can fail near discontinuous accept/reject boundaries.
Adaptive solvers therefore introduce three derivative layers:
| Layer | Meaning |
|---|---|
| Model derivative | Derivative of |
| Solver derivative | Derivative through numerical updates |
| Control derivative | Derivative through step-size and branch decisions |
The first two are standard. The third is delicate because solver control flow is piecewise constant or discontinuous.
Stiff Differential Equations
A stiff ODE contains dynamics on very different time scales. Explicit methods may require extremely small steps for stability. Implicit solvers are often used instead.
An implicit Euler step is defined by
Here is defined implicitly. The solver usually finds it with Newton iteration.
Differentiating this step can be done in two ways.
The first way is to differentiate through every Newton iteration. This is simple if the solver is implemented in an AD system, but it can be expensive and may expose numerical details that are not mathematically relevant.
The second way is implicit differentiation. Define
Since
we differentiate implicitly:
Thus
This avoids differentiating through the internal iterations of the nonlinear solver. It treats the converged implicit step as the mathematical operation.
PDEs and Discretized Systems
Partial differential equations introduce derivatives over both time and space. A PDE such as
is usually discretized into a large system of ODEs. For example, after spatial discretization, we may get
Automatic differentiation then applies to the discretized program: mesh construction, finite difference stencils, finite element assembly, linear solves, time stepping, and loss evaluation.
The challenge is scale. PDE solvers may contain millions or billions of state variables. Reverse-mode AD over the entire solver can be memory-heavy. Efficient systems use structure:
| Structure | AD consequence |
|---|---|
| Sparse matrices | Sparse Jacobian and adjoint operations |
| Local stencils | Local derivative propagation |
| Linear solves | Custom VJP rules using transpose solves |
| Mesh hierarchy | Multilevel adjoint methods |
| Time stepping | Checkpointing and recomputation |
The key is to avoid materializing dense Jacobians. Most scientific AD systems compute products such as
or
without constructing explicitly.
Differentiating Linear Solves
Linear solves appear constantly in differential equation solvers. Suppose
or equivalently
For a perturbation,
Therefore
In reverse mode, if the adjoint of is , then the adjoints are computed using
followed by
and
This rule is central for differentiating implicit time steps, finite element methods, optimization layers, and constrained physical systems.
Memory and Checkpointing
Reverse-mode differentiation through a long time integration requires access to previous states. A naive implementation stores every state:
This has memory cost .
Checkpointing reduces memory by storing only selected states and recomputing missing states during the backward pass. This trades extra computation for lower memory.
| Strategy | Memory | Extra compute |
|---|---|---|
| Store all states | High | Low |
| Store no states | Low | High |
| Checkpointing | Medium | Medium |
| Revolve-style schedules | Near optimal | Controlled |
For long simulations, checkpointing is often necessary. Without it, reverse-mode AD may be unusable even when the mathematical gradient is well-defined.
Practical Design Rule
A robust differentiable differential equation system usually separates four layers:
| Layer | Responsibility |
|---|---|
| Model | Defines |
| Solver | Advances the state |
| Linear algebra | Solves sparse or dense systems |
| AD rules | Define JVPs and VJPs for each primitive |
This separation matters. Differentiating through every low-level operation is general, but often inefficient. Scientific solvers usually need custom derivative rules for linear solves, nonlinear solves, interpolation, event handling, and time stepping.
Example: Parameter Estimation
Suppose a physical system is modeled by
We observe measurements
and define
To fit , we need . AD provides this gradient by differentiating the solve procedure. An optimizer then updates :
This gives a standard inverse-problem loop:
- solve the differential equation,
- compare the solution with observations,
- differentiate the loss,
- update the parameters.
The same pattern appears in system identification, climate modeling, robotics, molecular simulation, pharmacokinetics, and neural differential equations.
Summary
Differential equations turn AD from a local derivative tool into a method for differentiating entire simulations. The main issue is not whether derivatives exist. The main issue is how to compute them with acceptable accuracy, memory use, and runtime.
Forward sensitivity methods are simple and reliable when the parameter dimension is small. Reverse and adjoint methods are better when a scalar loss depends on many parameters. Discrete adjoints differentiate the actual numerical computation. Continuous adjoints differentiate the mathematical model and then solve the adjoint equation.
For production scientific computing, efficient AD requires more than applying reverse mode blindly. It needs solver-aware derivative rules, sparse linear algebra, checkpointing, and careful treatment of adaptive control flow.