Skip to content

Stable Training in Deep Networks

Stable training means that a model can make steady progress without numerical collapse, uncontrolled gradients, or large oscillations in the loss.

Stable training means that a model can make steady progress without numerical collapse, uncontrolled gradients, or large oscillations in the loss. Deep networks are sensitive systems. A small problem in scale, initialization, normalization, data preprocessing, optimizer settings, or precision can compound across many layers.

This section combines the ideas from the previous sections into a practical view of stable PyTorch training.

What Stability Means

A training run is stable when several quantities remain within useful ranges:

QuantityStable behavior
LossDecreases over time, with tolerable noise
ActivationsNeither collapse to zero nor grow without bound
GradientsFinite and large enough to learn
ParametersChange gradually rather than jumping wildly
Learning rateLarge enough for progress, small enough for control
Validation metricImproves or degrades slowly, rather than randomly

Instability usually appears as one of the following:

SymptomCommon cause
Loss becomes nanToo large learning rate, invalid operation, overflow
Loss explodesLarge gradients, bad initialization, bad data scale
Loss does not moveVanishing gradients, learning rate too small, frozen parameters
Training accuracy rises but validation dropsOverfitting, data leakage, poor regularization
Gradients are zeroDetached graph, saturated activations, wrong loss
GPU memory grows each stepStored computation graphs, missing detach()

A stable training run does not need to be perfectly smooth. Mini-batch training is noisy by design. The question is whether the noise remains bounded and useful.

Start with Correct Data Scale

Deep learning models assume inputs have reasonable numerical scale. If input values are too large, early activations may become too large. If input values are poorly centered, optimization may become slower.

For images, a common preprocessing step is to convert pixel values from integers in [0, 255] to floating-point values in [0, 1], then normalize by dataset mean and standard deviation.

import torch
from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
    ),
])

For tabular data, standardization is common:

x=xμσ+ϵ. x' = \frac{x - \mu}{\sigma + \epsilon}.

The mean μ\mu and standard deviation σ\sigma should be computed on the training set only. Validation and test data should use the same training-set statistics.

For token IDs, do not normalize the integer IDs themselves. Token IDs are indices into an embedding table. The embedding vectors are learned.

Use Initialization That Matches the Architecture

Initialization controls the starting scale of activations and gradients.

For ReLU-based MLPs and CNNs, Kaiming initialization is a strong default:

from torch import nn

def init_relu_network(module):
    if isinstance(module, (nn.Linear, nn.Conv2d)):
        nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
        if module.bias is not None:
            nn.init.zeros_(module.bias)

For tanh-based networks, Xavier initialization is usually better:

def init_tanh_network(module):
    if isinstance(module, nn.Linear):
        nn.init.xavier_uniform_(module.weight)
        if module.bias is not None:
            nn.init.zeros_(module.bias)

For transformers, use the architecture’s intended initialization when possible. Many transformer implementations use small normal initialization for linear and embedding weights.

def init_transformer_style(module):
    if isinstance(module, (nn.Linear, nn.Embedding)):
        nn.init.normal_(module.weight, mean=0.0, std=0.02)

    if isinstance(module, nn.Linear) and module.bias is not None:
        nn.init.zeros_(module.bias)

    if isinstance(module, nn.LayerNorm):
        nn.init.ones_(module.weight)
        nn.init.zeros_(module.bias)

The best initialization rule depends on architecture, activation function, depth, residual scaling, and normalization.

Choose Normalization Deliberately

Normalization reduces uncontrolled changes in activation scale.

ArchitectureCommon normalization
CNN with moderate or large batchesBatchNorm
CNN with small batchesGroupNorm
TransformerLayerNorm or RMSNorm
Style transfer modelInstanceNorm
RNN or sequence modelLayerNorm

Use normalization as an architectural choice, not as a patch applied after the model fails.

For a transformer block, a pre-normalization layout is usually stable:

class PreNormMLPBlock(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )

    def forward(self, x):
        return x + self.ffn(self.norm(x))

For a convolutional block with small batches, group normalization is often safer than batch normalization:

class ConvGNBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=8, num_channels=out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)

Use Residual Connections for Depth

Residual connections help deep networks optimize by giving gradients a direct path through the model.

A residual block computes

xl+1=xl+Fl(xl). x_{l+1} = x_l + F_l(x_l).

This update form helps the network learn small corrections rather than entirely new representations at each layer.

For very deep networks, consider residual scaling:

class ScaledResidualBlock(nn.Module):
    def __init__(self, dim, hidden_dim, scale=0.1):
        super().__init__()
        self.scale = scale
        self.f = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )

    def forward(self, x):
        return x + self.scale * self.f(x)

Residual scaling is useful when many residual additions cause activation magnitudes to grow with depth.

Select a Conservative Learning Rate First

The learning rate is one of the most common causes of instability. A rate that is too high can make the loss explode. A rate that is too low can make training appear broken.

A conservative starting point:

OptimizerTypical first learning rate
SGD with momentum1e-2 to 1e-1
Adam1e-3
AdamW for transformers1e-4 to 3e-4
Fine-tuning pretrained models1e-5 to 5e-5

These are only starting points. The right value depends on batch size, model size, optimizer, normalization, and task.

A useful first response to exploding loss is to reduce the learning rate by a factor of 3 or 10.

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=0.01,
)

Use Learning Rate Warmup

Large models, transformers, and mixed-precision training often benefit from learning rate warmup. Warmup starts with a small learning rate and increases it gradually during early training.

This avoids large destructive updates before the model reaches a reasonable activation and gradient regime.

A simple warmup schedule:

def warmup_lr(step, base_lr, warmup_steps):
    if step < warmup_steps:
        return base_lr * (step + 1) / warmup_steps
    return base_lr

With a PyTorch optimizer:

base_lr = 3e-4
warmup_steps = 1000

for step, batch in enumerate(loader):
    lr = warmup_lr(step, base_lr, warmup_steps)
    for group in optimizer.param_groups:
        group["lr"] = lr

    loss = training_step(batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

Warmup is especially common for transformer pretraining and fine-tuning.

Clip Gradients When Needed

Gradient clipping limits exploding gradients before the optimizer step.

loss.backward()

torch.nn.utils.clip_grad_norm_(
    model.parameters(),
    max_norm=1.0,
)

optimizer.step()

Gradient clipping is common in recurrent networks, transformers, reinforcement learning, and unstable generative models.

It should not be used to hide every problem. If clipping is active on almost every step with very large unclipped norms, the learning rate, initialization, or architecture may still be wrong.

Use Mixed Precision Carefully

Mixed precision can speed up training and reduce memory use, but it introduces numerical risks. In PyTorch, automatic mixed precision is usually handled with torch.autocast and GradScaler.

scaler = torch.cuda.amp.GradScaler()

for x, y in loader:
    optimizer.zero_grad(set_to_none=True)

    with torch.cuda.amp.autocast():
        logits = model(x)
        loss = loss_fn(logits, y)

    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    scaler.step(optimizer)
    scaler.update()

The call to scaler.unscale_(optimizer) is important when clipping gradients. It converts scaled gradients back to their real scale before clipping.

On hardware that supports it, bfloat16 is often more stable than float16 because it has a wider exponent range.

Check for Non-Finite Values

A direct stability check is to detect non-finite losses, activations, gradients, and parameters.

def assert_finite_tensor(name, x):
    if not torch.isfinite(x).all():
        raise RuntimeError(f"{name} contains non-finite values")

During training:

logits = model(x)
assert_finite_tensor("logits", logits)

loss = loss_fn(logits, y)
assert_finite_tensor("loss", loss)

loss.backward()

for name, param in model.named_parameters():
    if param.grad is not None:
        assert_finite_tensor(f"{name}.grad", param.grad)

This makes failures local. Instead of discovering nan after many steps, you can stop at the operation where the problem first appears.

Track Activation and Gradient Statistics

Logging tensor statistics gives visibility into training dynamics.

def parameter_report(model):
    rows = []

    for name, param in model.named_parameters():
        grad_norm = None
        if param.grad is not None:
            grad_norm = param.grad.norm().item()

        rows.append({
            "name": name,
            "shape": tuple(param.shape),
            "param_mean": param.data.mean().item(),
            "param_std": param.data.std().item(),
            "param_norm": param.data.norm().item(),
            "grad_norm": grad_norm,
        })

    return rows

Forward hooks can inspect activations:

def add_activation_hooks(model):
    hooks = []

    def hook(name):
        def fn(module, inputs, output):
            if isinstance(output, torch.Tensor):
                print(
                    name,
                    "mean=", output.mean().item(),
                    "std=", output.std().item(),
                    "max=", output.abs().max().item(),
                )
        return fn

    for name, module in model.named_modules():
        if len(list(module.children())) == 0:
            hooks.append(module.register_forward_hook(hook(name)))

    return hooks

Use hooks for debugging, not for every production training run. They add overhead and can produce large logs.

Avoid Accidental Graph Retention

A common PyTorch stability problem is retaining computation graphs unintentionally. This usually appears as GPU memory growing every training step.

Bad pattern:

losses = []

for batch in loader:
    loss = training_step(batch)
    losses.append(loss)  # Keeps graph alive

Correct pattern:

losses = []

for batch in loader:
    loss = training_step(batch)
    losses.append(loss.item())

Another common issue is accumulating tensors without detaching them:

running_loss = running_loss + loss

Use scalar values instead:

running_loss += loss.item()

When storing model outputs for later analysis, detach them first:

saved_logits.append(logits.detach().cpu())

Use zero_grad Correctly

Gradients accumulate in PyTorch by default. This is useful for gradient accumulation, but it is a common source of bugs.

A normal training step should clear old gradients:

optimizer.zero_grad(set_to_none=True)

logits = model(x)
loss = loss_fn(logits, y)

loss.backward()
optimizer.step()

Using set_to_none=True can reduce memory use and make it easier to detect parameters that did not receive gradients.

When doing gradient accumulation, clear gradients only at accumulation boundaries:

optimizer.zero_grad(set_to_none=True)

for step, batch in enumerate(loader):
    loss = training_step(batch)
    loss = loss / accumulation_steps
    loss.backward()

    if (step + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

Dividing the loss by accumulation_steps keeps gradient scale comparable to a larger batch.

Validate the Training Loop on a Tiny Batch

Before training at scale, test whether the model can overfit a tiny batch. This is one of the best debugging checks.

Take one small batch and train on it repeatedly:

x_small, y_small = next(iter(loader))

for step in range(500):
    optimizer.zero_grad(set_to_none=True)

    logits = model(x_small)
    loss = loss_fn(logits, y_small)

    loss.backward()
    optimizer.step()

    if step % 50 == 0:
        print(step, loss.item())

A sufficiently expressive model should drive the training loss very low on a tiny batch. If it cannot, check the model, loss function, labels, optimizer, data preprocessing, and gradient flow.

This test does not prove the model generalizes. It proves that the training loop can learn at all.

Check Loss and Output Compatibility

Many unstable training runs come from a mismatch between model outputs and loss function.

For classification with nn.CrossEntropyLoss, the model should output raw logits, not softmax probabilities.

Correct:

logits = model(x)
loss = nn.CrossEntropyLoss()(logits, labels)

Incorrect:

probs = torch.softmax(model(x), dim=-1)
loss = nn.CrossEntropyLoss()(probs, labels)

CrossEntropyLoss internally applies log-softmax. Passing probabilities can make optimization worse.

For binary classification with nn.BCEWithLogitsLoss, the model should output raw logits.

Correct:

logits = model(x).squeeze(-1)
loss = nn.BCEWithLogitsLoss()(logits, targets.float())

Incorrect:

probs = torch.sigmoid(model(x)).squeeze(-1)
loss = nn.BCEWithLogitsLoss()(probs, targets.float())

Use BCELoss only when you explicitly pass probabilities. In most cases, BCEWithLogitsLoss is more numerically stable.

Handle Masks Carefully

Masks are common in sequence models and attention. Bad masking can produce nan.

For attention, a mask may set invalid logits to a large negative value before softmax:

scores = scores.masked_fill(mask == 0, -1e9)
weights = torch.softmax(scores, dim=-1)

In lower precision, very large constants can cause problems. A dtype-aware value is safer:

large_negative = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask == 0, large_negative)
weights = torch.softmax(scores, dim=-1)

Also ensure that each row has at least one valid position. A softmax over all masked positions can produce invalid values.

Recommended Debug Order

When training is unstable, debug in this order:

StepCheck
1Verify input shapes, label shapes, and dtypes
2Check output and loss compatibility
3Overfit one tiny batch
4Log loss, gradient norms, and activation statistics
5Lower the learning rate
6Check initialization and normalization
7Add gradient clipping
8Disable mixed precision temporarily
9Check for accidental graph retention
10Inspect data for invalid values or bad labels

This order catches the most common mistakes before changing the architecture.

Summary

Stable training comes from controlling numerical scale across the whole training system. Initialization sets the starting point. Normalization controls activation scale. Residual connections improve gradient flow. Learning rate schedules control update size. Gradient clipping limits rare large updates. Mixed precision requires care. The training loop must avoid graph retention, wrong loss usage, and incorrect gradient accumulation.

A practical PyTorch workflow is simple: start with correct data scale, use architecture-appropriate initialization and normalization, choose a conservative learning rate, verify the training loop on a tiny batch, and inspect gradients when behavior looks wrong.