# Stable Training in Deep Networks

Stable training means that a model can make steady progress without numerical collapse, uncontrolled gradients, or large oscillations in the loss. Deep networks are sensitive systems. A small problem in scale, initialization, normalization, data preprocessing, optimizer settings, or precision can compound across many layers.

This section combines the ideas from the previous sections into a practical view of stable PyTorch training.

### What Stability Means

A training run is stable when several quantities remain within useful ranges:

| Quantity | Stable behavior |
|---|---|
| Loss | Decreases over time, with tolerable noise |
| Activations | Neither collapse to zero nor grow without bound |
| Gradients | Finite and large enough to learn |
| Parameters | Change gradually rather than jumping wildly |
| Learning rate | Large enough for progress, small enough for control |
| Validation metric | Improves or degrades slowly, rather than randomly |

Instability usually appears as one of the following:

| Symptom | Common cause |
|---|---|
| Loss becomes `nan` | Too large learning rate, invalid operation, overflow |
| Loss explodes | Large gradients, bad initialization, bad data scale |
| Loss does not move | Vanishing gradients, learning rate too small, frozen parameters |
| Training accuracy rises but validation drops | Overfitting, data leakage, poor regularization |
| Gradients are zero | Detached graph, saturated activations, wrong loss |
| GPU memory grows each step | Stored computation graphs, missing `detach()` |

A stable training run does not need to be perfectly smooth. Mini-batch training is noisy by design. The question is whether the noise remains bounded and useful.

### Start with Correct Data Scale

Deep learning models assume inputs have reasonable numerical scale. If input values are too large, early activations may become too large. If input values are poorly centered, optimization may become slower.

For images, a common preprocessing step is to convert pixel values from integers in `[0, 255]` to floating-point values in `[0, 1]`, then normalize by dataset mean and standard deviation.

```python
import torch
from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
    ),
])
```

For tabular data, standardization is common:

$$
x' = \frac{x - \mu}{\sigma + \epsilon}.
$$

The mean $\mu$ and standard deviation $\sigma$ should be computed on the training set only. Validation and test data should use the same training-set statistics.

For token IDs, do not normalize the integer IDs themselves. Token IDs are indices into an embedding table. The embedding vectors are learned.

### Use Initialization That Matches the Architecture

Initialization controls the starting scale of activations and gradients.

For ReLU-based MLPs and CNNs, Kaiming initialization is a strong default:

```python
from torch import nn

def init_relu_network(module):
    if isinstance(module, (nn.Linear, nn.Conv2d)):
        nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
        if module.bias is not None:
            nn.init.zeros_(module.bias)
```

For tanh-based networks, Xavier initialization is usually better:

```python
def init_tanh_network(module):
    if isinstance(module, nn.Linear):
        nn.init.xavier_uniform_(module.weight)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
```

For transformers, use the architecture’s intended initialization when possible. Many transformer implementations use small normal initialization for linear and embedding weights.

```python
def init_transformer_style(module):
    if isinstance(module, (nn.Linear, nn.Embedding)):
        nn.init.normal_(module.weight, mean=0.0, std=0.02)

    if isinstance(module, nn.Linear) and module.bias is not None:
        nn.init.zeros_(module.bias)

    if isinstance(module, nn.LayerNorm):
        nn.init.ones_(module.weight)
        nn.init.zeros_(module.bias)
```

The best initialization rule depends on architecture, activation function, depth, residual scaling, and normalization.

### Choose Normalization Deliberately

Normalization reduces uncontrolled changes in activation scale.

| Architecture | Common normalization |
|---|---|
| CNN with moderate or large batches | BatchNorm |
| CNN with small batches | GroupNorm |
| Transformer | LayerNorm or RMSNorm |
| Style transfer model | InstanceNorm |
| RNN or sequence model | LayerNorm |

Use normalization as an architectural choice, not as a patch applied after the model fails.

For a transformer block, a pre-normalization layout is usually stable:

```python
class PreNormMLPBlock(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )

    def forward(self, x):
        return x + self.ffn(self.norm(x))
```

For a convolutional block with small batches, group normalization is often safer than batch normalization:

```python
class ConvGNBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=8, num_channels=out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)
```

### Use Residual Connections for Depth

Residual connections help deep networks optimize by giving gradients a direct path through the model.

A residual block computes

$$
x_{l+1} = x_l + F_l(x_l).
$$

This update form helps the network learn small corrections rather than entirely new representations at each layer.

For very deep networks, consider residual scaling:

```python
class ScaledResidualBlock(nn.Module):
    def __init__(self, dim, hidden_dim, scale=0.1):
        super().__init__()
        self.scale = scale
        self.f = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )

    def forward(self, x):
        return x + self.scale * self.f(x)
```

Residual scaling is useful when many residual additions cause activation magnitudes to grow with depth.

### Select a Conservative Learning Rate First

The learning rate is one of the most common causes of instability. A rate that is too high can make the loss explode. A rate that is too low can make training appear broken.

A conservative starting point:

| Optimizer | Typical first learning rate |
|---|---:|
| SGD with momentum | `1e-2` to `1e-1` |
| Adam | `1e-3` |
| AdamW for transformers | `1e-4` to `3e-4` |
| Fine-tuning pretrained models | `1e-5` to `5e-5` |

These are only starting points. The right value depends on batch size, model size, optimizer, normalization, and task.

A useful first response to exploding loss is to reduce the learning rate by a factor of 3 or 10.

```python
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=0.01,
)
```

### Use Learning Rate Warmup

Large models, transformers, and mixed-precision training often benefit from learning rate warmup. Warmup starts with a small learning rate and increases it gradually during early training.

This avoids large destructive updates before the model reaches a reasonable activation and gradient regime.

A simple warmup schedule:

```python
def warmup_lr(step, base_lr, warmup_steps):
    if step < warmup_steps:
        return base_lr * (step + 1) / warmup_steps
    return base_lr
```

With a PyTorch optimizer:

```python
base_lr = 3e-4
warmup_steps = 1000

for step, batch in enumerate(loader):
    lr = warmup_lr(step, base_lr, warmup_steps)
    for group in optimizer.param_groups:
        group["lr"] = lr

    loss = training_step(batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)
```

Warmup is especially common for transformer pretraining and fine-tuning.

### Clip Gradients When Needed

Gradient clipping limits exploding gradients before the optimizer step.

```python
loss.backward()

torch.nn.utils.clip_grad_norm_(
    model.parameters(),
    max_norm=1.0,
)

optimizer.step()
```

Gradient clipping is common in recurrent networks, transformers, reinforcement learning, and unstable generative models.

It should not be used to hide every problem. If clipping is active on almost every step with very large unclipped norms, the learning rate, initialization, or architecture may still be wrong.

### Use Mixed Precision Carefully

Mixed precision can speed up training and reduce memory use, but it introduces numerical risks. In PyTorch, automatic mixed precision is usually handled with `torch.autocast` and `GradScaler`.

```python
scaler = torch.cuda.amp.GradScaler()

for x, y in loader:
    optimizer.zero_grad(set_to_none=True)

    with torch.cuda.amp.autocast():
        logits = model(x)
        loss = loss_fn(logits, y)

    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    scaler.step(optimizer)
    scaler.update()
```

The call to `scaler.unscale_(optimizer)` is important when clipping gradients. It converts scaled gradients back to their real scale before clipping.

On hardware that supports it, `bfloat16` is often more stable than `float16` because it has a wider exponent range.

### Check for Non-Finite Values

A direct stability check is to detect non-finite losses, activations, gradients, and parameters.

```python
def assert_finite_tensor(name, x):
    if not torch.isfinite(x).all():
        raise RuntimeError(f"{name} contains non-finite values")
```

During training:

```python
logits = model(x)
assert_finite_tensor("logits", logits)

loss = loss_fn(logits, y)
assert_finite_tensor("loss", loss)

loss.backward()

for name, param in model.named_parameters():
    if param.grad is not None:
        assert_finite_tensor(f"{name}.grad", param.grad)
```

This makes failures local. Instead of discovering `nan` after many steps, you can stop at the operation where the problem first appears.

### Track Activation and Gradient Statistics

Logging tensor statistics gives visibility into training dynamics.

```python
def parameter_report(model):
    rows = []

    for name, param in model.named_parameters():
        grad_norm = None
        if param.grad is not None:
            grad_norm = param.grad.norm().item()

        rows.append({
            "name": name,
            "shape": tuple(param.shape),
            "param_mean": param.data.mean().item(),
            "param_std": param.data.std().item(),
            "param_norm": param.data.norm().item(),
            "grad_norm": grad_norm,
        })

    return rows
```

Forward hooks can inspect activations:

```python
def add_activation_hooks(model):
    hooks = []

    def hook(name):
        def fn(module, inputs, output):
            if isinstance(output, torch.Tensor):
                print(
                    name,
                    "mean=", output.mean().item(),
                    "std=", output.std().item(),
                    "max=", output.abs().max().item(),
                )
        return fn

    for name, module in model.named_modules():
        if len(list(module.children())) == 0:
            hooks.append(module.register_forward_hook(hook(name)))

    return hooks
```

Use hooks for debugging, not for every production training run. They add overhead and can produce large logs.

### Avoid Accidental Graph Retention

A common PyTorch stability problem is retaining computation graphs unintentionally. This usually appears as GPU memory growing every training step.

Bad pattern:

```python
losses = []

for batch in loader:
    loss = training_step(batch)
    losses.append(loss)  # Keeps graph alive
```

Correct pattern:

```python
losses = []

for batch in loader:
    loss = training_step(batch)
    losses.append(loss.item())
```

Another common issue is accumulating tensors without detaching them:

```python
running_loss = running_loss + loss
```

Use scalar values instead:

```python
running_loss += loss.item()
```

When storing model outputs for later analysis, detach them first:

```python
saved_logits.append(logits.detach().cpu())
```

### Use `zero_grad` Correctly

Gradients accumulate in PyTorch by default. This is useful for gradient accumulation, but it is a common source of bugs.

A normal training step should clear old gradients:

```python
optimizer.zero_grad(set_to_none=True)

logits = model(x)
loss = loss_fn(logits, y)

loss.backward()
optimizer.step()
```

Using `set_to_none=True` can reduce memory use and make it easier to detect parameters that did not receive gradients.

When doing gradient accumulation, clear gradients only at accumulation boundaries:

```python
optimizer.zero_grad(set_to_none=True)

for step, batch in enumerate(loader):
    loss = training_step(batch)
    loss = loss / accumulation_steps
    loss.backward()

    if (step + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
```

Dividing the loss by `accumulation_steps` keeps gradient scale comparable to a larger batch.

### Validate the Training Loop on a Tiny Batch

Before training at scale, test whether the model can overfit a tiny batch. This is one of the best debugging checks.

Take one small batch and train on it repeatedly:

```python
x_small, y_small = next(iter(loader))

for step in range(500):
    optimizer.zero_grad(set_to_none=True)

    logits = model(x_small)
    loss = loss_fn(logits, y_small)

    loss.backward()
    optimizer.step()

    if step % 50 == 0:
        print(step, loss.item())
```

A sufficiently expressive model should drive the training loss very low on a tiny batch. If it cannot, check the model, loss function, labels, optimizer, data preprocessing, and gradient flow.

This test does not prove the model generalizes. It proves that the training loop can learn at all.

### Check Loss and Output Compatibility

Many unstable training runs come from a mismatch between model outputs and loss function.

For classification with `nn.CrossEntropyLoss`, the model should output raw logits, not softmax probabilities.

Correct:

```python
logits = model(x)
loss = nn.CrossEntropyLoss()(logits, labels)
```

Incorrect:

```python
probs = torch.softmax(model(x), dim=-1)
loss = nn.CrossEntropyLoss()(probs, labels)
```

`CrossEntropyLoss` internally applies log-softmax. Passing probabilities can make optimization worse.

For binary classification with `nn.BCEWithLogitsLoss`, the model should output raw logits.

Correct:

```python
logits = model(x).squeeze(-1)
loss = nn.BCEWithLogitsLoss()(logits, targets.float())
```

Incorrect:

```python
probs = torch.sigmoid(model(x)).squeeze(-1)
loss = nn.BCEWithLogitsLoss()(probs, targets.float())
```

Use `BCELoss` only when you explicitly pass probabilities. In most cases, `BCEWithLogitsLoss` is more numerically stable.

### Handle Masks Carefully

Masks are common in sequence models and attention. Bad masking can produce `nan`.

For attention, a mask may set invalid logits to a large negative value before softmax:

```python
scores = scores.masked_fill(mask == 0, -1e9)
weights = torch.softmax(scores, dim=-1)
```

In lower precision, very large constants can cause problems. A dtype-aware value is safer:

```python
large_negative = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask == 0, large_negative)
weights = torch.softmax(scores, dim=-1)
```

Also ensure that each row has at least one valid position. A softmax over all masked positions can produce invalid values.

### Recommended Debug Order

When training is unstable, debug in this order:

| Step | Check |
|---:|---|
| 1 | Verify input shapes, label shapes, and dtypes |
| 2 | Check output and loss compatibility |
| 3 | Overfit one tiny batch |
| 4 | Log loss, gradient norms, and activation statistics |
| 5 | Lower the learning rate |
| 6 | Check initialization and normalization |
| 7 | Add gradient clipping |
| 8 | Disable mixed precision temporarily |
| 9 | Check for accidental graph retention |
| 10 | Inspect data for invalid values or bad labels |

This order catches the most common mistakes before changing the architecture.

### Summary

Stable training comes from controlling numerical scale across the whole training system. Initialization sets the starting point. Normalization controls activation scale. Residual connections improve gradient flow. Learning rate schedules control update size. Gradient clipping limits rare large updates. Mixed precision requires care. The training loop must avoid graph retention, wrong loss usage, and incorrect gradient accumulation.

A practical PyTorch workflow is simple: start with correct data scale, use architecture-appropriate initialization and normalization, choose a conservative learning rate, verify the training loop on a tiny batch, and inspect gradients when behavior looks wrong.

