# Neural Network Training

## Neural Network Training

Neural network training is the repeated application of three operations: evaluate a model, differentiate a scalar loss, and update parameters. Automatic differentiation supplies the middle operation. The training system supplies data movement, batching, numerical policy, optimizer state, checkpointing, logging, and distributed execution.

A minimal training step is:

```text
batch = next_batch(data)

pred = model(batch.x)
loss = loss_fn(pred, batch.y)

grad = gradient(loss, model.parameters)
optimizer.update(model.parameters, grad)
```

This loop looks simple, but it hides several contracts. The model must produce differentiable tensor operations. The loss must reduce the batch output to a scalar or to a scalar-equivalent objective. The AD engine must retain enough forward information to run the backward pass. The optimizer must update parameter storage without corrupting the graph of the current computation.

## Model, Loss, and Optimizer

A training system separates three objects.

| Object | Role |
|---|---|
| Model | Maps inputs to predictions |
| Loss | Converts predictions and targets into a scalar objective |
| Optimizer | Converts gradients into parameter updates |

For supervised learning, the model is usually a parameterized function:

$$
\hat{y} = f_\theta(x).
$$

The loss compares $\hat{y}$ with the target $y$:

$$
L_B(\theta) =
\frac{1}{|B|}
\sum_{(x_i,y_i)\in B}
\ell(f_\theta(x_i), y_i).
$$

Reverse mode AD computes:

$$
\nabla_\theta L_B(\theta).
$$

The optimizer then applies an update. For plain gradient descent:

$$
\theta \leftarrow \theta - \eta \nabla_\theta L_B(\theta).
$$

For Adam-like optimizers, the update also uses running moment estimates. The derivative computation remains the same. Only the update rule changes.

## Training Step Semantics

A training step has a precise order.

First, the forward pass computes predictions and saves residuals needed for differentiation. Second, the loss function constructs a scalar objective. Third, reverse mode traverses the recorded computation and accumulates gradients into parameter buffers. Fourth, the optimizer mutates the parameter tensors. Fifth, gradient buffers are reset before the next step.

Written explicitly:

```text
zero_grad(parameters)

pred = model(x)
loss = loss_fn(pred, y)

backward(loss)

optimizer_step(parameters)

zero_or_clear_grad(parameters)
```

Some systems clear gradients before the forward pass. Some clear them after the optimizer step. Both conventions are valid if accumulation is intentional and consistent.

The dangerous case is silent accumulation across unrelated batches. Then the update uses stale gradient contributions.

## Parameter State

A trainable parameter has at least two pieces of state:

```text
value: current tensor
grad: accumulated derivative of current objective with respect to value
```

An optimizer may add more state:

```text
momentum: moving average of gradients
variance: moving average of squared gradients
step: update counter
```

For SGD without momentum, optimizer state is small. For Adam, two auxiliary tensors are commonly stored for each parameter. In large models, optimizer state may require more memory than the model weights.

The training checkpoint must therefore include more than model weights. A complete restartable checkpoint usually stores:

| State | Needed for |
|---|---|
| Model parameters | Forward computation |
| Optimizer buffers | Continuing the same optimization trajectory |
| Learning-rate scheduler state | Correct step-dependent schedule |
| Random number generator state | Reproducibility |
| Data sampler state | Exact batch order |
| Mixed-precision scaler | Stable reduced-precision training |

Saving only weights is enough for inference. It is often insufficient for resuming training exactly.

## Loss Scaling and Reduction

Batch losses must be reduced carefully. Two common reductions are mean and sum.

Mean reduction:

$$
L_B =
\frac{1}{|B|}
\sum_{i\in B}\ell_i.
$$

Sum reduction:

$$
L_B =
\sum_{i\in B}\ell_i.
$$

With mean reduction, gradient magnitude is roughly independent of batch size. With sum reduction, gradient magnitude scales with batch size. This affects the learning rate.

The distinction matters for gradient accumulation. Suppose we split a batch into $K$ micro-batches. To match the average gradient of the full batch, each micro-batch loss is usually divided by $K$:

```text
zero_grad()

for micro_batch in micro_batches:
    loss = loss_fn(model(micro_batch.x), micro_batch.y)
    loss = loss / K
    backward(loss)

optimizer_step()
```

Without this division, the accumulated gradient is $K$ times larger than intended.

## Training and Evaluation Modes

Many neural network modules behave differently during training and evaluation.

Dropout randomly masks activations during training. During evaluation, it usually uses the full activation without sampling.

Batch normalization uses batch statistics during training and stored running statistics during evaluation.

Other modules may also switch behavior: stochastic depth, data-dependent routing, quantization observers, and certain regularizers.

This mode switch changes the executed program. Since AD differentiates the executed program, the training mode must be set correctly before the forward pass.

A common pattern is:

```text
model.train()
loss = training_step(batch)

model.eval()
metric = validation_step(batch)
```

Validation usually disables gradient recording. This saves memory and avoids building unnecessary graphs.

## Regularization

Training often adds regularization terms to the loss.

For L2 regularization:

$$
L_{\text{total}}(\theta) =
L_{\text{data}}(\theta)
+
\lambda \|\theta\|_2^2.
$$

AD differentiates the combined objective:

$$
\nabla_\theta L_{\text{total}} =
\nabla_\theta L_{\text{data}}
+
2\lambda \theta.
$$

Weight decay is closely related but may be implemented inside the optimizer rather than as a loss term. In simple SGD, L2 regularization and weight decay can coincide. In adaptive optimizers, decoupled weight decay behaves differently from adding an L2 penalty to the loss.

This is a training-system detail, not an AD detail. AD differentiates the loss it is given. If regularization is implemented in the optimizer, AD never sees it.

## Mixed Precision Training

Large neural networks are often trained with reduced precision arithmetic. Common formats include float16, bfloat16, and float32.

Reduced precision improves throughput and lowers memory traffic, but it introduces numerical risks. Gradients may underflow. Accumulations may lose precision. Some operations, such as normalization or softmax, may need higher precision internally.

Mixed precision training usually uses a policy like:

| Quantity | Typical precision |
|---|---|
| Model weights | float16 or bfloat16 for compute |
| Master weights | float32 in some systems |
| Activations | float16 or bfloat16 |
| Gradient accumulation | float16, bfloat16, or float32 |
| Optimizer state | often float32 |

Loss scaling is used mainly with float16. The loss is multiplied by a scale factor before backward propagation. Gradients are then unscaled before the optimizer step.

```text
scaled_loss = loss * scale
backward(scaled_loss)

unscale_gradients(parameters, scale)

if gradients_are_finite:
    optimizer_step()
else:
    skip_step_and_reduce_scale()
```

The derivative relation is simple: multiplying the loss by $s$ multiplies all gradients by $s$. Unscaling restores the intended gradient values. The purpose is to keep small gradients representable during backward computation.

## Gradient Clipping

Gradient clipping limits the size of an update signal. It is common in recurrent networks, transformers, and unstable training regimes.

Global norm clipping computes:

$$
g =
(\nabla_{\theta_1}L,\ldots,\nabla_{\theta_k}L),
$$

then rescales if needed:

$$
g \leftarrow g \cdot \min\left(1, \frac{c}{\|g\|}\right).
$$

Here $c$ is the clipping threshold.

Clipping is applied after backward and before the optimizer step:

```text
backward(loss)
clip_grad_norm(parameters, threshold)
optimizer_step()
```

AD computes the unclipped gradient. Clipping is an optimizer-side or training-loop operation. It changes the update, not the derivative of the model loss.

## Data Pipeline

Training speed often depends on the data pipeline as much as the model.

The pipeline may perform:

| Stage | Examples |
|---|---|
| Storage read | local files, object storage, database reads |
| Decode | image decode, text tokenization, audio loading |
| Transform | crop, resize, normalize, augment |
| Batch | pad, pack, bucket, collate |
| Transfer | CPU to GPU, host to accelerator memory |

If the accelerator waits for data, the AD engine and optimizer are idle. High utilization requires overlapping data loading with computation.

Most data preprocessing is outside the differentiated graph. For example, random image crop or text tokenization usually has no useful derivative with respect to model parameters. The batch tensors are treated as inputs.

## Validation and Metrics

Training loss is optimized directly. Metrics are often separate.

For classification, the loss may be cross-entropy, while the metric may be accuracy. Accuracy contains an argmax and an equality comparison, both unsuitable for gradient-based optimization. This is acceptable because metrics are measured, not differentiated.

A validation step usually runs as:

```text
disable_gradient_recording()

pred = model(x)
loss = loss_fn(pred, y)
metric = compute_metric(pred, y)
```

Disabling gradient recording reduces memory because no backward residuals are stored.

Validation also uses evaluation mode so modules such as dropout and batch normalization behave deterministically.

## Checkpointing

Checkpointing has two meanings in training.

Training checkpoints save state to durable storage so a run can resume after interruption.

Activation checkpointing saves memory during backpropagation by discarding selected forward activations and recomputing them during the backward pass.

These are different mechanisms. Both matter.

A training checkpoint is about fault tolerance and experiment continuity. An activation checkpoint is about memory-time tradeoff inside AD execution.

Activation checkpointing changes runtime cost but preserves the mathematical gradient, assuming recomputation is deterministic and numerically consistent enough.

## Minimal Training Loop

A compact training loop can be written as:

```text
for step in range(num_steps):
    model.train()

    batch = next(train_loader)

    optimizer.zero_grad()

    pred = model(batch.x)
    loss = loss_fn(pred, batch.y)

    loss.backward()

    clip_gradients_if_needed(model.parameters)

    optimizer.step()

    if step % validation_interval == 0:
        model.eval()
        with no_grad():
            validate(model, validation_loader)

    if step % checkpoint_interval == 0:
        save_checkpoint(model, optimizer, scheduler, step)
```

This loop exposes the main interface points. The model builds the differentiable computation. The AD engine computes gradients. The optimizer mutates parameters. The surrounding system handles data, evaluation, logging, and persistence.

## Common Training Bugs

Several bugs occur frequently.

Gradients are accumulated unintentionally because gradient buffers were not cleared.

The model remains in evaluation mode during training, disabling dropout or using stale batch-normalization behavior.

The model remains in training mode during validation, making metrics noisy or biased.

The loss reduction changes when the batch size changes, causing effective learning-rate changes.

The optimizer state is not restored from checkpoint, so resumed training follows a different trajectory.

Reduced precision causes overflow or underflow, especially in softmax, normalization, or gradient accumulation.

A parameter is detached from the graph, so it receives no gradient.

A tensor is modified in-place while still needed by the backward pass.

The loss is converted to a plain scalar too early, breaking the differentiable graph.

These bugs are often systems bugs around AD rather than failures of the derivative rules themselves.

## The AD Contract in Training

Neural network training relies on a clean contract:

```text
given:
    a scalar loss produced by a differentiable program

AD provides:
    gradients with respect to selected parameters

the optimizer provides:
    parameter updates using those gradients

the training system provides:
    data, scheduling, precision, state, and recovery
```

The quality of training depends on all parts. Correct AD gives the right local derivative of the executed program. Effective training also needs stable losses, suitable optimization, good data flow, controlled randomness, and reliable state management.

