# Operator Overloading

## Operator Overloading

Operator overloading implements automatic differentiation by changing the meaning of ordinary arithmetic operations for special numeric objects.

Instead of rewriting the program, the AD system gives numbers extra derivative state. When the user writes:

```python
y = sin(x * x)
```

the operations `*` and `sin` execute derivative-aware methods. The program looks unchanged, but the values flowing through it are no longer plain scalars. They are AD values.

A forward-mode value may contain:

```text
value   primal value
dot     tangent value
```

A reverse-mode value may contain:

```text
value       primal value
parents     inputs that produced this value
backward    local adjoint rule
grad        accumulated adjoint
```

The arithmetic syntax stays familiar. The differentiation logic lives inside the numeric type.

### Forward Mode by Overloading

Forward mode is the simpler case.

Define a dual value:

```python
class Dual:
    def __init__(self, value, dot):
        self.value = value
        self.dot = dot
```

Then overload arithmetic:

```python
def add(x, y):
    return Dual(
        x.value + y.value,
        x.dot + y.dot,
    )

def mul(x, y):
    return Dual(
        x.value * y.value,
        x.dot * y.value + x.value * y.dot,
    )
```

For elementary functions:

```python
def sin(x):
    return Dual(
        math.sin(x.value),
        math.cos(x.value) * x.dot,
    )
```

Running the original program on `Dual` values computes both the primal result and the directional derivative.

```python
x = Dual(3.0, 1.0)
y = sin(x * x)

print(y.value)
print(y.dot)
```

The tangent seed `1.0` means “differentiate with respect to this input.” For a multivariable function, different seeds compute different Jacobian columns or Jacobian-vector products.

### Reverse Mode by Overloading

Reverse mode overloads operations to build a dynamic computation graph.

Each operation creates a node. The node stores its primal value and a local backward rule.

```python
class Var:
    def __init__(self, value):
        self.value = value
        self.grad = 0.0
        self.parents = []
```

Multiplication may create a result node like this:

```python
def mul(x, y):
    out = Var(x.value * y.value)

    def backward():
        x.grad += y.value * out.grad
        y.grad += x.value * out.grad

    out.parents = [x, y]
    out.backward = backward
    return out
```

For a function:

```python
x = Var(3.0)
a = x * x
y = sin(a)
```

the program builds a graph during execution:

```text
x -> mul -> a -> sin -> y
```

Calling `backward()` on `y` walks the graph in reverse topological order:

```text
y.grad = 1
a.grad += cos(a.value) * y.grad
x.grad += x.value * a.grad
x.grad += x.value * a.grad
```

The result is the gradient of `y` with respect to `x`.

### Dynamic Graphs

Operator overloading naturally supports dynamic execution.

The graph is built from the operations that actually run. If the program has a branch:

```python
def f(x):
    if x.value > 0:
        return x * x
    else:
        return -x
```

then only the selected branch enters the computation graph.

This makes operator overloading suitable for languages and workloads where control flow depends on runtime values. It is a good fit for Python-style programming, interactive exploration, and models whose structure changes between calls.

The derivative follows the executed path. At branch boundaries, the derivative is the derivative of the selected branch, not a symbolic derivative of all possible branches.

### Why It Is Simple to Implement

Operator overloading requires less compiler infrastructure than source transformation.

The system needs:

| Component | Role |
|---|---|
| AD numeric type | Stores primal and derivative state |
| Overloaded operators | Implement local derivative rules |
| Elementary function wrappers | Differentiate `sin`, `exp`, `log`, etc. |
| Graph storage | Needed for reverse mode |
| Backward traversal | Propagates adjoints |
| Gradient accumulation | Sums contributions from multiple paths |

The host language parser, type checker, optimizer, and runtime remain unchanged.

This is why many small AD systems use operator overloading first. A minimal reverse-mode engine can be written in a few hundred lines. A minimal forward-mode engine can be smaller.

### Advantages

Operator overloading is easy to embed into an existing language. It preserves the user’s programming style. Ordinary control flow works without special syntax. Loops, recursion, and conditionals build exactly the graph used by the primal computation.

It is also highly compositional. Once a type supports addition, multiplication, and elementary functions, user-defined functions composed from those operations become differentiable automatically.

For research code, this is often the fastest path to a working AD system.

### Limitations

Operator overloading has less static visibility than source transformation.

The system sees operations as they execute, not as a whole program before execution. This limits ahead-of-time optimization. It may allocate many small graph nodes. It may pay runtime dispatch overhead on each primitive operation. It may also struggle with low-level memory planning.

Reverse-mode overloading can also create large dynamic tapes. Every differentiable operation may allocate a node and save primal values needed for backward propagation.

For scalar-heavy code, this overhead can dominate runtime.

For tensor-heavy code, the overhead is often acceptable because each graph node may represent a large tensor operation.

### Aliasing and Mutation

Mutation remains difficult.

Consider:

```python
x = Var(1.0)
y = x * x
x.value = 3.0
y.backward()
```

The backward rule for `y = x * x` must use the value of `x` at the time the multiplication happened, not the later mutated value.

Therefore reverse-mode systems usually save needed primal values inside the graph node:

```python
x_value = x.value
y_value = y.value
```

Then the backward rule closes over those saved values.

This is safe but increases memory use. For tensors, saved activations can dominate memory cost.

### Type Coverage

Operator overloading works only when operations dispatch through overloadable interfaces.

It handles:

```text
x + y
x * y
sin(x)
matmul(a, b)
```

if those operations call overloadable methods or registered kernels.

It may fail for:

```text
foreign library calls
raw pointer arithmetic
in-place mutation inside native kernels
non-overloadable builtins
control flow hidden inside opaque compiled code
```

To differentiate through such operations, the system needs custom derivative rules. These rules are often called custom gradients, custom VJPs, custom JVPs, or primitive adjoints.

### Forward vs Reverse Overloading

Forward-mode overloading computes derivatives immediately. There is no separate backward pass.

Reverse-mode overloading records first and differentiates later.

| Property | Forward overloading | Reverse overloading |
|---|---|---|
| Derivative state | Tangents | Adjoints |
| Runtime structure | No graph required | Dynamic graph required |
| Best for | Few inputs, many outputs | Many inputs, few outputs |
| Memory cost | Low | Can be high |
| Control flow | Natural | Natural, but must record graph |
| Typical result | JVP | VJP or gradient |

The programming interface can look similar, but the execution model differs substantially.

### Operator Overloading in Practice

Many popular AD systems use operator overloading or a close variant.

Small educational systems often implement reverse mode this way because it exposes the core idea clearly. Dynamic deep learning systems also rely on this model: run ordinary code, record tensor operations, then backpropagate through the recorded graph.

The implementation becomes more complex at production scale. Real systems need tensor kernels, broadcasting rules, device placement, mixed precision, memory reuse, custom operators, distributed execution, and graph-level optimization.

Still, the conceptual core remains simple: replace numbers with derivative-aware numbers, and let the program run.

### Summary

Operator overloading treats automatic differentiation as a runtime interpretation strategy.

The user writes ordinary code. The AD system supplies special values whose operations propagate tangents or record adjoint rules. The result is a flexible and compact implementation model, especially effective for dynamic programs.

Its main cost is runtime overhead and limited whole-program optimization. Where source transformation behaves like a compiler pass, operator overloading behaves like an execution-time semantic extension of the host language.

