Skip to content

Stochastic Depth

Stochastic depth regularizes deep residual networks by randomly skipping residual branches during training.

Stochastic depth regularizes deep residual networks by randomly skipping residual branches during training. It is also commonly called drop path. Unlike ordinary dropout, which drops individual activations, stochastic depth drops a whole computation path.

A standard residual block computes

y=x+F(x), y = x + F(x),

where xx is the input and F(x)F(x) is the residual branch. With stochastic depth, the branch is randomly kept or removed:

y=x+bF(x), y = x + bF(x),

where bb is sampled from a Bernoulli distribution.

Residual Branches

Residual architectures are built around the idea that a layer should learn a correction to its input rather than a completely new representation.

The block

y=x+F(x) y = x + F(x)

contains two paths. The identity path sends xx forward unchanged. The residual path computes F(x)F(x). If the residual path is useful, it modifies the representation. If not, the block can behave close to the identity function.

This structure makes residual networks suitable for stochastic depth. The model can skip F(x)F(x) while preserving a valid output through the identity path.

Training Rule

Let plp_l be the survival probability for block ll. During training, sample

blBernoulli(pl). b_l \sim \mathrm{Bernoulli}(p_l).

The block output is

y=x+blplFl(x). y = x + \frac{b_l}{p_l}F_l(x).

The scaling factor 1/pl1/p_l keeps the expected residual contribution unchanged:

E[blplFl(x)]=Fl(x). \mathbb{E}\left[\frac{b_l}{p_l}F_l(x)\right] = F_l(x).

At inference time, all branches are used:

y=x+Fl(x). y = x + F_l(x).

So stochastic depth increases training noise but adds no inference cost.

Why It Regularizes

Stochastic depth trains many subnetworks inside one large model. In each training step, a different subset of residual branches is active. The model cannot rely on every block always being present.

This discourages fragile co-adaptation between neighboring layers. It also forces later blocks to handle representations that may have passed through different computational paths.

In very deep networks, stochastic depth can also make optimization easier. Some updates pass through shallower effective networks, giving gradients shorter paths from output to earlier layers.

Layer-Dependent Drop Rates

Stochastic depth is usually stronger in deeper layers than in earlier layers. Early layers often learn basic features, so dropping them too often can harm training.

A common schedule increases drop probability linearly with depth:

dl=dmaxlL, d_l = d_{\max}\frac{l}{L},

where LL is the number of residual blocks, ll is the block index, and dmaxd_{\max} is the maximum drop probability.

The survival probability is

pl=1dl. p_l = 1 - d_l.

For example, if a model has 10 residual blocks and dmax=0.2d_{\max}=0.2, the last block has drop probability 0.2, while early blocks have smaller drop probabilities.

PyTorch Implementation

A minimal implementation drops whole residual branches per sample:

import torch
from torch import nn

class DropPath(nn.Module):
    def __init__(self, drop_prob: float = 0.0):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0.0 or not self.training:
            return x

        keep_prob = 1.0 - self.drop_prob

        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        mask = torch.empty(shape, dtype=x.dtype, device=x.device)
        mask.bernoulli_(keep_prob)

        return x * mask / keep_prob

The mask shape is important. For an input of shape [B, C, H, W], the mask has shape [B, 1, 1, 1]. This means each sample keeps or drops the whole branch, rather than dropping independent pixels.

A residual MLP block can use it as follows:

class ResidualMLPBlock(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, drop_prob: float):
        super().__init__()

        self.branch = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )

        self.drop_path = DropPath(drop_prob)

    def forward(self, x):
        return x + self.drop_path(self.branch(x))

During training, the residual branch is randomly removed. During evaluation, DropPath returns the branch unchanged.

Stochastic Depth in Vision Transformers

Vision transformers use residual branches around both attention and MLP sublayers:

xx+Attention(Norm(x)), x \leftarrow x + \operatorname{Attention}(\operatorname{Norm}(x)), xx+MLP(Norm(x)). x \leftarrow x + \operatorname{MLP}(\operatorname{Norm}(x)).

Stochastic depth can be inserted on each residual branch:

class TransformerBlock(nn.Module):
    def __init__(self, dim, attention, mlp, drop_prob):
        super().__init__()

        self.norm1 = nn.LayerNorm(dim)
        self.attention = attention
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = mlp

        self.drop_path1 = DropPath(drop_prob)
        self.drop_path2 = DropPath(drop_prob)

    def forward(self, x):
        x = x + self.drop_path1(self.attention(self.norm1(x)))
        x = x + self.drop_path2(self.mlp(self.norm2(x)))
        return x

This pattern is common in ViT, Swin Transformer, ConvNeXt, and similar residual architectures.

Stochastic Depth in Convolutional Networks

In convolutional residual networks, stochastic depth is applied to convolutional residual branches.

class ConvResidualBlock(nn.Module):
    def __init__(self, channels: int, drop_prob: float):
        super().__init__()

        self.branch = nn.Sequential(
            nn.BatchNorm2d(channels),
            nn.ReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
        )

        self.drop_path = DropPath(drop_prob)

    def forward(self, x):
        return x + self.drop_path(self.branch(x))

This direct form assumes that input and output shapes match. If the block changes spatial resolution or channel count, the identity path must include a projection. In such cases, care is needed so that dropping the residual branch still leaves a valid tensor shape.

Built-In Utilities

The PyTorch ecosystem provides stochastic depth implementations. In torchvision:

from torchvision.ops import StochasticDepth

drop_path = StochasticDepth(p=0.1, mode="row")

The argument p is the drop probability. The mode controls how the random mask is sampled.

ModeMeaning
"row"Drop independently for each sample
"batch"Drop the branch for the whole batch

Most modern vision models use per-sample dropping, corresponding to "row".

Choosing Drop Probability

The drop probability should increase with model depth and regularization need.

SettingTypical maximum drop probability
Small model0.0 to 0.1
Medium model0.1 to 0.2
Large model0.2 to 0.4
Very large vision model0.4 or higher, with validation

Too little stochastic depth may have no visible effect. Too much can cause underfitting or unstable optimization.

A practical schedule:

def drop_path_rate(block_index, num_blocks, max_drop):
    if num_blocks <= 1:
        return 0.0
    return max_drop * block_index / (num_blocks - 1)

This gives the first block a drop probability near 0 and the last block a drop probability near max_drop.

Relationship to Dropout

Dropout and stochastic depth inject different kinds of noise.

MethodWhat is droppedGranularity
DropoutActivation entriesElement-level
Dropout2dFeature channelsChannel-level
Stochastic depthResidual branchesBlock-level

Dropout is useful in fully connected layers, attention projections, and classifier heads. Stochastic depth is useful when the architecture has residual branches.

They can be used together, but the total regularization strength must be controlled. A model with heavy augmentation, label smoothing, weight decay, dropout, and stochastic depth may underfit.

Failure Modes

Stochastic depth has several common failure modes.

If early layers are dropped too often, the model may fail to learn stable low-level representations. Use a depth-dependent schedule that starts near zero.

If the drop probability is too high, the network may underfit. Training loss remains high, and validation performance does not improve.

If the branch changes shape, naive dropping can cause shape errors. Stochastic depth is easiest when the branch output has the same shape as the identity input.

If validation is run in training mode, stochastic depth remains active and evaluation becomes noisy. Always call:

model.eval()

before validation or inference.

Practical Guidance

Use stochastic depth for residual networks, vision transformers, ConvNeXt-style architectures, and other models with explicit skip connections.

Start with a small maximum drop probability such as 0.1. Increase it only when validation performance improves.

Use lower drop rates for shallow models and higher rates for very deep models.

Apply stochastic depth during training only. It should disappear at inference time.

Summary

Stochastic depth randomly skips residual branches during training. It trains a family of shallower subnetworks inside one deep model and reduces co-adaptation between layers.

It is especially useful in modern residual architectures such as ResNets, ConvNeXt models, and vision transformers.

At inference time, all branches are active, so stochastic depth improves regularization without adding inference cost.