Rust is an attractive language for automatic differentiation because it combines low-level performance with strong static guarantees. It gives the programmer control over...
Rust is an attractive language for automatic differentiation because it combines low-level performance with strong static guarantees. It gives the programmer control over memory layout and allocation, while the type system tracks ownership, borrowing, lifetimes, mutation, and thread safety. These properties are useful for building AD systems that are fast, safe, and explicit about state.
Rust also makes AD harder in some places. Operator overloading is available through traits, but the language does not have C++-style implicit conversions or template specialization. Mutation is carefully controlled, which helps correctness but requires deliberate design for reverse-mode tapes and adjoint storage.
Scalar Forward Mode
Forward mode in Rust can be implemented with a dual-number type.
#[derive(Clone, Copy, Debug)]
pub struct Dual {
pub x: f64,
pub dx: f64,
}Arithmetic is implemented through standard traits:
use std::ops::{Add, Mul};
impl Add for Dual {
type Output = Dual;
fn add(self, rhs: Dual) -> Dual {
Dual {
x: self.x + rhs.x,
dx: self.dx + rhs.dx,
}
}
}
impl Mul for Dual {
type Output = Dual;
fn mul(self, rhs: Dual) -> Dual {
Dual {
x: self.x * rhs.x,
dx: self.dx * rhs.x + self.x * rhs.dx,
}
}
}A generic function can then be evaluated on either f64 or Dual if its operations are expressed through traits.
fn f(x: Dual) -> Dual {
let one = Dual { x: 1.0, dx: 0.0 };
(x + one) * sin(x)
}With x = Dual { x: 2.0, dx: 1.0 }, the result contains both the primal value and the derivative with respect to x.
Traits as the Numeric Interface
Rust AD libraries usually define traits for differentiable scalar behavior.
pub trait Scalar:
Copy
+ Add<Output = Self>
+ Mul<Output = Self>
{
fn from_f64(x: f64) -> Self;
fn sin(self) -> Self;
fn cos(self) -> Self;
}The same function can then run over multiple scalar domains:
fn model<T: Scalar>(x: T) -> T {
(x + T::from_f64(1.0)) * x.sin()
}This style is close to generic numerical programming. The user writes the function once. The AD system supplies the scalar type.
The advantage is static dispatch. The compiler can inline operations and remove many abstractions. The limitation is that all needed operations must be represented in trait bounds, and error messages can become complex.
Reverse Mode with Arenas
Reverse mode needs shared graph structure. Rust ownership rules make this design explicit.
A common approach is to store graph nodes in an arena and refer to them by index.
#[derive(Clone, Copy, Debug)]
pub struct Var {
id: usize,
}
pub struct Node {
value: f64,
adjoint: f64,
parents: Vec<(usize, f64)>,
}
pub struct Tape {
nodes: Vec<Node>,
}For an operation such as multiplication:
pub fn mul(tape: &mut Tape, a: Var, b: Var) -> Var {
let av = tape.nodes[a.id].value;
let bv = tape.nodes[b.id].value;
let id = tape.nodes.len();
tape.nodes.push(Node {
value: av * bv,
adjoint: 0.0,
parents: vec![(a.id, bv), (b.id, av)],
});
Var { id }
}The reverse pass walks nodes backward:
pub fn backward(tape: &mut Tape, out: Var) {
tape.nodes[out.id].adjoint = 1.0;
for i in (0..=out.id).rev() {
let adj = tape.nodes[i].adjoint;
let parents = tape.nodes[i].parents.clone();
for (p, w) in parents {
tape.nodes[p].adjoint += adj * w;
}
}
}This design avoids reference cycles. It also makes lifetimes simple: variables are valid only with respect to the tape that created them.
Lifetimes and Tape Safety
Rust can encode tape ownership in the type system.
A variable may carry a lifetime tied to its tape:
pub struct Var<'t> {
id: usize,
tape: &'t Tape,
}This prevents a variable from outliving the tape. However, reverse mode usually needs mutable access to the tape to append nodes and accumulate adjoints. That pushes many designs toward one of these patterns:
| Pattern | Description |
|---|---|
| Arena plus indices | Variables store stable node IDs |
| Interior mutability | Tape uses RefCell, Cell, or UnsafeCell |
| Context object | User passes &mut Tape explicitly |
| Thread-local tape | Easier syntax, weaker explicitness |
| Static graph builder | Build graph first, execute later |
The explicit context pattern is verbose but clear. Interior mutability improves ergonomics but moves some checking from compile time to runtime.
Mutation and Borrowing
Rust’s borrowing model helps expose unsafe AD designs. In reverse mode, gradients are accumulated into adjoint storage. This is mutation. Rust requires that mutation be exclusive or mediated through safe abstractions.
For example, this is straightforward:
let mut tape = Tape::new();
let x = tape.var(2.0);
let y = tape.var(3.0);
let z = tape.mul(x, y);
tape.backward(z);The tape owns all mutable state. Variables are small handles.
Problems arise when the user wants ordinary arithmetic syntax:
let z = x * y + x.sin();For this to work, x and y must know where to record operations. That usually requires shared access to a tape. Shared mutable access conflicts with Rust’s default ownership model, so libraries often use reference-counted pointers plus interior mutability:
Rc<RefCell<Tape>>or thread-safe variants:
Arc<Mutex<Tape>>These designs trade static simplicity for ergonomic operator syntax.
Tensors in Rust
Scalar AD is useful for teaching and small problems. Practical AD usually needs tensors.
A tensor object may contain:
pub struct Tensor {
storage: Arc<Storage>,
shape: Vec<usize>,
strides: Vec<usize>,
dtype: DType,
device: Device,
grad: Option<Arc<Storage>>,
op: Option<OpRef>,
}The AD system tracks tensor operations rather than scalar arithmetic. This reduces graph size. One node represents a whole matrix multiplication, convolution, or reduction.
For tensor reverse mode, each operation defines a backward rule:
| Operation | Backward behavior |
|---|---|
| Add | Pass adjoint to both inputs |
| Multiply | Multiply adjoint by opposite primal |
| Matmul | Use transposed matrix products |
| Sum | Broadcast adjoint to input shape |
| Reshape | Reshape adjoint back |
| Slice | Scatter adjoint into input shape |
Rust is well suited for this because tensor metadata and storage ownership can be modeled explicitly.
Shape and Type Safety
Rust’s type system can encode some tensor information statically.
A dynamic tensor type stores shape at runtime:
pub struct Tensor {
shape: Vec<usize>,
}A statically shaped tensor type may use const generics:
pub struct Tensor<const N: usize> {
data: [f64; N],
}For matrices:
pub struct Matrix<const M: usize, const N: usize> {
data: [[f64; N]; M],
}This allows the compiler to reject shape-invalid operations:
fn matmul<const M: usize, const K: usize, const N: usize>(
a: Matrix<M, K>,
b: Matrix<K, N>,
) -> Matrix<M, N> {
...
}Static shapes improve safety and performance for small fixed-size computations. Dynamic shapes remain necessary for most machine learning workloads.
Source Transformation and Procedural Macros
Rust procedural macros can inspect and rewrite syntax. This suggests a source-transformation path for AD.
A function might be annotated:
#[differentiate]
fn f(x: f64) -> f64 {
(x + 1.0) * x.sin()
}The macro could generate:
fn f_grad(x: f64) -> f64 {
...
}This approach can produce efficient derivative code. However, macro-based AD has limits:
| Issue | Explanation |
|---|---|
| Type information | Procedural macros run before full type checking |
| Borrow semantics | Generated code must preserve ownership rules |
| Trait dispatch | Actual operation may depend on resolved impls |
| Control flow | Loops and branches need careful transformation |
| External calls | Derivative rules must be available |
A deeper compiler integration would have access to more semantic information, but Rust’s stable compiler plugin surface is limited.
Compiler IR Approaches
Another route is differentiating lower-level IR, such as LLVM IR. Since Rust lowers to LLVM, compiler-level AD tools can in principle differentiate Rust code after monomorphization and optimization.
This can handle generic code after type resolution. It can also differentiate through inlined library code.
The challenge is that Rust’s high-level ownership information has been lowered into memory operations. The AD tool must reason about loads, stores, aliasing, and activity at the IR level.
Unsafe Code
High-performance Rust libraries often use unsafe internally for vectorization, custom allocators, GPU bindings, or unchecked indexing.
AD systems must treat unsafe regions carefully. The Rust type system guarantees memory safety only outside invalid unsafe implementations. If unsafe code violates aliasing or lifetime assumptions, derivative code may also be invalid.
A practical design should keep unsafe code behind small, tested primitives with explicit derivative rules.
Parallelism
Rust’s ownership model is useful for parallel AD.
Forward tensor operations can run in parallel when inputs are immutable. Reverse-mode accumulation is harder because multiple downstream paths may contribute to the same adjoint.
Common approaches include:
| Approach | Use |
|---|---|
| Topological scheduling | Process graph levels safely |
| Atomic accumulation | Allow concurrent adjoint updates |
| Thread-local gradients | Reduce later |
| Graph partitioning | Parallelize independent regions |
| Static compilation | Plan parallel reverse execution |
Rust’s Send and Sync traits make thread-safety requirements explicit. This helps library authors define which tensors, tapes, and contexts can cross threads.
GPU Integration
Rust AD frameworks need access to GPU execution for large tensor workloads.
Typical strategies include:
| Strategy | Description |
|---|---|
| Bind to CUDA or ROCm | Use vendor libraries |
| Generate kernels | Emit GPU code from tensor graph |
| Use WebGPU or wgpu | Portable GPU backend |
| Call existing runtimes | Interoperate with C/C++ tensor libraries |
| Compile through MLIR or XLA | Lower graph to accelerator IR |
Rust can provide a safe frontend around unsafe GPU APIs, but the AD system still needs custom backward rules for kernels.
Error Handling
Rust uses Result for recoverable errors.
Differentiable code may fail due to:
- Shape mismatch
- Invalid device transfer
- Unsupported operation
- Singular matrix
- Non-finite values
- Missing derivative rule
A serious Rust AD library should surface these errors explicitly rather than panic in ordinary use.
pub fn backward(&mut self, out: Tensor) -> Result<(), Error> {
...
}This fits Rust’s broader design style.
Representative Rust AD Directions
Rust’s AD ecosystem is smaller than Python’s, but several design directions are common.
| Direction | Description |
|---|---|
| Scalar AD crates | Dual numbers and small reverse-mode systems |
| Tensor libraries | Autograd attached to tensor operations |
| Compiler-assisted AD | Differentiation through lowered IR |
| ML frameworks | Rust-native training stacks |
| Bindings to existing systems | Safe Rust frontend over C++/Python runtimes |
Rust is especially attractive for embedded optimization, simulation engines, inference runtimes, differentiable graphics kernels, and systems where memory safety matters.
Design Guidance
A Rust AD system should choose one clear semantic center.
For a small numerical library, use explicit forward mode with dual numbers and traits. This gives simple code, static dispatch, and low overhead.
For tensor machine learning, use graph-based reverse mode over tensor primitives. Keep scalar operator overloading secondary. A graph node per tensor operation is much more efficient than a graph node per scalar operation.
For scientific computing, expose explicit contexts and avoid hidden global tapes. Users in this domain often prefer predictable allocation and clear ownership.
For compiler-level AD, integrate with existing Rust compilation stages or use LLVM-based differentiation. This path offers performance but requires careful handling of memory effects.
Rust’s main contribution to AD is discipline. It forces the implementation to specify who owns the tape, who mutates adjoints, when intermediate values live, and which computations are thread-safe. That discipline makes ergonomic design harder, but it also reduces classes of runtime errors that are common in dynamic AD systems.