# Differentiable Subprograms

## Differentiable Subprograms

A differentiable subprogram is a program fragment that can participate in derivative propagation as a coherent unit. Instead of differentiating an entire application monolithically, AD systems decompose computation into smaller callable pieces with well-defined derivative behavior.

A simple example is:

```text
def square(x):
    return x * x
```

Mathematically:

$$
f(x) = x^2.
$$

Its derivative is:

$$
f'(x) = 2x.
$$

An AD system can either inline the body into a larger graph or treat the function as a reusable differentiable component.

## Functions as Differentiable Maps

A differentiable subprogram behaves like a map:

$$
f : X \to Y.
$$

Automatic differentiation constructs associated derivative maps.

Forward mode constructs:

$$
Df : (x, \dot x) \mapsto (f(x), \dot y).
$$

Reverse mode constructs a backward map:

$$
B_f : (x, \bar y) \mapsto \bar x.
$$

The function boundary becomes part of the differentiation structure.

## Encapsulation

Subprograms encapsulate local computation.

Example:

```text
def layer(x, w, b):
    return relu(matmul(x, w) + b)
```

A larger model may call this repeatedly:

```text
h1 = layer(x, w1, b1)
h2 = layer(h1, w2, b2)
y  = layer(h2, w3, b3)
```

The AD system can:

| Strategy | Meaning |
|---|---|
| Inline | expand the body each call |
| Reuse transformed version | cache derivative transform |
| Treat as primitive | use custom derivative rule |

Encapsulation allows modular differentiation.

## Call Graphs

Programs with functions form a call graph.

Example:

```text
main
 ├─> encoder
 │    ├─> attention
 │    └─> mlp
 └─> decoder
      ├─> attention
      └─> mlp
```

AD propagates derivatives through the same call structure.

Forward mode follows call direction. Reverse mode propagates adjoints back through return dependencies.

A reverse-mode engine must know:

| Item | Purpose |
|---|---|
| Inputs to function | backward derivatives |
| Outputs from function | output adjoints |
| Saved intermediates | local backward rules |
| Call ordering | reverse traversal |

## Function Composition

Subprograms compose naturally.

If:

$$
y = g(f(x)),
$$

then:

$$
\frac{dy}{dx} =
\frac{dy}{df}
\frac{df}{dx}.
$$

In code:

```text
u = f(x)
y = g(u)
```

The dependency graph becomes:

```text
x -> f -> u -> g -> y
```

AD applies the chain rule across function boundaries exactly as across primitive operations.

## Local Derivative Contracts

A differentiable subprogram exposes a local derivative contract.

For forward mode:

```text
(primal_in, tangent_in)
    ->
(primal_out, tangent_out)
```

For reverse mode:

```text
(primal_in, primal_out, output_adj)
    ->
input_adj
```

The outer system does not need to know the internal implementation if the contract is correct.

This abstraction enables custom derivative rules.

## Primitive Operations

Some subprograms are treated as primitives.

Example:

```text
y = sin(x)
```

The AD system does not expand the implementation of sine from numerical approximation code. Instead, it uses the known derivative rule:

$$
\frac{d}{dx}\sin(x) = \cos(x).
$$

Similarly for:

| Primitive | Derivative |
|---|---|
| `exp(x)` | `exp(x)` |
| `log(x)` | `1/x` |
| `matmul(a,b)` | matrix rules |
| `conv(x,w)` | convolution rules |

Primitive differentiation hides implementation complexity.

## User-Defined Functions

User-defined functions can usually be differentiated automatically.

Example:

```text
def f(x):
    a = x * x
    b = sin(a)
    return b + 1
```

The system builds a dependency graph for the body and differentiates it mechanically.

Forward-mode transformed version conceptually becomes:

```text
def df(x, dx):
    a  = x * x
    da = 2 * x * dx

    b  = sin(a)
    db = cos(a) * da

    y  = b + 1
    dy = db

    return y, dy
```

This transformation is systematic.

## Closures

A closure captures external variables.

```text
def make_scale(a):
    def scale(x):
        return a * x
    return scale
```

The inner function depends on both `x` and captured variable `a`.

Mathematically:

$$
f(x; a) = ax.
$$

The derivatives are:

$$
\frac{\partial f}{\partial x} = a,
\qquad
\frac{\partial f}{\partial a} = x.
$$

AD systems must track captured values as dependencies.

## Recursive Functions

Recursive subprograms define themselves in terms of earlier calls.

Example:

```text
def f(x, n):
    if n == 0:
        return x
    return sin(f(x, n - 1))
```

For fixed `n`, the call tree expands into a finite dependency graph.

AD differentiates the expanded execution trace.

Reverse mode must preserve:

| Item | Reason |
|---|---|
| Call stack | reverse traversal |
| Local variables | local derivatives |
| Return structure | adjoint propagation |

Recursive AD therefore resembles stack replay.

## Higher-Order Functions

A higher-order function takes functions as inputs or outputs.

Example:

```text
def apply_twice(f, x):
    return f(f(x))
```

If `f` is differentiable:

$$
y = f(f(x)).
$$

Then:

$$
\frac{dy}{dx} =
f'(f(x))f'(x).
$$

AD systems for functional languages often treat derivative transforms themselves as higher-order functions.

## Differentiation as Program Transformation

A differentiable subprogram may be transformed into a new subprogram.

Original:

```text
f : X -> Y
```

Forward-mode transform:

```text
Df : (X × TX) -> (Y × TY)
```

Reverse-mode transform:

```text
Rf : X -> (Y, Y* -> X*)
```

where $TX$ and $TY$ are tangent spaces, and $X^*$, $Y^*$ are cotangent spaces.

Thus AD itself becomes a compiler transform on callable units.

## Custom Gradient Rules

Sometimes the default derivative is inefficient or numerically unstable.

A user may define a custom backward rule.

Example:

```text
def stable_logsumexp(x):
    ...
```

The backward rule can be supplied directly:

```text
def backward(output_adj):
    ...
```

Advantages include:

| Benefit | Example |
|---|---|
| Better stability | log-sum-exp |
| Lower memory | fused backward |
| Faster execution | custom kernels |
| Implicit differentiation | solvers |
| Approximate gradients | quantization |

The subprogram becomes a primitive with user-defined differentiation semantics.

## Opaque External Functions

Some subprograms call external libraries.

```text
y = cuda_kernel(x)
```

The AD system may not know the internal implementation.

Possible strategies:

| Strategy | Meaning |
|---|---|
| Treat as non-differentiable | stop gradient |
| Provide custom rule | manual backward |
| Trace internal ops | if supported |
| Use finite differences | fallback approximation |

Large systems often rely heavily on custom derivative rules for external kernels.

## Nested Differentiation

A differentiable subprogram may itself invoke AD.

Example:

```text
def gradient_norm(f, x):
    g = grad(f)(x)
    return dot(g, g)
```

Now AD is applied to a program that already performs differentiation.

This creates nested derivative structures:

| Outer level | Inner level |
|---|---|
| differentiate `gradient_norm` | differentiate `f` |

Nested AD requires careful management of tangent and adjoint scopes.

## Perturbation Confusion

Nested differentiation can accidentally mix derivative levels.

Example:

```text
grad(lambda x: grad(f)(x))(x)
```

The inner and outer derivative computations must remain distinct.

Correct systems isolate derivative contexts so that:

| Level | Meaning |
|---|---|
| Inner tangent | derivative of `f` |
| Outer tangent | derivative of derivative |

Without isolation, perturbations may interfere and produce incorrect higher-order derivatives.

## Function Boundaries and Optimization

Subprogram boundaries influence optimization.

Inlining may expose more fusion opportunities:

```text
y = relu(matmul(x, w) + b)
```

Modular boundaries may improve reuse and compilation caching.

Compiler-based systems often balance:

| Goal | Preference |
|---|---|
| Maximum optimization | aggressive inlining |
| Fast compilation | preserve modularity |
| Reusability | cached differentiated kernels |
| Lower memory | fused backward passes |

Differentiable subprograms are therefore both semantic and optimization units.

## Interface Design

A minimal differentiable interface may look like:

```go
type Function interface {
    Forward(x Value) Value
    Backward(yAdj Value) Value
}
```

More realistic systems require:

| Requirement | Reason |
|---|---|
| Multiple inputs | tensor programs |
| Multiple outputs | structured models |
| Saved intermediates | reverse mode |
| Device awareness | GPU execution |
| Shape metadata | tensor validation |
| Batched evaluation | vectorization |

Production AD systems therefore build sophisticated callable abstractions around these basic ideas.

## Core Idea

A differentiable subprogram is a callable computation unit with well-defined derivative behavior. Automatic differentiation propagates through function boundaries exactly as through primitive operations: by composing local derivative rules according to the dependency structure.

Subprograms provide modularity, reuse, abstraction, and optimization boundaries. They also enable custom gradients, higher-order differentiation, and compiler-level differentiation transforms.

