# Custom Gradients

## Custom Gradients

A custom gradient gives the user direct control over the backward rule of an operation. The forward computation still produces an ordinary value, but the derivative no longer has to be inferred from the primitive operations used internally.

This is useful when the default AD result is mathematically correct but computationally poor, numerically unstable, too memory-hungry, or simply not the derivative the user wants the system to use.

A custom gradient turns a function into a primitive operation.

## Why Custom Gradients Exist

Suppose we define:

$$
y = \log(1 + e^x)
$$

This is the softplus function. A naive implementation is:

```go
func Softplus(t *Tape, x Slot) Slot {
    one := t.Const(1)
    return t.Log(t.Add(one, t.Exp(x)))
}
```

The derivative is:

$$
\frac{dy}{dx} = \frac{e^x}{1 + e^x}
$$

which is the sigmoid function.

The naive expression works for moderate values of `x`, but it can overflow when `x` is large because `exp(x)` may exceed the floating point range. A custom gradient lets us implement both the forward and backward paths in a stable form.

## A Minimal Custom Operation

In a tape-based engine, a custom operation can be represented as an instruction whose backward rule is supplied by the user.

```go
type CustomBackward func(t *Tape, ins Instr)

type Instr struct {
    Op  OpKind
    Out Slot
    A   Slot
    B   Slot

    Custom CustomBackward
}
```

Then the backward dispatcher includes:

```go
case OpCustom:
    ins.Custom(t, ins)
```

A small helper can register a unary custom operation:

```go
func (t *Tape) CustomUnary(
    x Slot,
    value float64,
    backward CustomBackward,
) Slot {
    out := t.alloc(value, t.RequiresGrad[x])

    if t.RequiresGrad[out] {
        t.Instrs = append(t.Instrs, Instr{
            Op:     OpCustom,
            Out:    out,
            A:      x,
            Custom: backward,
        })
    }

    return out
}
```

This design is intentionally small. It gives users a way to attach a local vector-Jacobian product to an operation.

## Stable Softplus

A stable forward implementation can use:

```go
func stableSoftplus(x float64) float64 {
    if x > 0 {
        return x + math.Log1p(math.Exp(-x))
    }

    return math.Log1p(math.Exp(x))
}
```

The custom AD operation:

```go
func (t *Tape) Softplus(x Slot) Slot {
    value := stableSoftplus(t.Values[x])

    return t.CustomUnary(x, value, func(t *Tape, ins Instr) {
        xval := t.Values[ins.A]

        var sig float64
        if xval >= 0 {
            z := math.Exp(-xval)
            sig = 1 / (1 + z)
        } else {
            z := math.Exp(xval)
            sig = z / (1 + z)
        }

        t.Grads[ins.A] += sig * t.Grads[ins.Out]
    })
}
```

The backward rule directly computes the stable derivative. It avoids relying on the internal expression `log(1 + exp(x))`.

## Custom Gradient as VJP

In reverse mode, a custom gradient is usually a custom vector-Jacobian product.

For a function:

$$
y = f(x)
$$

the backward rule receives an output adjoint:

$$
\bar{y}
$$

and must add to the input adjoint:

$$
\bar{x} \mathrel{+}= J_f(x)^T \bar{y}
$$

For scalar functions, this reduces to:

$$
\bar{x} \mathrel{+}= f'(x)\bar{y}
$$

For tensor functions, the rule must handle full shape semantics.

The custom backward rule should therefore be understood as a local VJP implementation, not merely a derivative formula.

## Multiple Inputs

A custom operation may have several inputs.

Example:

$$
z = f(x, y)
$$

The backward rule must return contributions for both inputs:

$$
\bar{x} \mathrel{+}= \frac{\partial z}{\partial x}\bar{z}
$$

$$
\bar{y} \mathrel{+}= \frac{\partial z}{\partial y}\bar{z}
$$

A minimal binary helper:

```go
func (t *Tape) CustomBinary(
    a, b Slot,
    value float64,
    backward CustomBackward,
) Slot {
    out := t.alloc(value, t.RequiresGrad[a] || t.RequiresGrad[b])

    if t.RequiresGrad[out] {
        t.Instrs = append(t.Instrs, Instr{
            Op:     OpCustom,
            Out:    out,
            A:      a,
            B:      b,
            Custom: backward,
        })
    }

    return out
}
```

Each input that requires gradients receives its own contribution.

## Saved Values

A custom backward rule often needs values computed during the forward pass. Recomputing them may be expensive or numerically risky.

The instruction can store saved scalars:

```go
type Instr struct {
    Op  OpKind
    Out Slot
    A   Slot
    B   Slot

    Saved []float64

    Custom CustomBackward
}
```

For softplus, we may save the sigmoid value:

```go
func (t *Tape) Softplus(x Slot) Slot {
    xval := t.Values[x]
    value := stableSoftplus(xval)
    sig := stableSigmoid(xval)

    return t.CustomUnarySaved(
        x,
        value,
        []float64{sig},
        func(t *Tape, ins Instr) {
            t.Grads[ins.A] += ins.Saved[0] * t.Grads[ins.Out]
        },
    )
}
```

This trades memory for stable and cheaper backward execution.

The operator author must decide which values to save. A useful rule:

```text
Save values that are expensive, unstable, or impossible to reconstruct safely.
```

## Stop Gradient

A special custom gradient is the zero gradient.

Forward:

$$
y = x
$$

Backward:

$$
\bar{x} \mathrel{+}= 0
$$

This is usually called `stop_gradient`, `detach`, or `no_grad`.

```go
func (t *Tape) StopGradient(x Slot) Slot {
    return t.alloc(t.Values[x], false)
}
```

The returned slot has the same value but does not record gradient dependencies.

This is useful when the program needs a value for computation but wants to prevent derivative flow through that value.

## Straight-Through Estimators

Some operations are non-differentiable but still used inside models. A common example is rounding.

Forward:

$$
y = \text{round}(x)
$$

True derivative is zero almost everywhere and undefined at half-integers. A straight-through estimator chooses a surrogate backward rule:

$$
\bar{x} \mathrel{+}= \bar{y}
$$

Implementation:

```go
func (t *Tape) RoundSTE(x Slot) Slot {
    value := math.Round(t.Values[x])

    return t.CustomUnary(x, value, func(t *Tape, ins Instr) {
        t.Grads[ins.A] += t.Grads[ins.Out]
    })
}
```

This is not the mathematical derivative of `round`. It is an optimization heuristic.

A good AD engine should make this explicit. Custom gradients can encode surrogate gradients, but the user should know when the backward rule differs from the true derivative.

## Custom Gradient Validation

Custom gradients are powerful because they bypass the default derivative machinery. They are dangerous for the same reason.

Each custom gradient should be tested.

For scalar functions, finite difference checks are simple:

```go
func finiteDiff(f func(float64) float64, x float64) float64 {
    eps := 1e-6
    return (f(x+eps) - f(x-eps)) / (2 * eps)
}
```

Then compare with AD:

```go
func checkGrad(f func(*Tape, Slot) Slot, xval float64) {
    var t Tape

    x := t.Var(xval)
    y := f(&t, x)

    t.Backward(y)

    fmt.Println("finite diff:", finiteDiff(func(x float64) float64 {
        var tt Tape
        xx := tt.Var(x)
        yy := f(&tt, xx)
        return tt.Values[yy]
    }, xval))

    fmt.Println("AD:", t.Grads[x])
}
```

Finite differences are approximate, so the comparison should use tolerances.

## API Design

A practical custom-gradient API should avoid exposing too much of the engine internals. A minimal shape:

```go
type VJPFunc func(outGrad Tensor, saved SavedValues) []Tensor

type CustomOp struct {
    Name string

    Forward func(inputs []Tensor) (output Tensor, saved SavedValues)
    Backward VJPFunc
}
```

The engine calls `Forward` during evaluation and stores `saved`. During backward, it calls `Backward`.

For scalar tape systems, this is simpler. For tensor systems, this separation is essential.

## Correctness Contract

A custom gradient must satisfy the same invariant as a built-in operator:

$$
\bar{x_i} \mathrel{+}= J_i^T \bar{y}
$$

where $J_i$ is the Jacobian of the output with respect to input $x_i$.

If the operation intentionally uses a surrogate derivative, the implementation should name it accordingly. `RoundSTE` is clearer than `Round`, because it signals that the derivative is an estimator.

## When to Use Custom Gradients

Custom gradients are appropriate when:
- the default derivative is numerically unstable
- recomputation would be expensive
- a fused operation needs a fused backward pass
- the forward implementation uses non-differentiable internals
- the mathematically desired derivative differs from the implementation derivative
- memory can be reduced by saving a smaller sufficient state

They should be avoided when ordinary composition already gives a simple and stable derivative.

The best default is to implement primitives correctly, compose them normally, and reserve custom gradients for places where the ordinary derivative path is measurably worse or semantically wrong.

