Skip to content

AD in Python

Python became the dominant language for modern machine learning and differentiable computing because it combines a simple programming model with access to high-performance...

Python became the dominant language for modern machine learning and differentiable computing because it combines a simple programming model with access to high-performance native libraries. Most Python automatic differentiation systems therefore follow a hybrid architecture:

LayerRole
Python frontendUser-facing model and control logic
Tensor runtimeDense array execution
AD engineGradient propagation
Native backendCPU/GPU kernels in C/C++/CUDA
Compiler subsystemGraph optimization and lowering

Python itself is slow for numerical kernels. The important observation is that tensor operations are executed outside the Python interpreter. The AD system therefore differentiates tensor programs driven by Python control.

Tensor-Centric Computation

Modern Python AD systems are built around tensors.

A tensor object typically contains:

ComponentMeaning
ShapeTensor dimensions
DtypeNumeric type
StorageUnderlying memory
DeviceCPU, GPU, TPU
Gradient metadataInformation for reverse mode
Graph referenceDependency structure

A simple example:

x = tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x * x
z = y.sum()
z.backward()

After execution:

x.grad

contains:

[2,4,6] [2, 4, 6]

The user writes imperative Python code, but the AD system records tensor operations and constructs derivative propagation rules internally.

Dynamic Computational Graphs

Many Python systems use dynamic graphs.

A graph is built during execution. Each tensor operation creates graph nodes connecting inputs to outputs.

For example:

a = x + 1
b = sin(x)
c = a * b

produces a graph:

x
├── add ── a
├── sin ── b
└── mul(a,b) ── c

Each node stores:

FieldPurpose
Operation typeDetermines derivative rule
InputsParent references
OutputsResult tensor
Saved tensorsNeeded for backward pass
Backward functionPropagates adjoints

Dynamic graphs are flexible because they naturally support:

  • Loops
  • Branches
  • Recursion
  • Variable shapes
  • Interactive execution

This matched the needs of machine learning research, where models evolve rapidly.

Reverse Mode in Python Systems

Most Python AD frameworks optimize for scalar-loss reverse mode because neural network training requires gradients with respect to many parameters.

Suppose:

y = f(x)

where:

f:RnR f : \mathbb{R}^n \rightarrow \mathbb{R}

Reverse mode computes:

f(x) \nabla f(x)

with cost proportional to a small multiple of the forward evaluation.

Internally, the reverse pass traverses the graph backward:

  1. Initialize output adjoint to 1.
  2. Visit graph nodes in reverse topological order.
  3. Apply local derivative rules.
  4. Accumulate gradients into parent tensors.

For multiplication:

z=xy z = xy

the reverse rule is:

xˉ+=zˉy \bar{x} += \bar{z}y yˉ+=zˉx \bar{y} += \bar{z}x

z=xy z = xy

The Python frontend hides this machinery, but the runtime manages graph traversal, tensor lifetimes, and adjoint accumulation.

Eager Execution

Many Python systems use eager execution.

Operations execute immediately:

x = tensor([1, 2, 3])
y = x + 1

y is computed immediately rather than deferred into a static graph.

Advantages include:

AdvantageExplanation
DebuggabilityIntermediate values are visible
Natural control flowPython semantics preserved
Interactive workflowsREPL and notebooks work naturally
Simpler mental modelExecution follows source order

The downside is optimization difficulty. Since the runtime sees operations incrementally, it has limited global visibility.

Tracing and Graph Capture

To recover optimization opportunities, many systems trace Python functions into graph representations.

Example:

def f(x):
    return sin(x) + x * x

Tracing executes the function with special tensor objects that record operations instead of performing ordinary computation.

The result is an intermediate graph:

x
├── sin
├── mul(x,x)
└── add

The graph can then be optimized, fused, compiled, or lowered to accelerators.

Tracing enables:

OptimizationPurpose
Kernel fusionReduce launch overhead
Constant foldingEliminate redundant computation
Memory planningReuse buffers
VectorizationImprove throughput
Device loweringGenerate accelerator code

This produces a hybrid execution model:

ModeBehavior
Eager modeFlexible interactive execution
Traced modeOptimized graph execution

Static vs Dynamic Graph Systems

Early Python AD systems often used static graphs.

The user first defined a graph:

x = placeholder()
y = x * x

and later executed it:

session.run(y, feed_dict={x: ...})

Static graphs enabled aggressive optimization but created awkward programming models.

Dynamic systems later became dominant because they matched ordinary Python execution.

The distinction today is less strict. Modern systems often combine both:

SystemMain style
PyTorchDynamic eager execution
TensorFlow 1.xStatic graph
TensorFlow 2.xEager + tracing
JAXFunctional tracing
TinygradMinimal dynamic graph
MindSporeGraph-oriented hybrid execution

Mutation and In-Place Operations

Python tensor systems frequently support mutation:

x += y

Mutation complicates reverse mode because the old value of x may be needed during the backward pass.

Systems handle this differently.

StrategyExplanation
Disallow unsafe mutationSimplifies correctness
Version countersDetect illegal overwrites
FunctionalizationRewrite mutation into pure operations
Copy-on-writePreserve old values automatically
Tape snapshotsSave overwritten tensors

PyTorch, for example, tracks tensor versions to detect modifications that invalidate gradient computation.

Custom Gradient Functions

Many operations need manually defined derivatives.

A Python framework usually exposes an API:

class MyOp(Function):

    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return ...

    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        return ...

This separates:

PhaseRole
ForwardCompute primal output
Context storageSave required intermediates
BackwardCompute adjoint propagation

Custom gradients are critical for:

  • Numerical stability
  • Efficient kernels
  • External libraries
  • Implicit differentiation
  • Physics simulators
  • Specialized GPU operations

Higher-Order Differentiation

Python systems increasingly support higher-order derivatives.

Example:

g = grad(f)
h = grad(g)

This requires differentiating the backward pass itself.

The system must ensure:

  • Reverse-mode operations are differentiable
  • Graphs remain valid through nesting
  • Saved tensors survive nested passes
  • Perturbation confusion is avoided

Higher-order AD is important for:

ApplicationNeed
Meta-learningDifferentiate optimization
PhysicsCurvature information
Scientific computingHessian-vector products
Implicit methodsJacobian structure
Probabilistic inferenceLaplace approximations

Python and Functional Transformations

Some Python AD systems adopt a more functional style.

Instead of mutating tensor objects:

y.backward()

they expose explicit transformations:

grad(f)
vmap(f)
jit(f)
pmap(f)

These transformations compose.

Example:

jit(grad(f))

means:

  1. Differentiate f
  2. Compile the resulting derivative program

This model treats differentiation as a pure program transformation rather than as a side effect attached to tensors.

Compilation Pipelines

Modern Python AD frameworks often lower programs into compiler IRs.

The pipeline may look like:

Python
→ traced graph
→ normalized IR
→ optimized IR
→ backend lowering
→ machine code or accelerator kernels

Common IR forms include:

IRPurpose
FX graphsPython graph capture
HLOTensor compiler IR
MLIRMulti-level compiler infrastructure
XLA graphsAccelerator optimization
TorchScript IRPyTorch compilation

AD may operate at multiple levels:

LevelDifferentiation target
Python ASTSource transformation
Runtime graphDynamic tracing
Tensor IRGraph-level AD
LLVM IRLow-level compiler differentiation

Interaction with NumPy

NumPy heavily influenced Python AD systems.

Many frameworks mimic NumPy APIs:

sin
exp
matmul
reshape
broadcast_to
sum
transpose

This allows numerical code to become differentiable with minimal changes.

However, ordinary NumPy arrays do not carry gradient metadata. Frameworks therefore provide tensor types that emulate NumPy behavior while tracking derivatives.

Compatibility layers are essential for ecosystem adoption.

GPU and Accelerator Execution

Python frameworks usually execute tensor kernels outside Python.

The Python interpreter orchestrates computation, but dense operations are dispatched to:

BackendTypical implementation
CPUBLAS, vectorized kernels
GPUCUDA or ROCm kernels
TPUXLA-compiled programs
Specialized acceleratorsVendor-specific runtimes

AD systems must therefore manage:

  • Device placement
  • Gradient synchronization
  • Memory transfers
  • Kernel scheduling
  • Mixed precision

The derivative computation becomes part of a distributed runtime system.

Memory Management

Reverse mode requires storing intermediates from the forward pass.

Memory costs can dominate execution.

Strategies include:

TechniquePurpose
Gradient checkpointingRecompute instead of storing
Activation rematerializationTrade compute for memory
Buffer reuseReduce allocations
Lazy gradient allocationAllocate only when needed
Static memory planningOptimize graph execution

Large neural networks are often constrained more by activation memory than by arithmetic throughput.

Numerical Stability

Naive derivatives can be numerically unstable.

Examples include:

OperationProblem
SoftmaxOverflow
LogarithmSingularities near zero
DivisionUnstable denominators
ExponentialsExploding gradients
NormalizationSmall variance instability

Python AD systems therefore rely heavily on custom stable primitives.

For example:

logsumexp
softplus
cross_entropy
layer_norm

often have carefully engineered backward implementations.

Major Python AD Systems

SystemMain characteristics
PyTorchDynamic eager reverse mode
TensorFlowHybrid graph/eager system
JAXFunctional transformations and tracing
AutogradPure NumPy-based tracing
TinygradMinimal educational framework
MindSporeGraph-oriented execution
ChainerEarly dynamic graph system

These systems differ mainly in:

  • Graph construction strategy
  • Compilation model
  • Mutation semantics
  • Transformation interface
  • Hardware integration

Python as an AD Host Language

Python succeeded because it provided:

FeatureImportance
Simple syntaxRapid experimentation
Scientific ecosystemNumPy, SciPy, plotting
Dynamic executionFlexible model definition
Native extension supportAccess to optimized kernels
Interactive workflowNotebook-based research

The AD engine is usually not written primarily in Python. Python acts as the orchestration layer above highly optimized native runtimes and compiler systems.

Modern Python AD frameworks therefore resemble compiler toolchains hidden behind an imperative scripting interface.