# Vanishing and Exploding Gradients

Deep networks train by sending information in two directions. The forward pass sends activations from the input layer to the output layer. The backward pass sends gradients from the loss back to earlier layers. Stable training requires both signals to remain numerically useful.

When gradients become extremely small as they move backward through the network, we say gradients vanish. When gradients become extremely large, we say gradients explode. Both problems make optimization difficult.

### The Chain Rule Across Many Layers

Consider a deep network built from layers

$$
h_1 = f_1(h_0), \quad
h_2 = f_2(h_1), \quad \ldots, \quad
h_L = f_L(h_{L-1}).
$$

The loss is

$$
\mathcal{L} = \ell(h_L, y).
$$

To update early layers, backpropagation applies the chain rule repeatedly:

$$
\frac{\partial \mathcal{L}}{\partial h_0} =
\frac{\partial \mathcal{L}}{\partial h_L}
\frac{\partial h_L}{\partial h_{L-1}}
\frac{\partial h_{L-1}}{\partial h_{L-2}}
\cdots
\frac{\partial h_1}{\partial h_0}.
$$

Each factor is a local derivative. In vector form, these factors are Jacobian matrices. A deep network multiplies many such terms. If their typical scale is less than 1, the product can shrink rapidly. If their typical scale is greater than 1, the product can grow rapidly.

This is the root of vanishing and exploding gradients.

### Vanishing Gradients

Vanishing gradients occur when gradients become too small for earlier layers to learn effectively.

Suppose each layer roughly multiplies the gradient by $0.5$. After 20 layers, the scale becomes approximately

$$
0.5^{20} \approx 9.5 \times 10^{-7}.
$$

A gradient of this size produces very small parameter updates. The final layers may train, but the early layers barely move.

This was a major difficulty in early deep networks, especially with sigmoid and tanh activations. The sigmoid function is

$$
\sigma(x)=\frac{1}{1+e^{-x}}.
$$

Its derivative is at most $0.25$. Therefore, repeated multiplication through sigmoid layers can shrink gradients quickly.

```python
import torch

x = torch.linspace(-8, 8, 1000)
sigmoid = torch.sigmoid(x)
derivative = sigmoid * (1 - sigmoid)

print(derivative.max())  # tensor close to 0.25
```

The issue becomes worse when activations saturate. For very positive or very negative inputs, sigmoid outputs are close to 1 or 0. In those regions, the derivative is close to zero.

### Exploding Gradients

Exploding gradients occur when gradients become too large.

Suppose each layer roughly multiplies the gradient by $1.5$. After 50 layers, the scale becomes

$$
1.5^{50} \approx 6.4 \times 10^8.
$$

Such gradients can cause huge parameter updates. The loss may become unstable, parameters may become `inf` or `nan`, and training may fail.

Exploding gradients are common in recurrent neural networks because the same transition is applied repeatedly across time. A sequence of length 100 behaves like a 100-layer computation graph. If the recurrent dynamics amplify gradients, the gradient can grow rapidly.

### Why Depth Makes the Problem Severe

A shallow network contains only a few multiplications in the backward path. A deep network contains many.

For a simplified scalar model,

$$
h_L = w_L w_{L-1} \cdots w_1 x.
$$

The derivative with respect to an early hidden value contains a product of many weights:

$$
\frac{\partial h_L}{\partial h_0} =
w_L w_{L-1} \cdots w_1.
$$

If the weights have magnitude less than 1 on average, the product tends to vanish. If they have magnitude greater than 1 on average, the product tends to explode.

Real networks use matrices and nonlinearities, but the same principle applies. Training stability depends on the singular values of layer Jacobians. If these singular values are usually below 1, gradients shrink. If they are usually above 1, gradients grow.

### Symptoms in PyTorch

Vanishing gradients often appear as slow or stalled learning. The loss may decrease only in the last layers, while early layers receive tiny gradients.

Exploding gradients often appear as sudden loss spikes, `nan` values, or unstable parameter norms.

You can inspect gradient norms during training:

```python
def grad_stats(model):
    for name, param in model.named_parameters():
        if param.grad is None:
            continue

        print(
            name,
            "grad_norm=",
            param.grad.norm().item(),
            "param_norm=",
            param.data.norm().item(),
        )
```

Use it after `loss.backward()`:

```python
optimizer.zero_grad()
loss = model_loss(model, batch)
loss.backward()

grad_stats(model)

optimizer.step()
```

A typical warning sign is that early-layer gradient norms are much smaller than later-layer gradient norms. For exploding gradients, norms may become extremely large or non-finite.

You can also detect non-finite gradients:

```python
for name, param in model.named_parameters():
    if param.grad is not None:
        if not torch.isfinite(param.grad).all():
            print("Non-finite gradient:", name)
```

### Activation Functions and Gradient Flow

Activation functions affect gradient flow through their derivatives.

Sigmoid and tanh can saturate. When their inputs have large magnitude, their derivatives become small. This can make gradients vanish.

ReLU helps because its derivative is 1 for positive inputs:

$$
\operatorname{ReLU}(x) = \max(0,x).
$$

For active units, gradients pass through without shrinkage from the activation derivative. This is one reason ReLU made deep networks easier to train.

However, ReLU has its own failure mode. For negative inputs, the derivative is zero. A unit that outputs zero for most inputs may stop learning. This is sometimes called a dead ReLU.

Leaky ReLU reduces this problem by allowing a small slope for negative inputs:

$$
\operatorname{LeakyReLU}(x) =
\begin{cases}
x, & x > 0, \\
\alpha x, & x \le 0.
\end{cases}
$$

In PyTorch:

```python
relu = torch.nn.ReLU()
leaky_relu = torch.nn.LeakyReLU(negative_slope=0.01)
gelu = torch.nn.GELU()
```

Modern transformer models commonly use GELU or related smooth activations. These work well with residual connections and normalization.

### Initialization and Gradient Scale

Initialization controls the initial scale of activations and gradients. If weights start too small, gradients may vanish. If weights start too large, gradients may explode.

For ReLU networks, Kaiming initialization is usually appropriate:

```python
layer = torch.nn.Linear(256, 256)
torch.nn.init.kaiming_normal_(layer.weight, nonlinearity="relu")
torch.nn.init.zeros_(layer.bias)
```

For tanh-like networks, Xavier initialization is often appropriate:

```python
layer = torch.nn.Linear(256, 256)
torch.nn.init.xavier_uniform_(layer.weight)
torch.nn.init.zeros_(layer.bias)
```

Initialization cannot solve every stability problem, but it sets the network in a reasonable numerical regime before training begins.

### Normalization Layers

Normalization layers reduce uncontrolled changes in activation scale.

Batch normalization normalizes activations using batch statistics. Layer normalization normalizes activations across the feature dimension for each example. Transformers usually use layer normalization because sequence batches often have variable lengths and because layer normalization works naturally with autoregressive inference.

A simple MLP block with layer normalization:

```python
class MLPBlock(torch.nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.LayerNorm(dim),
            torch.nn.Linear(dim, 4 * dim),
            torch.nn.GELU(),
            torch.nn.Linear(4 * dim, dim),
        )

    def forward(self, x):
        return self.net(x)
```

Normalization helps keep activations in a range where gradients remain useful.

### Residual Connections

Residual connections give gradients a shorter path through the network.

Instead of learning

$$
h_{l+1} = F(h_l),
$$

a residual block learns

$$
h_{l+1} = h_l + F(h_l).
$$

The identity path allows information and gradients to move through many layers more easily. During backpropagation, gradients can flow through the skip connection even when the residual branch has poor conditioning.

In PyTorch:

```python
class ResidualBlock(torch.nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.f = torch.nn.Sequential(
            torch.nn.LayerNorm(dim),
            torch.nn.Linear(dim, 4 * dim),
            torch.nn.GELU(),
            torch.nn.Linear(4 * dim, dim),
        )

    def forward(self, x):
        return x + self.f(x)
```

Residual connections are a central reason very deep CNNs and transformers can be trained reliably.

### Gradient Clipping

Gradient clipping is a direct method for controlling exploding gradients. It limits the norm or value of gradients before the optimizer step.

Norm clipping rescales gradients when their total norm exceeds a threshold:

```python
optimizer.zero_grad()
loss = model_loss(model, batch)
loss.backward()

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

optimizer.step()
```

This is common in RNNs, transformers, reinforcement learning, and other unstable training regimes.

Value clipping clamps each gradient entry:

```python
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)
```

Norm clipping is usually preferred because it preserves the direction of the full gradient vector while limiting its magnitude.

### Stable Training Checklist

When gradients vanish or explode, check the following:

| Problem | Common fix |
|---|---|
| Activations shrink across layers | Use better initialization, normalization, residual connections |
| Activations grow across layers | Reduce learning rate, improve initialization, use normalization |
| Gradients vanish in early layers | Use ReLU-like activations, residual connections, normalization |
| Gradients explode | Use gradient clipping, reduce learning rate |
| Loss becomes `nan` | Check learning rate, dtype, invalid operations, gradient norms |
| ReLU units die | Lower learning rate, use Leaky ReLU or GELU |
| RNN training unstable | Use LSTM or GRU, clip gradients, shorten unroll length |

A practical debugging step is to log activation and gradient statistics layer by layer. Most stability problems leave a numerical trace before they fully break training.

### Summary

Vanishing and exploding gradients arise from repeated multiplication by layer derivatives during backpropagation. In deep networks, small deviations in scale can compound across many layers.

Vanishing gradients make early layers learn slowly. Exploding gradients make optimization unstable. Initialization, activation choice, normalization, residual connections, learning rate control, and gradient clipping are the main tools for managing these problems.

Stable deep learning requires preserving useful signal in both the forward pass and the backward pass.

