Deep networks train by sending information in two directions. The forward pass sends activations from the input layer to the output layer. The backward pass sends gradients from the loss back to earlier layers. Stable training requires both signals to remain numerically useful.
When gradients become extremely small as they move backward through the network, we say gradients vanish. When gradients become extremely large, we say gradients explode. Both problems make optimization difficult.
The Chain Rule Across Many Layers
Consider a deep network built from layers
The loss is
To update early layers, backpropagation applies the chain rule repeatedly:
Each factor is a local derivative. In vector form, these factors are Jacobian matrices. A deep network multiplies many such terms. If their typical scale is less than 1, the product can shrink rapidly. If their typical scale is greater than 1, the product can grow rapidly.
This is the root of vanishing and exploding gradients.
Vanishing Gradients
Vanishing gradients occur when gradients become too small for earlier layers to learn effectively.
Suppose each layer roughly multiplies the gradient by . After 20 layers, the scale becomes approximately
A gradient of this size produces very small parameter updates. The final layers may train, but the early layers barely move.
This was a major difficulty in early deep networks, especially with sigmoid and tanh activations. The sigmoid function is
Its derivative is at most . Therefore, repeated multiplication through sigmoid layers can shrink gradients quickly.
import torch
x = torch.linspace(-8, 8, 1000)
sigmoid = torch.sigmoid(x)
derivative = sigmoid * (1 - sigmoid)
print(derivative.max()) # tensor close to 0.25The issue becomes worse when activations saturate. For very positive or very negative inputs, sigmoid outputs are close to 1 or 0. In those regions, the derivative is close to zero.
Exploding Gradients
Exploding gradients occur when gradients become too large.
Suppose each layer roughly multiplies the gradient by . After 50 layers, the scale becomes
Such gradients can cause huge parameter updates. The loss may become unstable, parameters may become inf or nan, and training may fail.
Exploding gradients are common in recurrent neural networks because the same transition is applied repeatedly across time. A sequence of length 100 behaves like a 100-layer computation graph. If the recurrent dynamics amplify gradients, the gradient can grow rapidly.
Why Depth Makes the Problem Severe
A shallow network contains only a few multiplications in the backward path. A deep network contains many.
For a simplified scalar model,
The derivative with respect to an early hidden value contains a product of many weights:
If the weights have magnitude less than 1 on average, the product tends to vanish. If they have magnitude greater than 1 on average, the product tends to explode.
Real networks use matrices and nonlinearities, but the same principle applies. Training stability depends on the singular values of layer Jacobians. If these singular values are usually below 1, gradients shrink. If they are usually above 1, gradients grow.
Symptoms in PyTorch
Vanishing gradients often appear as slow or stalled learning. The loss may decrease only in the last layers, while early layers receive tiny gradients.
Exploding gradients often appear as sudden loss spikes, nan values, or unstable parameter norms.
You can inspect gradient norms during training:
def grad_stats(model):
for name, param in model.named_parameters():
if param.grad is None:
continue
print(
name,
"grad_norm=",
param.grad.norm().item(),
"param_norm=",
param.data.norm().item(),
)Use it after loss.backward():
optimizer.zero_grad()
loss = model_loss(model, batch)
loss.backward()
grad_stats(model)
optimizer.step()A typical warning sign is that early-layer gradient norms are much smaller than later-layer gradient norms. For exploding gradients, norms may become extremely large or non-finite.
You can also detect non-finite gradients:
for name, param in model.named_parameters():
if param.grad is not None:
if not torch.isfinite(param.grad).all():
print("Non-finite gradient:", name)Activation Functions and Gradient Flow
Activation functions affect gradient flow through their derivatives.
Sigmoid and tanh can saturate. When their inputs have large magnitude, their derivatives become small. This can make gradients vanish.
ReLU helps because its derivative is 1 for positive inputs:
For active units, gradients pass through without shrinkage from the activation derivative. This is one reason ReLU made deep networks easier to train.
However, ReLU has its own failure mode. For negative inputs, the derivative is zero. A unit that outputs zero for most inputs may stop learning. This is sometimes called a dead ReLU.
Leaky ReLU reduces this problem by allowing a small slope for negative inputs:
In PyTorch:
relu = torch.nn.ReLU()
leaky_relu = torch.nn.LeakyReLU(negative_slope=0.01)
gelu = torch.nn.GELU()Modern transformer models commonly use GELU or related smooth activations. These work well with residual connections and normalization.
Initialization and Gradient Scale
Initialization controls the initial scale of activations and gradients. If weights start too small, gradients may vanish. If weights start too large, gradients may explode.
For ReLU networks, Kaiming initialization is usually appropriate:
layer = torch.nn.Linear(256, 256)
torch.nn.init.kaiming_normal_(layer.weight, nonlinearity="relu")
torch.nn.init.zeros_(layer.bias)For tanh-like networks, Xavier initialization is often appropriate:
layer = torch.nn.Linear(256, 256)
torch.nn.init.xavier_uniform_(layer.weight)
torch.nn.init.zeros_(layer.bias)Initialization cannot solve every stability problem, but it sets the network in a reasonable numerical regime before training begins.
Normalization Layers
Normalization layers reduce uncontrolled changes in activation scale.
Batch normalization normalizes activations using batch statistics. Layer normalization normalizes activations across the feature dimension for each example. Transformers usually use layer normalization because sequence batches often have variable lengths and because layer normalization works naturally with autoregressive inference.
A simple MLP block with layer normalization:
class MLPBlock(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.LayerNorm(dim),
torch.nn.Linear(dim, 4 * dim),
torch.nn.GELU(),
torch.nn.Linear(4 * dim, dim),
)
def forward(self, x):
return self.net(x)Normalization helps keep activations in a range where gradients remain useful.
Residual Connections
Residual connections give gradients a shorter path through the network.
Instead of learning
a residual block learns
The identity path allows information and gradients to move through many layers more easily. During backpropagation, gradients can flow through the skip connection even when the residual branch has poor conditioning.
In PyTorch:
class ResidualBlock(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.f = torch.nn.Sequential(
torch.nn.LayerNorm(dim),
torch.nn.Linear(dim, 4 * dim),
torch.nn.GELU(),
torch.nn.Linear(4 * dim, dim),
)
def forward(self, x):
return x + self.f(x)Residual connections are a central reason very deep CNNs and transformers can be trained reliably.
Gradient Clipping
Gradient clipping is a direct method for controlling exploding gradients. It limits the norm or value of gradients before the optimizer step.
Norm clipping rescales gradients when their total norm exceeds a threshold:
optimizer.zero_grad()
loss = model_loss(model, batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()This is common in RNNs, transformers, reinforcement learning, and other unstable training regimes.
Value clipping clamps each gradient entry:
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)Norm clipping is usually preferred because it preserves the direction of the full gradient vector while limiting its magnitude.
Stable Training Checklist
When gradients vanish or explode, check the following:
| Problem | Common fix |
|---|---|
| Activations shrink across layers | Use better initialization, normalization, residual connections |
| Activations grow across layers | Reduce learning rate, improve initialization, use normalization |
| Gradients vanish in early layers | Use ReLU-like activations, residual connections, normalization |
| Gradients explode | Use gradient clipping, reduce learning rate |
Loss becomes nan | Check learning rate, dtype, invalid operations, gradient norms |
| ReLU units die | Lower learning rate, use Leaky ReLU or GELU |
| RNN training unstable | Use LSTM or GRU, clip gradients, shorten unroll length |
A practical debugging step is to log activation and gradient statistics layer by layer. Most stability problems leave a numerical trace before they fully break training.
Summary
Vanishing and exploding gradients arise from repeated multiplication by layer derivatives during backpropagation. In deep networks, small deviations in scale can compound across many layers.
Vanishing gradients make early layers learn slowly. Exploding gradients make optimization unstable. Initialization, activation choice, normalization, residual connections, learning rate control, and gradient clipping are the main tools for managing these problems.
Stable deep learning requires preserving useful signal in both the forward pass and the backward pass.