# JAX

## JAX

JAX is an automatic differentiation and array programming system for Python. It combines NumPy-like syntax with composable program transformations. Its core transformations include differentiation, vectorization, just-in-time compilation, and parallelization.

The important design point is that JAX treats numerical Python functions as transformable objects. A function can be differentiated with `grad`, batched with `vmap`, compiled with `jit`, and distributed with `pmap` or related sharding APIs.

## Functional Numerical Programs

JAX works best with pure functions. A function should take inputs explicitly and return outputs explicitly.

```python
import jax
import jax.numpy as jnp

def f(x):
    return x * x + jnp.sin(x)

df = jax.grad(f)

print(df(2.0))
```

Here `jax.grad(f)` returns a new function that computes the derivative of `f`.

This differs from PyTorch’s usual `.backward()` style. PyTorch accumulates gradients into tensor fields. JAX returns transformed functions. The gradient is a value, not an implicit side effect.

## Reverse Mode with `grad`

For scalar-output functions, `jax.grad` uses reverse-mode automatic differentiation.

```python
def loss(params, x, y):
    pred = x @ params["w"] + params["b"]
    return jnp.mean((pred - y) ** 2)

grads = jax.grad(loss)(params, x, y)
```

The output `grads` has the same tree structure as `params`. If `params` is a nested dictionary, tuple, or dataclass-like PyTree, the gradient follows the same structure.

This is a central JAX idea: arrays are leaves, and Python containers form PyTrees. Transformations preserve that structure.

## Forward Mode with `jvp`

JAX also exposes forward mode through Jacobian-vector products.

For

$$
y = f(x),
$$

a JVP computes

$$
(f(x), J_f(x)v).
$$

```python
y, tangent = jax.jvp(f, (x,), (v,))
```

Forward mode is efficient when the number of input directions is small. It is also useful for Hessian-vector products when composed with reverse mode.

## Reverse Mode with `vjp`

Vector-Jacobian products are exposed through `vjp`.

For

$$
y = f(x),
$$

a VJP computes

$$
u^T J_f(x).
$$

```python
y, pullback = jax.vjp(f, x)
grad_x = pullback(u)
```

The `pullback` is a function. It maps an output cotangent back to input cotangents.

This API exposes the algebra of reverse mode directly. Instead of hiding the backward pass behind `backward()`, JAX represents it as a callable object.

## Composable Transformations

JAX transformations can be composed.

```python
fast_grad = jax.jit(jax.grad(loss))
```

This means: first transform `loss` into a gradient function, then compile that gradient function.

Vectorization works similarly:

```python
batched_f = jax.vmap(f)
```

`vmap` transforms a scalar-example function into a batched function without writing explicit loops.

A common pattern is:

```python
step = jax.jit(train_step)
```

where `train_step` computes a loss, gradients, and updated parameters.

The power of JAX comes from this small set of orthogonal transformations:

| Transformation | Meaning |
|---|---|
| `grad` | reverse-mode derivative of scalar-output function |
| `jvp` | forward-mode Jacobian-vector product |
| `vjp` | reverse-mode vector-Jacobian product |
| `jacfwd` | full Jacobian using forward mode |
| `jacrev` | full Jacobian using reverse mode |
| `vmap` | automatic batching |
| `jit` | compilation |
| `pmap` / sharding APIs | parallel execution |

## Tracing and JAXPR

JAX implements transformations by tracing Python functions. During tracing, JAX does not execute ordinary array computation in the normal way. Instead, it records a symbolic representation of the computation.

This internal representation is called JAXPR. It is a small functional intermediate language containing primitive operations, variables, constants, and equations.

A simplified view:

```text
x -> sin x -> add -> y
```

becomes a graph-like expression over JAX primitives.

Transformations such as `grad`, `vmap`, and `jit` operate on this representation or on traces that produce it. This makes JAX closer to a compiler system than to a simple runtime tape.

## XLA Compilation

JAX uses XLA as its compilation backend. When a function is wrapped with `jax.jit`, JAX traces the function, lowers it to compiler IR, and compiles it for CPU, GPU, or TPU.

```python
@jax.jit
def f(x):
    return jnp.tanh(x @ x)
```

The compiled version can fuse operations, optimize memory layout, and target accelerators.

Compilation has consequences. Shapes and dtypes often determine compilation cache entries. A new shape may trigger a new trace and compilation. Python side effects happen during tracing, not during every compiled execution.

## Pure State Updates

JAX arrays are immutable. State updates are expressed functionally.

Instead of mutating parameters:

```python
params["w"] -= lr * grads["w"]
```

JAX code commonly returns a new parameter tree:

```python
new_params = jax.tree.map(
    lambda p, g: p - lr * g,
    params,
    grads,
)
```

This style makes transformations easier. Mutation complicates differentiation, batching, compilation, and parallel execution. JAX avoids much of that complexity by preferring explicit state passing.

## Randomness

JAX treats randomness as explicit state. Random keys are passed into functions.

```python
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)

x = jax.random.normal(subkey, (10,))
```

This differs from global random state in NumPy or PyTorch. Explicit keys make random programs reproducible and compatible with transformations such as `vmap` and `jit`.

## Custom Derivatives

JAX supports custom derivative rules with `custom_jvp` and `custom_vjp`.

A custom JVP defines forward-mode behavior. A custom VJP defines reverse-mode behavior.

```python
@jax.custom_vjp
def f(x):
    return jnp.sqrt(x)

def f_fwd(x):
    y = f(x)
    return y, y

def f_bwd(saved, g):
    y = saved
    return (g / (2 * y),)

f.defvjp(f_fwd, f_bwd)
```

Custom rules are used for numerical stability, external primitives, implicit differentiation, and operations whose default derivative would be inefficient.

## Strengths

JAX has a compact and rigorous transformation model. Its APIs expose the core AD objects directly: gradients, JVPs, VJPs, Jacobians, and Hessians.

Its functional style makes programs easier to transform. Explicit inputs, explicit outputs, immutable arrays, and explicit random keys reduce hidden state.

The combination of `grad`, `vmap`, and `jit` is especially powerful. Users can write simple scalar or single-example code, then derive batched and compiled versions mechanically.

JAX is also strong for higher-order differentiation. Since transformations compose, expressions such as `jax.grad(jax.grad(f))` or `jax.jvp(jax.grad(f), ...)` are natural.

## Limitations

JAX differentiates JAX operations, not arbitrary Python execution. Python control flow is traced only under specific rules, and data-dependent control inside `jit` often requires JAX control-flow primitives such as `jax.lax.cond`, `jax.lax.scan`, and `jax.lax.while_loop`.

Compilation can be expensive. Programs with many input shapes or frequent retracing may perform poorly.

The functional style is clean but sometimes intrusive. Existing Python programs with mutation, object state, dynamic data structures, and side effects often need redesign before they work well in JAX.

Debugging compiled and transformed code can be harder than debugging eager Python. A function may behave differently during tracing than during execution, especially when Python side effects are involved.

## Historical Role

JAX represents automatic differentiation as one member of a broader family of program transformations. Earlier systems often centered on one AD mode or one execution model. JAX made transformation composition the primary interface.

Its contribution is architectural: differentiation, batching, compilation, and parallelism are treated as related transformations over numerical programs. This made JAX influential not only as a machine learning framework, but also as a model for differentiable programming language design.

