# AD in Python

## AD in Python

Python became the dominant language for modern machine learning and differentiable computing because it combines a simple programming model with access to high-performance native libraries. Most Python automatic differentiation systems therefore follow a hybrid architecture:

| Layer | Role |
|---|---|
| Python frontend | User-facing model and control logic |
| Tensor runtime | Dense array execution |
| AD engine | Gradient propagation |
| Native backend | CPU/GPU kernels in C/C++/CUDA |
| Compiler subsystem | Graph optimization and lowering |

Python itself is slow for numerical kernels. The important observation is that tensor operations are executed outside the Python interpreter. The AD system therefore differentiates tensor programs driven by Python control.

### Tensor-Centric Computation

Modern Python AD systems are built around tensors.

A tensor object typically contains:

| Component | Meaning |
|---|---|
| Shape | Tensor dimensions |
| Dtype | Numeric type |
| Storage | Underlying memory |
| Device | CPU, GPU, TPU |
| Gradient metadata | Information for reverse mode |
| Graph reference | Dependency structure |

A simple example:

```python
x = tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x * x
z = y.sum()
z.backward()
```

After execution:

```python
x.grad
```

contains:

$$
[2, 4, 6]
$$

The user writes imperative Python code, but the AD system records tensor operations and constructs derivative propagation rules internally.

### Dynamic Computational Graphs

Many Python systems use dynamic graphs.

A graph is built during execution. Each tensor operation creates graph nodes connecting inputs to outputs.

For example:

```python
a = x + 1
b = sin(x)
c = a * b
```

produces a graph:

```text
x
├── add ── a
├── sin ── b
└── mul(a,b) ── c
```

Each node stores:

| Field | Purpose |
|---|---|
| Operation type | Determines derivative rule |
| Inputs | Parent references |
| Outputs | Result tensor |
| Saved tensors | Needed for backward pass |
| Backward function | Propagates adjoints |

Dynamic graphs are flexible because they naturally support:

- Loops
- Branches
- Recursion
- Variable shapes
- Interactive execution

This matched the needs of machine learning research, where models evolve rapidly.

### Reverse Mode in Python Systems

Most Python AD frameworks optimize for scalar-loss reverse mode because neural network training requires gradients with respect to many parameters.

Suppose:

```python
y = f(x)
```

where:

$$
f : \mathbb{R}^n \rightarrow \mathbb{R}
$$

Reverse mode computes:

$$
\nabla f(x)
$$

with cost proportional to a small multiple of the forward evaluation.

Internally, the reverse pass traverses the graph backward:

1. Initialize output adjoint to `1`.
2. Visit graph nodes in reverse topological order.
3. Apply local derivative rules.
4. Accumulate gradients into parent tensors.

For multiplication:

$$
z = xy
$$

the reverse rule is:

$$
\bar{x} += \bar{z}y
$$

$$
\bar{y} += \bar{z}x
$$

$$
z = xy
$$

The Python frontend hides this machinery, but the runtime manages graph traversal, tensor lifetimes, and adjoint accumulation.

### Eager Execution

Many Python systems use eager execution.

Operations execute immediately:

```python
x = tensor([1, 2, 3])
y = x + 1
```

`y` is computed immediately rather than deferred into a static graph.

Advantages include:

| Advantage | Explanation |
|---|---|
| Debuggability | Intermediate values are visible |
| Natural control flow | Python semantics preserved |
| Interactive workflows | REPL and notebooks work naturally |
| Simpler mental model | Execution follows source order |

The downside is optimization difficulty. Since the runtime sees operations incrementally, it has limited global visibility.

### Tracing and Graph Capture

To recover optimization opportunities, many systems trace Python functions into graph representations.

Example:

```python
def f(x):
    return sin(x) + x * x
```

Tracing executes the function with special tensor objects that record operations instead of performing ordinary computation.

The result is an intermediate graph:

```text
x
├── sin
├── mul(x,x)
└── add
```

The graph can then be optimized, fused, compiled, or lowered to accelerators.

Tracing enables:

| Optimization | Purpose |
|---|---|
| Kernel fusion | Reduce launch overhead |
| Constant folding | Eliminate redundant computation |
| Memory planning | Reuse buffers |
| Vectorization | Improve throughput |
| Device lowering | Generate accelerator code |

This produces a hybrid execution model:

| Mode | Behavior |
|---|---|
| Eager mode | Flexible interactive execution |
| Traced mode | Optimized graph execution |

### Static vs Dynamic Graph Systems

Early Python AD systems often used static graphs.

The user first defined a graph:

```python
x = placeholder()
y = x * x
```

and later executed it:

```python
session.run(y, feed_dict={x: ...})
```

Static graphs enabled aggressive optimization but created awkward programming models.

Dynamic systems later became dominant because they matched ordinary Python execution.

The distinction today is less strict. Modern systems often combine both:

| System | Main style |
|---|---|
| PyTorch | Dynamic eager execution |
| TensorFlow 1.x | Static graph |
| TensorFlow 2.x | Eager + tracing |
| JAX | Functional tracing |
| Tinygrad | Minimal dynamic graph |
| MindSpore | Graph-oriented hybrid execution |

### Mutation and In-Place Operations

Python tensor systems frequently support mutation:

```python
x += y
```

Mutation complicates reverse mode because the old value of `x` may be needed during the backward pass.

Systems handle this differently.

| Strategy | Explanation |
|---|---|
| Disallow unsafe mutation | Simplifies correctness |
| Version counters | Detect illegal overwrites |
| Functionalization | Rewrite mutation into pure operations |
| Copy-on-write | Preserve old values automatically |
| Tape snapshots | Save overwritten tensors |

PyTorch, for example, tracks tensor versions to detect modifications that invalidate gradient computation.

### Custom Gradient Functions

Many operations need manually defined derivatives.

A Python framework usually exposes an API:

```python
class MyOp(Function):

    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return ...

    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        return ...
```

This separates:

| Phase | Role |
|---|---|
| Forward | Compute primal output |
| Context storage | Save required intermediates |
| Backward | Compute adjoint propagation |

Custom gradients are critical for:

- Numerical stability
- Efficient kernels
- External libraries
- Implicit differentiation
- Physics simulators
- Specialized GPU operations

### Higher-Order Differentiation

Python systems increasingly support higher-order derivatives.

Example:

```python
g = grad(f)
h = grad(g)
```

This requires differentiating the backward pass itself.

The system must ensure:

- Reverse-mode operations are differentiable
- Graphs remain valid through nesting
- Saved tensors survive nested passes
- Perturbation confusion is avoided

Higher-order AD is important for:

| Application | Need |
|---|---|
| Meta-learning | Differentiate optimization |
| Physics | Curvature information |
| Scientific computing | Hessian-vector products |
| Implicit methods | Jacobian structure |
| Probabilistic inference | Laplace approximations |

### Python and Functional Transformations

Some Python AD systems adopt a more functional style.

Instead of mutating tensor objects:

```python
y.backward()
```

they expose explicit transformations:

```python
grad(f)
vmap(f)
jit(f)
pmap(f)
```

These transformations compose.

Example:

```python
jit(grad(f))
```

means:

1. Differentiate `f`
2. Compile the resulting derivative program

This model treats differentiation as a pure program transformation rather than as a side effect attached to tensors.

### Compilation Pipelines

Modern Python AD frameworks often lower programs into compiler IRs.

The pipeline may look like:

```text
Python
→ traced graph
→ normalized IR
→ optimized IR
→ backend lowering
→ machine code or accelerator kernels
```

Common IR forms include:

| IR | Purpose |
|---|---|
| FX graphs | Python graph capture |
| HLO | Tensor compiler IR |
| MLIR | Multi-level compiler infrastructure |
| XLA graphs | Accelerator optimization |
| TorchScript IR | PyTorch compilation |

AD may operate at multiple levels:

| Level | Differentiation target |
|---|---|
| Python AST | Source transformation |
| Runtime graph | Dynamic tracing |
| Tensor IR | Graph-level AD |
| LLVM IR | Low-level compiler differentiation |

### Interaction with NumPy

NumPy heavily influenced Python AD systems.

Many frameworks mimic NumPy APIs:

```python
sin
exp
matmul
reshape
broadcast_to
sum
transpose
```

This allows numerical code to become differentiable with minimal changes.

However, ordinary NumPy arrays do not carry gradient metadata. Frameworks therefore provide tensor types that emulate NumPy behavior while tracking derivatives.

Compatibility layers are essential for ecosystem adoption.

### GPU and Accelerator Execution

Python frameworks usually execute tensor kernels outside Python.

The Python interpreter orchestrates computation, but dense operations are dispatched to:

| Backend | Typical implementation |
|---|---|
| CPU | BLAS, vectorized kernels |
| GPU | CUDA or ROCm kernels |
| TPU | XLA-compiled programs |
| Specialized accelerators | Vendor-specific runtimes |

AD systems must therefore manage:

- Device placement
- Gradient synchronization
- Memory transfers
- Kernel scheduling
- Mixed precision

The derivative computation becomes part of a distributed runtime system.

### Memory Management

Reverse mode requires storing intermediates from the forward pass.

Memory costs can dominate execution.

Strategies include:

| Technique | Purpose |
|---|---|
| Gradient checkpointing | Recompute instead of storing |
| Activation rematerialization | Trade compute for memory |
| Buffer reuse | Reduce allocations |
| Lazy gradient allocation | Allocate only when needed |
| Static memory planning | Optimize graph execution |

Large neural networks are often constrained more by activation memory than by arithmetic throughput.

### Numerical Stability

Naive derivatives can be numerically unstable.

Examples include:

| Operation | Problem |
|---|---|
| Softmax | Overflow |
| Logarithm | Singularities near zero |
| Division | Unstable denominators |
| Exponentials | Exploding gradients |
| Normalization | Small variance instability |

Python AD systems therefore rely heavily on custom stable primitives.

For example:

```python
logsumexp
softplus
cross_entropy
layer_norm
```

often have carefully engineered backward implementations.

### Major Python AD Systems

| System | Main characteristics |
|---|---|
| PyTorch | Dynamic eager reverse mode |
| TensorFlow | Hybrid graph/eager system |
| JAX | Functional transformations and tracing |
| Autograd | Pure NumPy-based tracing |
| Tinygrad | Minimal educational framework |
| MindSpore | Graph-oriented execution |
| Chainer | Early dynamic graph system |

These systems differ mainly in:

- Graph construction strategy
- Compilation model
- Mutation semantics
- Transformation interface
- Hardware integration

### Python as an AD Host Language

Python succeeded because it provided:

| Feature | Importance |
|---|---|
| Simple syntax | Rapid experimentation |
| Scientific ecosystem | NumPy, SciPy, plotting |
| Dynamic execution | Flexible model definition |
| Native extension support | Access to optimized kernels |
| Interactive workflow | Notebook-based research |

The AD engine is usually not written primarily in Python. Python acts as the orchestration layer above highly optimized native runtimes and compiler systems.

Modern Python AD frameworks therefore resemble compiler toolchains hidden behind an imperative scripting interface.

