Skip to content

AD in Rust

Rust is an attractive language for automatic differentiation because it combines low-level performance with strong static guarantees. It gives the programmer control over...

Rust is an attractive language for automatic differentiation because it combines low-level performance with strong static guarantees. It gives the programmer control over memory layout and allocation, while the type system tracks ownership, borrowing, lifetimes, mutation, and thread safety. These properties are useful for building AD systems that are fast, safe, and explicit about state.

Rust also makes AD harder in some places. Operator overloading is available through traits, but the language does not have C++-style implicit conversions or template specialization. Mutation is carefully controlled, which helps correctness but requires deliberate design for reverse-mode tapes and adjoint storage.

Scalar Forward Mode

Forward mode in Rust can be implemented with a dual-number type.

#[derive(Clone, Copy, Debug)]
pub struct Dual {
    pub x: f64,
    pub dx: f64,
}

Arithmetic is implemented through standard traits:

use std::ops::{Add, Mul};

impl Add for Dual {
    type Output = Dual;

    fn add(self, rhs: Dual) -> Dual {
        Dual {
            x: self.x + rhs.x,
            dx: self.dx + rhs.dx,
        }
    }
}

impl Mul for Dual {
    type Output = Dual;

    fn mul(self, rhs: Dual) -> Dual {
        Dual {
            x: self.x * rhs.x,
            dx: self.dx * rhs.x + self.x * rhs.dx,
        }
    }
}

A generic function can then be evaluated on either f64 or Dual if its operations are expressed through traits.

fn f(x: Dual) -> Dual {
    let one = Dual { x: 1.0, dx: 0.0 };
    (x + one) * sin(x)
}

With x = Dual { x: 2.0, dx: 1.0 }, the result contains both the primal value and the derivative with respect to x.

Traits as the Numeric Interface

Rust AD libraries usually define traits for differentiable scalar behavior.

pub trait Scalar:
    Copy
    + Add<Output = Self>
    + Mul<Output = Self>
{
    fn from_f64(x: f64) -> Self;
    fn sin(self) -> Self;
    fn cos(self) -> Self;
}

The same function can then run over multiple scalar domains:

fn model<T: Scalar>(x: T) -> T {
    (x + T::from_f64(1.0)) * x.sin()
}

This style is close to generic numerical programming. The user writes the function once. The AD system supplies the scalar type.

The advantage is static dispatch. The compiler can inline operations and remove many abstractions. The limitation is that all needed operations must be represented in trait bounds, and error messages can become complex.

Reverse Mode with Arenas

Reverse mode needs shared graph structure. Rust ownership rules make this design explicit.

A common approach is to store graph nodes in an arena and refer to them by index.

#[derive(Clone, Copy, Debug)]
pub struct Var {
    id: usize,
}

pub struct Node {
    value: f64,
    adjoint: f64,
    parents: Vec<(usize, f64)>,
}

pub struct Tape {
    nodes: Vec<Node>,
}

For an operation such as multiplication:

pub fn mul(tape: &mut Tape, a: Var, b: Var) -> Var {
    let av = tape.nodes[a.id].value;
    let bv = tape.nodes[b.id].value;

    let id = tape.nodes.len();

    tape.nodes.push(Node {
        value: av * bv,
        adjoint: 0.0,
        parents: vec![(a.id, bv), (b.id, av)],
    });

    Var { id }
}

The reverse pass walks nodes backward:

pub fn backward(tape: &mut Tape, out: Var) {
    tape.nodes[out.id].adjoint = 1.0;

    for i in (0..=out.id).rev() {
        let adj = tape.nodes[i].adjoint;
        let parents = tape.nodes[i].parents.clone();

        for (p, w) in parents {
            tape.nodes[p].adjoint += adj * w;
        }
    }
}

This design avoids reference cycles. It also makes lifetimes simple: variables are valid only with respect to the tape that created them.

Lifetimes and Tape Safety

Rust can encode tape ownership in the type system.

A variable may carry a lifetime tied to its tape:

pub struct Var<'t> {
    id: usize,
    tape: &'t Tape,
}

This prevents a variable from outliving the tape. However, reverse mode usually needs mutable access to the tape to append nodes and accumulate adjoints. That pushes many designs toward one of these patterns:

PatternDescription
Arena plus indicesVariables store stable node IDs
Interior mutabilityTape uses RefCell, Cell, or UnsafeCell
Context objectUser passes &mut Tape explicitly
Thread-local tapeEasier syntax, weaker explicitness
Static graph builderBuild graph first, execute later

The explicit context pattern is verbose but clear. Interior mutability improves ergonomics but moves some checking from compile time to runtime.

Mutation and Borrowing

Rust’s borrowing model helps expose unsafe AD designs. In reverse mode, gradients are accumulated into adjoint storage. This is mutation. Rust requires that mutation be exclusive or mediated through safe abstractions.

For example, this is straightforward:

let mut tape = Tape::new();
let x = tape.var(2.0);
let y = tape.var(3.0);
let z = tape.mul(x, y);
tape.backward(z);

The tape owns all mutable state. Variables are small handles.

Problems arise when the user wants ordinary arithmetic syntax:

let z = x * y + x.sin();

For this to work, x and y must know where to record operations. That usually requires shared access to a tape. Shared mutable access conflicts with Rust’s default ownership model, so libraries often use reference-counted pointers plus interior mutability:

Rc<RefCell<Tape>>

or thread-safe variants:

Arc<Mutex<Tape>>

These designs trade static simplicity for ergonomic operator syntax.

Tensors in Rust

Scalar AD is useful for teaching and small problems. Practical AD usually needs tensors.

A tensor object may contain:

pub struct Tensor {
    storage: Arc<Storage>,
    shape: Vec<usize>,
    strides: Vec<usize>,
    dtype: DType,
    device: Device,
    grad: Option<Arc<Storage>>,
    op: Option<OpRef>,
}

The AD system tracks tensor operations rather than scalar arithmetic. This reduces graph size. One node represents a whole matrix multiplication, convolution, or reduction.

For tensor reverse mode, each operation defines a backward rule:

OperationBackward behavior
AddPass adjoint to both inputs
MultiplyMultiply adjoint by opposite primal
MatmulUse transposed matrix products
SumBroadcast adjoint to input shape
ReshapeReshape adjoint back
SliceScatter adjoint into input shape

Rust is well suited for this because tensor metadata and storage ownership can be modeled explicitly.

Shape and Type Safety

Rust’s type system can encode some tensor information statically.

A dynamic tensor type stores shape at runtime:

pub struct Tensor {
    shape: Vec<usize>,
}

A statically shaped tensor type may use const generics:

pub struct Tensor<const N: usize> {
    data: [f64; N],
}

For matrices:

pub struct Matrix<const M: usize, const N: usize> {
    data: [[f64; N]; M],
}

This allows the compiler to reject shape-invalid operations:

fn matmul<const M: usize, const K: usize, const N: usize>(
    a: Matrix<M, K>,
    b: Matrix<K, N>,
) -> Matrix<M, N> {
    ...
}

Static shapes improve safety and performance for small fixed-size computations. Dynamic shapes remain necessary for most machine learning workloads.

Source Transformation and Procedural Macros

Rust procedural macros can inspect and rewrite syntax. This suggests a source-transformation path for AD.

A function might be annotated:

#[differentiate]
fn f(x: f64) -> f64 {
    (x + 1.0) * x.sin()
}

The macro could generate:

fn f_grad(x: f64) -> f64 {
    ...
}

This approach can produce efficient derivative code. However, macro-based AD has limits:

IssueExplanation
Type informationProcedural macros run before full type checking
Borrow semanticsGenerated code must preserve ownership rules
Trait dispatchActual operation may depend on resolved impls
Control flowLoops and branches need careful transformation
External callsDerivative rules must be available

A deeper compiler integration would have access to more semantic information, but Rust’s stable compiler plugin surface is limited.

Compiler IR Approaches

Another route is differentiating lower-level IR, such as LLVM IR. Since Rust lowers to LLVM, compiler-level AD tools can in principle differentiate Rust code after monomorphization and optimization.

This can handle generic code after type resolution. It can also differentiate through inlined library code.

The challenge is that Rust’s high-level ownership information has been lowered into memory operations. The AD tool must reason about loads, stores, aliasing, and activity at the IR level.

Unsafe Code

High-performance Rust libraries often use unsafe internally for vectorization, custom allocators, GPU bindings, or unchecked indexing.

AD systems must treat unsafe regions carefully. The Rust type system guarantees memory safety only outside invalid unsafe implementations. If unsafe code violates aliasing or lifetime assumptions, derivative code may also be invalid.

A practical design should keep unsafe code behind small, tested primitives with explicit derivative rules.

Parallelism

Rust’s ownership model is useful for parallel AD.

Forward tensor operations can run in parallel when inputs are immutable. Reverse-mode accumulation is harder because multiple downstream paths may contribute to the same adjoint.

Common approaches include:

ApproachUse
Topological schedulingProcess graph levels safely
Atomic accumulationAllow concurrent adjoint updates
Thread-local gradientsReduce later
Graph partitioningParallelize independent regions
Static compilationPlan parallel reverse execution

Rust’s Send and Sync traits make thread-safety requirements explicit. This helps library authors define which tensors, tapes, and contexts can cross threads.

GPU Integration

Rust AD frameworks need access to GPU execution for large tensor workloads.

Typical strategies include:

StrategyDescription
Bind to CUDA or ROCmUse vendor libraries
Generate kernelsEmit GPU code from tensor graph
Use WebGPU or wgpuPortable GPU backend
Call existing runtimesInteroperate with C/C++ tensor libraries
Compile through MLIR or XLALower graph to accelerator IR

Rust can provide a safe frontend around unsafe GPU APIs, but the AD system still needs custom backward rules for kernels.

Error Handling

Rust uses Result for recoverable errors.

Differentiable code may fail due to:

  • Shape mismatch
  • Invalid device transfer
  • Unsupported operation
  • Singular matrix
  • Non-finite values
  • Missing derivative rule

A serious Rust AD library should surface these errors explicitly rather than panic in ordinary use.

pub fn backward(&mut self, out: Tensor) -> Result<(), Error> {
    ...
}

This fits Rust’s broader design style.

Representative Rust AD Directions

Rust’s AD ecosystem is smaller than Python’s, but several design directions are common.

DirectionDescription
Scalar AD cratesDual numbers and small reverse-mode systems
Tensor librariesAutograd attached to tensor operations
Compiler-assisted ADDifferentiation through lowered IR
ML frameworksRust-native training stacks
Bindings to existing systemsSafe Rust frontend over C++/Python runtimes

Rust is especially attractive for embedded optimization, simulation engines, inference runtimes, differentiable graphics kernels, and systems where memory safety matters.

Design Guidance

A Rust AD system should choose one clear semantic center.

For a small numerical library, use explicit forward mode with dual numbers and traits. This gives simple code, static dispatch, and low overhead.

For tensor machine learning, use graph-based reverse mode over tensor primitives. Keep scalar operator overloading secondary. A graph node per tensor operation is much more efficient than a graph node per scalar operation.

For scientific computing, expose explicit contexts and avoid hidden global tapes. Users in this domain often prefer predictable allocation and clear ownership.

For compiler-level AD, integrate with existing Rust compilation stages or use LLVM-based differentiation. This path offers performance but requires careful handling of memory effects.

Rust’s main contribution to AD is discipline. It forces the implementation to specify who owns the tape, who mutates adjoints, when intermediate values live, and which computations are thread-safe. That discipline makes ergonomic design harder, but it also reduces classes of runtime errors that are common in dynamic AD systems.