Skip to content

Batch Normalization

Batch normalization is a layer that normalizes activations using statistics computed from a mini-batch.

Batch normalization is a layer that normalizes activations using statistics computed from a mini-batch. It was introduced to make deep networks easier to train, especially convolutional and feedforward networks. The basic idea is simple: keep intermediate activations in a controlled numerical range, then let the model learn how much scale and shift it wants.

A neural network layer often produces pre-activations

z=Wx+b. z = Wx + b.

If the distribution of zz changes too much during training, later layers must constantly adapt to new input scales. Batch normalization reduces this problem by normalizing activations before they are passed onward.

Normalizing a Mini-Batch

Suppose a layer produces a batch of activations

XRB×D, X \in \mathbb{R}^{B \times D},

where BB is the batch size and DD is the feature dimension. For each feature dimension jj, batch normalization computes the mini-batch mean

μj=1Bi=1BXij \mu_j = \frac{1}{B}\sum_{i=1}^{B} X_{ij}

and variance

σj2=1Bi=1B(Xijμj)2. \sigma_j^2 = \frac{1}{B}\sum_{i=1}^{B}(X_{ij}-\mu_j)^2.

It then normalizes each activation:

X^ij=Xijμjσj2+ϵ. \hat{X}_{ij} = \frac{X_{ij}-\mu_j}{\sqrt{\sigma_j^2+\epsilon}}.

The small constant ϵ\epsilon prevents division by zero.

After normalization, each feature has approximately zero mean and unit variance within the mini-batch.

Learned Scale and Shift

Normalization alone may remove useful information. For this reason, batch normalization includes two learnable parameters per feature:

Yij=γjX^ij+βj. Y_{ij} = \gamma_j \hat{X}_{ij} + \beta_j.

The parameter γj\gamma_j learns the output scale. The parameter βj\beta_j learns the output shift.

If the model needs the normalized output unchanged, it can learn γj=1\gamma_j = 1 and βj=0\beta_j = 0. If it needs a different scale or offset, it can learn that too.

In PyTorch, these parameters are stored as weight and bias:

import torch
from torch import nn

bn = nn.BatchNorm1d(128)

print(bn.weight.shape)  # torch.Size([128])
print(bn.bias.shape)    # torch.Size([128])

By default, weight starts at ones and bias starts at zeros.

BatchNorm1d

BatchNorm1d is commonly used for vectors and sequence-like feature tensors.

For a matrix input with shape

[B, D]

batch normalization computes statistics over the batch axis for each feature.

x = torch.randn(32, 128)

bn = nn.BatchNorm1d(128)
y = bn(x)

print(y.shape)  # torch.Size([32, 128])

For sequence-like data with shape

[B, D, T]

BatchNorm1d normalizes each channel over both the batch and time dimensions.

x = torch.randn(32, 128, 50)

bn = nn.BatchNorm1d(128)
y = bn(x)

print(y.shape)  # torch.Size([32, 128, 50])

Here 128 is the channel or feature dimension.

BatchNorm2d

BatchNorm2d is used in convolutional networks. Its expected input shape is

[B, C, H, W]

where BB is batch size, CC is number of channels, HH is height, and WW is width.

For each channel cc, batch normalization computes mean and variance over the batch and spatial dimensions:

μc=1BHWb=1Bh=1Hw=1WXbchw. \mu_c = \frac{1}{BHW} \sum_{b=1}^{B} \sum_{h=1}^{H} \sum_{w=1}^{W} X_{bchw}.

In PyTorch:

x = torch.randn(32, 64, 56, 56)

bn = nn.BatchNorm2d(64)
y = bn(x)

print(y.shape)  # torch.Size([32, 64, 56, 56])

BatchNorm2d(64) means there are 64 channels. The layer learns one scale and one shift parameter per channel.

Training Mode and Evaluation Mode

Batch normalization behaves differently during training and inference.

During training, it uses statistics from the current mini-batch. During evaluation, it uses running estimates accumulated during training.

This distinction is essential.

model.train()
# BatchNorm uses current batch statistics.

model.eval()
# BatchNorm uses running mean and running variance.

The running statistics are stored in buffers:

bn = nn.BatchNorm1d(128)

print(bn.running_mean.shape)  # torch.Size([128])
print(bn.running_var.shape)   # torch.Size([128])

These are not learned by gradient descent. They are updated during training using a moving average.

Running Statistics

During training, PyTorch updates running statistics approximately as

running_mean(1m)running_mean+mμbatch, \text{running\_mean} \leftarrow (1-m)\text{running\_mean} + m\mu_{\text{batch}},

where mm is the momentum parameter.

Example:

bn = nn.BatchNorm1d(128, momentum=0.1)

A smaller momentum updates running statistics more slowly. A larger momentum makes them follow recent batches more closely.

The running statistics are used when the model is placed in evaluation mode. If these estimates are poor, inference quality may degrade.

This can happen when batch sizes are very small, data distribution changes, or the training loop forgets to call model.train() and model.eval() at the correct times.

Batch Normalization in an MLP

A common MLP pattern is

Linear -> BatchNorm -> ReLU

Example:

class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),

            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),

            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x):
        return self.net(x)

The batch normalization layer is usually placed after the linear transformation and before the activation function. Some architectures place it after the activation. The first pattern is more common in classical CNN and MLP designs.

Batch Normalization in a CNN

In convolutional networks, the common block is

Conv2d -> BatchNorm2d -> ReLU

Example:

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=3,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.block(x)

The convolution often uses bias=False when followed by batch normalization. The reason is that batch normalization subtracts the batch mean, so the convolution bias is mostly redundant. The batch normalization layer already has its own learned shift parameter.

Benefits of Batch Normalization

Batch normalization often improves training in several ways.

It stabilizes activation scale across layers. This helps gradients remain in a useful range.

It allows larger learning rates in many networks. Since activations are controlled, optimization may tolerate more aggressive updates.

It can reduce sensitivity to initialization. Weight initialization still matters, but batch normalization makes the network less fragile.

It adds mild regularization. Because each example is normalized using statistics from the mini-batch, the output for one example depends slightly on other examples in the same batch. This introduces noise during training.

Batch normalization became especially important in deep convolutional networks such as residual networks.

Limitations of Batch Normalization

Batch normalization depends on mini-batch statistics. This creates several limitations.

Small batches produce noisy mean and variance estimates. If the batch size is very small, batch normalization may perform poorly.

Batch normalization can be awkward for variable-length sequence modeling. Padding, masking, and autoregressive inference make batch statistics less natural.

Training and inference use different statistics. A model can behave differently between train() and eval() mode. This is a common source of bugs.

Batch normalization couples examples within a batch. This can be undesirable in some settings, especially when examples should be processed independently.

For these reasons, transformer architectures usually prefer layer normalization over batch normalization.

Common PyTorch Mistakes

A frequent error is evaluating a model without calling model.eval():

model.eval()
with torch.no_grad():
    y = model(x)

Without model.eval(), batch normalization keeps using current batch statistics. Predictions may vary depending on which examples appear in the same batch.

Another common error is forgetting to return to training mode after validation:

model.train()

A third error is using batch normalization with batch size 1 during training. With only one example, batch statistics can be poor, especially for fully connected layers.

A fourth error is applying BatchNorm1d to the wrong axis. For input shape [B, T, D], BatchNorm1d(D) does not directly treat D as the channel dimension. One common fix is to transpose:

x = torch.randn(32, 100, 128)  # [B, T, D]

bn = nn.BatchNorm1d(128)

y = bn(x.transpose(1, 2)).transpose(1, 2)

print(y.shape)  # torch.Size([32, 100, 128])

Here the tensor is temporarily changed from [B, T, D] to [B, D, T], which matches what BatchNorm1d expects.

Batch Normalization and Dropout

Batch normalization and dropout can be used together, but they interact.

Batch normalization already adds noise through mini-batch statistics. Dropout adds additional noise by randomly zeroing activations. In some networks, using both can help. In others, heavy dropout after batch normalization can hurt.

A common practical rule is to start with batch normalization and weight decay. Add dropout only if validation performance suggests overfitting.

For CNNs, dropout is often less important when batch normalization and data augmentation are used. For MLPs on tabular data, dropout may still be useful.

Summary

Batch normalization normalizes activations using mini-batch mean and variance. It then applies learned scale and shift parameters. In PyTorch, BatchNorm1d is used for vector and sequence-like feature tensors, while BatchNorm2d is used for convolutional feature maps.

Batch normalization often improves optimization, stabilizes training, and reduces sensitivity to initialization. Its main weakness is dependence on batch statistics. Small batches, train/eval mode mistakes, and sequence-modeling settings can make it less suitable.

The essential PyTorch rule is simple: use model.train() during training, use model.eval() during inference, and make sure the normalized feature axis matches the layer you selected.