Stochastic depth regularizes deep residual networks by randomly skipping residual branches during training.
Stochastic depth regularizes deep residual networks by randomly skipping residual branches during training. It is also commonly called drop path. Unlike ordinary dropout, which drops individual activations, stochastic depth drops a whole computation path.
A standard residual block computes
where is the input and is the residual branch. With stochastic depth, the branch is randomly kept or removed:
where is sampled from a Bernoulli distribution.
Residual Branches
Residual architectures are built around the idea that a layer should learn a correction to its input rather than a completely new representation.
The block
contains two paths. The identity path sends forward unchanged. The residual path computes . If the residual path is useful, it modifies the representation. If not, the block can behave close to the identity function.
This structure makes residual networks suitable for stochastic depth. The model can skip while preserving a valid output through the identity path.
Training Rule
Let be the survival probability for block . During training, sample
The block output is
The scaling factor keeps the expected residual contribution unchanged:
At inference time, all branches are used:
So stochastic depth increases training noise but adds no inference cost.
Why It Regularizes
Stochastic depth trains many subnetworks inside one large model. In each training step, a different subset of residual branches is active. The model cannot rely on every block always being present.
This discourages fragile co-adaptation between neighboring layers. It also forces later blocks to handle representations that may have passed through different computational paths.
In very deep networks, stochastic depth can also make optimization easier. Some updates pass through shallower effective networks, giving gradients shorter paths from output to earlier layers.
Layer-Dependent Drop Rates
Stochastic depth is usually stronger in deeper layers than in earlier layers. Early layers often learn basic features, so dropping them too often can harm training.
A common schedule increases drop probability linearly with depth:
where is the number of residual blocks, is the block index, and is the maximum drop probability.
The survival probability is
For example, if a model has 10 residual blocks and , the last block has drop probability 0.2, while early blocks have smaller drop probabilities.
PyTorch Implementation
A minimal implementation drops whole residual branches per sample:
import torch
from torch import nn
class DropPath(nn.Module):
def __init__(self, drop_prob: float = 0.0):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1.0 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
mask = torch.empty(shape, dtype=x.dtype, device=x.device)
mask.bernoulli_(keep_prob)
return x * mask / keep_probThe mask shape is important. For an input of shape [B, C, H, W], the mask has shape [B, 1, 1, 1]. This means each sample keeps or drops the whole branch, rather than dropping independent pixels.
A residual MLP block can use it as follows:
class ResidualMLPBlock(nn.Module):
def __init__(self, dim: int, hidden_dim: int, drop_prob: float):
super().__init__()
self.branch = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
self.drop_path = DropPath(drop_prob)
def forward(self, x):
return x + self.drop_path(self.branch(x))During training, the residual branch is randomly removed. During evaluation, DropPath returns the branch unchanged.
Stochastic Depth in Vision Transformers
Vision transformers use residual branches around both attention and MLP sublayers:
Stochastic depth can be inserted on each residual branch:
class TransformerBlock(nn.Module):
def __init__(self, dim, attention, mlp, drop_prob):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attention = attention
self.norm2 = nn.LayerNorm(dim)
self.mlp = mlp
self.drop_path1 = DropPath(drop_prob)
self.drop_path2 = DropPath(drop_prob)
def forward(self, x):
x = x + self.drop_path1(self.attention(self.norm1(x)))
x = x + self.drop_path2(self.mlp(self.norm2(x)))
return xThis pattern is common in ViT, Swin Transformer, ConvNeXt, and similar residual architectures.
Stochastic Depth in Convolutional Networks
In convolutional residual networks, stochastic depth is applied to convolutional residual branches.
class ConvResidualBlock(nn.Module):
def __init__(self, channels: int, drop_prob: float):
super().__init__()
self.branch = nn.Sequential(
nn.BatchNorm2d(channels),
nn.ReLU(),
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
nn.BatchNorm2d(channels),
nn.ReLU(),
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
)
self.drop_path = DropPath(drop_prob)
def forward(self, x):
return x + self.drop_path(self.branch(x))This direct form assumes that input and output shapes match. If the block changes spatial resolution or channel count, the identity path must include a projection. In such cases, care is needed so that dropping the residual branch still leaves a valid tensor shape.
Built-In Utilities
The PyTorch ecosystem provides stochastic depth implementations. In torchvision:
from torchvision.ops import StochasticDepth
drop_path = StochasticDepth(p=0.1, mode="row")The argument p is the drop probability. The mode controls how the random mask is sampled.
| Mode | Meaning |
|---|---|
"row" | Drop independently for each sample |
"batch" | Drop the branch for the whole batch |
Most modern vision models use per-sample dropping, corresponding to "row".
Choosing Drop Probability
The drop probability should increase with model depth and regularization need.
| Setting | Typical maximum drop probability |
|---|---|
| Small model | 0.0 to 0.1 |
| Medium model | 0.1 to 0.2 |
| Large model | 0.2 to 0.4 |
| Very large vision model | 0.4 or higher, with validation |
Too little stochastic depth may have no visible effect. Too much can cause underfitting or unstable optimization.
A practical schedule:
def drop_path_rate(block_index, num_blocks, max_drop):
if num_blocks <= 1:
return 0.0
return max_drop * block_index / (num_blocks - 1)This gives the first block a drop probability near 0 and the last block a drop probability near max_drop.
Relationship to Dropout
Dropout and stochastic depth inject different kinds of noise.
| Method | What is dropped | Granularity |
|---|---|---|
| Dropout | Activation entries | Element-level |
| Dropout2d | Feature channels | Channel-level |
| Stochastic depth | Residual branches | Block-level |
Dropout is useful in fully connected layers, attention projections, and classifier heads. Stochastic depth is useful when the architecture has residual branches.
They can be used together, but the total regularization strength must be controlled. A model with heavy augmentation, label smoothing, weight decay, dropout, and stochastic depth may underfit.
Failure Modes
Stochastic depth has several common failure modes.
If early layers are dropped too often, the model may fail to learn stable low-level representations. Use a depth-dependent schedule that starts near zero.
If the drop probability is too high, the network may underfit. Training loss remains high, and validation performance does not improve.
If the branch changes shape, naive dropping can cause shape errors. Stochastic depth is easiest when the branch output has the same shape as the identity input.
If validation is run in training mode, stochastic depth remains active and evaluation becomes noisy. Always call:
model.eval()before validation or inference.
Practical Guidance
Use stochastic depth for residual networks, vision transformers, ConvNeXt-style architectures, and other models with explicit skip connections.
Start with a small maximum drop probability such as 0.1. Increase it only when validation performance improves.
Use lower drop rates for shallow models and higher rates for very deep models.
Apply stochastic depth during training only. It should disappear at inference time.
Summary
Stochastic depth randomly skips residual branches during training. It trains a family of shallower subnetworks inside one deep model and reduces co-adaptation between layers.
It is especially useful in modern residual architectures such as ResNets, ConvNeXt models, and vision transformers.
At inference time, all branches are active, so stochastic depth improves regularization without adding inference cost.