Stable training means that a model can make steady progress without numerical collapse, uncontrolled gradients, or large oscillations in the loss.
Stable training means that a model can make steady progress without numerical collapse, uncontrolled gradients, or large oscillations in the loss. Deep networks are sensitive systems. A small problem in scale, initialization, normalization, data preprocessing, optimizer settings, or precision can compound across many layers.
This section combines the ideas from the previous sections into a practical view of stable PyTorch training.
What Stability Means
A training run is stable when several quantities remain within useful ranges:
| Quantity | Stable behavior |
|---|---|
| Loss | Decreases over time, with tolerable noise |
| Activations | Neither collapse to zero nor grow without bound |
| Gradients | Finite and large enough to learn |
| Parameters | Change gradually rather than jumping wildly |
| Learning rate | Large enough for progress, small enough for control |
| Validation metric | Improves or degrades slowly, rather than randomly |
Instability usually appears as one of the following:
| Symptom | Common cause |
|---|---|
Loss becomes nan | Too large learning rate, invalid operation, overflow |
| Loss explodes | Large gradients, bad initialization, bad data scale |
| Loss does not move | Vanishing gradients, learning rate too small, frozen parameters |
| Training accuracy rises but validation drops | Overfitting, data leakage, poor regularization |
| Gradients are zero | Detached graph, saturated activations, wrong loss |
| GPU memory grows each step | Stored computation graphs, missing detach() |
A stable training run does not need to be perfectly smooth. Mini-batch training is noisy by design. The question is whether the noise remains bounded and useful.
Start with Correct Data Scale
Deep learning models assume inputs have reasonable numerical scale. If input values are too large, early activations may become too large. If input values are poorly centered, optimization may become slower.
For images, a common preprocessing step is to convert pixel values from integers in [0, 255] to floating-point values in [0, 1], then normalize by dataset mean and standard deviation.
import torch
from torchvision import transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
),
])For tabular data, standardization is common:
The mean and standard deviation should be computed on the training set only. Validation and test data should use the same training-set statistics.
For token IDs, do not normalize the integer IDs themselves. Token IDs are indices into an embedding table. The embedding vectors are learned.
Use Initialization That Matches the Architecture
Initialization controls the starting scale of activations and gradients.
For ReLU-based MLPs and CNNs, Kaiming initialization is a strong default:
from torch import nn
def init_relu_network(module):
if isinstance(module, (nn.Linear, nn.Conv2d)):
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
if module.bias is not None:
nn.init.zeros_(module.bias)For tanh-based networks, Xavier initialization is usually better:
def init_tanh_network(module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)For transformers, use the architecture’s intended initialization when possible. Many transformer implementations use small normal initialization for linear and embedding weights.
def init_transformer_style(module):
if isinstance(module, (nn.Linear, nn.Embedding)):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
nn.init.zeros_(module.bias)
if isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)The best initialization rule depends on architecture, activation function, depth, residual scaling, and normalization.
Choose Normalization Deliberately
Normalization reduces uncontrolled changes in activation scale.
| Architecture | Common normalization |
|---|---|
| CNN with moderate or large batches | BatchNorm |
| CNN with small batches | GroupNorm |
| Transformer | LayerNorm or RMSNorm |
| Style transfer model | InstanceNorm |
| RNN or sequence model | LayerNorm |
Use normalization as an architectural choice, not as a patch applied after the model fails.
For a transformer block, a pre-normalization layout is usually stable:
class PreNormMLPBlock(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.ffn = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return x + self.ffn(self.norm(x))For a convolutional block with small batches, group normalization is often safer than batch normalization:
class ConvGNBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.GroupNorm(num_groups=8, num_channels=out_channels),
nn.ReLU(),
)
def forward(self, x):
return self.net(x)Use Residual Connections for Depth
Residual connections help deep networks optimize by giving gradients a direct path through the model.
A residual block computes
This update form helps the network learn small corrections rather than entirely new representations at each layer.
For very deep networks, consider residual scaling:
class ScaledResidualBlock(nn.Module):
def __init__(self, dim, hidden_dim, scale=0.1):
super().__init__()
self.scale = scale
self.f = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return x + self.scale * self.f(x)Residual scaling is useful when many residual additions cause activation magnitudes to grow with depth.
Select a Conservative Learning Rate First
The learning rate is one of the most common causes of instability. A rate that is too high can make the loss explode. A rate that is too low can make training appear broken.
A conservative starting point:
| Optimizer | Typical first learning rate |
|---|---|
| SGD with momentum | 1e-2 to 1e-1 |
| Adam | 1e-3 |
| AdamW for transformers | 1e-4 to 3e-4 |
| Fine-tuning pretrained models | 1e-5 to 5e-5 |
These are only starting points. The right value depends on batch size, model size, optimizer, normalization, and task.
A useful first response to exploding loss is to reduce the learning rate by a factor of 3 or 10.
optimizer = torch.optim.AdamW(
model.parameters(),
lr=3e-4,
weight_decay=0.01,
)Use Learning Rate Warmup
Large models, transformers, and mixed-precision training often benefit from learning rate warmup. Warmup starts with a small learning rate and increases it gradually during early training.
This avoids large destructive updates before the model reaches a reasonable activation and gradient regime.
A simple warmup schedule:
def warmup_lr(step, base_lr, warmup_steps):
if step < warmup_steps:
return base_lr * (step + 1) / warmup_steps
return base_lrWith a PyTorch optimizer:
base_lr = 3e-4
warmup_steps = 1000
for step, batch in enumerate(loader):
lr = warmup_lr(step, base_lr, warmup_steps)
for group in optimizer.param_groups:
group["lr"] = lr
loss = training_step(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)Warmup is especially common for transformer pretraining and fine-tuning.
Clip Gradients When Needed
Gradient clipping limits exploding gradients before the optimizer step.
loss.backward()
torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=1.0,
)
optimizer.step()Gradient clipping is common in recurrent networks, transformers, reinforcement learning, and unstable generative models.
It should not be used to hide every problem. If clipping is active on almost every step with very large unclipped norms, the learning rate, initialization, or architecture may still be wrong.
Use Mixed Precision Carefully
Mixed precision can speed up training and reduce memory use, but it introduces numerical risks. In PyTorch, automatic mixed precision is usually handled with torch.autocast and GradScaler.
scaler = torch.cuda.amp.GradScaler()
for x, y in loader:
optimizer.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast():
logits = model(x)
loss = loss_fn(logits, y)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()The call to scaler.unscale_(optimizer) is important when clipping gradients. It converts scaled gradients back to their real scale before clipping.
On hardware that supports it, bfloat16 is often more stable than float16 because it has a wider exponent range.
Check for Non-Finite Values
A direct stability check is to detect non-finite losses, activations, gradients, and parameters.
def assert_finite_tensor(name, x):
if not torch.isfinite(x).all():
raise RuntimeError(f"{name} contains non-finite values")During training:
logits = model(x)
assert_finite_tensor("logits", logits)
loss = loss_fn(logits, y)
assert_finite_tensor("loss", loss)
loss.backward()
for name, param in model.named_parameters():
if param.grad is not None:
assert_finite_tensor(f"{name}.grad", param.grad)This makes failures local. Instead of discovering nan after many steps, you can stop at the operation where the problem first appears.
Track Activation and Gradient Statistics
Logging tensor statistics gives visibility into training dynamics.
def parameter_report(model):
rows = []
for name, param in model.named_parameters():
grad_norm = None
if param.grad is not None:
grad_norm = param.grad.norm().item()
rows.append({
"name": name,
"shape": tuple(param.shape),
"param_mean": param.data.mean().item(),
"param_std": param.data.std().item(),
"param_norm": param.data.norm().item(),
"grad_norm": grad_norm,
})
return rowsForward hooks can inspect activations:
def add_activation_hooks(model):
hooks = []
def hook(name):
def fn(module, inputs, output):
if isinstance(output, torch.Tensor):
print(
name,
"mean=", output.mean().item(),
"std=", output.std().item(),
"max=", output.abs().max().item(),
)
return fn
for name, module in model.named_modules():
if len(list(module.children())) == 0:
hooks.append(module.register_forward_hook(hook(name)))
return hooksUse hooks for debugging, not for every production training run. They add overhead and can produce large logs.
Avoid Accidental Graph Retention
A common PyTorch stability problem is retaining computation graphs unintentionally. This usually appears as GPU memory growing every training step.
Bad pattern:
losses = []
for batch in loader:
loss = training_step(batch)
losses.append(loss) # Keeps graph aliveCorrect pattern:
losses = []
for batch in loader:
loss = training_step(batch)
losses.append(loss.item())Another common issue is accumulating tensors without detaching them:
running_loss = running_loss + lossUse scalar values instead:
running_loss += loss.item()When storing model outputs for later analysis, detach them first:
saved_logits.append(logits.detach().cpu())Use zero_grad Correctly
Gradients accumulate in PyTorch by default. This is useful for gradient accumulation, but it is a common source of bugs.
A normal training step should clear old gradients:
optimizer.zero_grad(set_to_none=True)
logits = model(x)
loss = loss_fn(logits, y)
loss.backward()
optimizer.step()Using set_to_none=True can reduce memory use and make it easier to detect parameters that did not receive gradients.
When doing gradient accumulation, clear gradients only at accumulation boundaries:
optimizer.zero_grad(set_to_none=True)
for step, batch in enumerate(loader):
loss = training_step(batch)
loss = loss / accumulation_steps
loss.backward()
if (step + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad(set_to_none=True)Dividing the loss by accumulation_steps keeps gradient scale comparable to a larger batch.
Validate the Training Loop on a Tiny Batch
Before training at scale, test whether the model can overfit a tiny batch. This is one of the best debugging checks.
Take one small batch and train on it repeatedly:
x_small, y_small = next(iter(loader))
for step in range(500):
optimizer.zero_grad(set_to_none=True)
logits = model(x_small)
loss = loss_fn(logits, y_small)
loss.backward()
optimizer.step()
if step % 50 == 0:
print(step, loss.item())A sufficiently expressive model should drive the training loss very low on a tiny batch. If it cannot, check the model, loss function, labels, optimizer, data preprocessing, and gradient flow.
This test does not prove the model generalizes. It proves that the training loop can learn at all.
Check Loss and Output Compatibility
Many unstable training runs come from a mismatch between model outputs and loss function.
For classification with nn.CrossEntropyLoss, the model should output raw logits, not softmax probabilities.
Correct:
logits = model(x)
loss = nn.CrossEntropyLoss()(logits, labels)Incorrect:
probs = torch.softmax(model(x), dim=-1)
loss = nn.CrossEntropyLoss()(probs, labels)CrossEntropyLoss internally applies log-softmax. Passing probabilities can make optimization worse.
For binary classification with nn.BCEWithLogitsLoss, the model should output raw logits.
Correct:
logits = model(x).squeeze(-1)
loss = nn.BCEWithLogitsLoss()(logits, targets.float())Incorrect:
probs = torch.sigmoid(model(x)).squeeze(-1)
loss = nn.BCEWithLogitsLoss()(probs, targets.float())Use BCELoss only when you explicitly pass probabilities. In most cases, BCEWithLogitsLoss is more numerically stable.
Handle Masks Carefully
Masks are common in sequence models and attention. Bad masking can produce nan.
For attention, a mask may set invalid logits to a large negative value before softmax:
scores = scores.masked_fill(mask == 0, -1e9)
weights = torch.softmax(scores, dim=-1)In lower precision, very large constants can cause problems. A dtype-aware value is safer:
large_negative = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask == 0, large_negative)
weights = torch.softmax(scores, dim=-1)Also ensure that each row has at least one valid position. A softmax over all masked positions can produce invalid values.
Recommended Debug Order
When training is unstable, debug in this order:
| Step | Check |
|---|---|
| 1 | Verify input shapes, label shapes, and dtypes |
| 2 | Check output and loss compatibility |
| 3 | Overfit one tiny batch |
| 4 | Log loss, gradient norms, and activation statistics |
| 5 | Lower the learning rate |
| 6 | Check initialization and normalization |
| 7 | Add gradient clipping |
| 8 | Disable mixed precision temporarily |
| 9 | Check for accidental graph retention |
| 10 | Inspect data for invalid values or bad labels |
This order catches the most common mistakes before changing the architecture.
Summary
Stable training comes from controlling numerical scale across the whole training system. Initialization sets the starting point. Normalization controls activation scale. Residual connections improve gradient flow. Learning rate schedules control update size. Gradient clipping limits rare large updates. Mixed precision requires care. The training loop must avoid graph retention, wrong loss usage, and incorrect gradient accumulation.
A practical PyTorch workflow is simple: start with correct data scale, use architecture-appropriate initialization and normalization, choose a conservative learning rate, verify the training loop on a tiny batch, and inspect gradients when behavior looks wrong.