# Batch Normalization

Batch normalization is a layer that normalizes activations using statistics computed from a mini-batch. It was introduced to make deep networks easier to train, especially convolutional and feedforward networks. The basic idea is simple: keep intermediate activations in a controlled numerical range, then let the model learn how much scale and shift it wants.

A neural network layer often produces pre-activations

$$
z = Wx + b.
$$

If the distribution of $z$ changes too much during training, later layers must constantly adapt to new input scales. Batch normalization reduces this problem by normalizing activations before they are passed onward.

### Normalizing a Mini-Batch

Suppose a layer produces a batch of activations

$$
X \in \mathbb{R}^{B \times D},
$$

where $B$ is the batch size and $D$ is the feature dimension. For each feature dimension $j$, batch normalization computes the mini-batch mean

$$
\mu_j = \frac{1}{B}\sum_{i=1}^{B} X_{ij}
$$

and variance

$$
\sigma_j^2 = \frac{1}{B}\sum_{i=1}^{B}(X_{ij}-\mu_j)^2.
$$

It then normalizes each activation:

$$
\hat{X}_{ij} = \frac{X_{ij}-\mu_j}{\sqrt{\sigma_j^2+\epsilon}}.
$$

The small constant $\epsilon$ prevents division by zero.

After normalization, each feature has approximately zero mean and unit variance within the mini-batch.

### Learned Scale and Shift

Normalization alone may remove useful information. For this reason, batch normalization includes two learnable parameters per feature:

$$
Y_{ij} = \gamma_j \hat{X}_{ij} + \beta_j.
$$

The parameter $\gamma_j$ learns the output scale. The parameter $\beta_j$ learns the output shift.

If the model needs the normalized output unchanged, it can learn $\gamma_j = 1$ and $\beta_j = 0$. If it needs a different scale or offset, it can learn that too.

In PyTorch, these parameters are stored as `weight` and `bias`:

```python
import torch
from torch import nn

bn = nn.BatchNorm1d(128)

print(bn.weight.shape)  # torch.Size([128])
print(bn.bias.shape)    # torch.Size([128])
```

By default, `weight` starts at ones and `bias` starts at zeros.

### BatchNorm1d

`BatchNorm1d` is commonly used for vectors and sequence-like feature tensors.

For a matrix input with shape

```python
[B, D]
```

batch normalization computes statistics over the batch axis for each feature.

```python
x = torch.randn(32, 128)

bn = nn.BatchNorm1d(128)
y = bn(x)

print(y.shape)  # torch.Size([32, 128])
```

For sequence-like data with shape

```python
[B, D, T]
```

`BatchNorm1d` normalizes each channel over both the batch and time dimensions.

```python
x = torch.randn(32, 128, 50)

bn = nn.BatchNorm1d(128)
y = bn(x)

print(y.shape)  # torch.Size([32, 128, 50])
```

Here `128` is the channel or feature dimension.

### BatchNorm2d

`BatchNorm2d` is used in convolutional networks. Its expected input shape is

```python
[B, C, H, W]
```

where $B$ is batch size, $C$ is number of channels, $H$ is height, and $W$ is width.

For each channel $c$, batch normalization computes mean and variance over the batch and spatial dimensions:

$$
\mu_c =
\frac{1}{BHW}
\sum_{b=1}^{B}
\sum_{h=1}^{H}
\sum_{w=1}^{W}
X_{bchw}.
$$

In PyTorch:

```python
x = torch.randn(32, 64, 56, 56)

bn = nn.BatchNorm2d(64)
y = bn(x)

print(y.shape)  # torch.Size([32, 64, 56, 56])
```

`BatchNorm2d(64)` means there are 64 channels. The layer learns one scale and one shift parameter per channel.

### Training Mode and Evaluation Mode

Batch normalization behaves differently during training and inference.

During training, it uses statistics from the current mini-batch. During evaluation, it uses running estimates accumulated during training.

This distinction is essential.

```python
model.train()
# BatchNorm uses current batch statistics.

model.eval()
# BatchNorm uses running mean and running variance.
```

The running statistics are stored in buffers:

```python
bn = nn.BatchNorm1d(128)

print(bn.running_mean.shape)  # torch.Size([128])
print(bn.running_var.shape)   # torch.Size([128])
```

These are not learned by gradient descent. They are updated during training using a moving average.

### Running Statistics

During training, PyTorch updates running statistics approximately as

$$
\text{running\_mean}
\leftarrow
(1-m)\text{running\_mean} + m\mu_{\text{batch}},
$$

where $m$ is the momentum parameter.

Example:

```python
bn = nn.BatchNorm1d(128, momentum=0.1)
```

A smaller momentum updates running statistics more slowly. A larger momentum makes them follow recent batches more closely.

The running statistics are used when the model is placed in evaluation mode. If these estimates are poor, inference quality may degrade.

This can happen when batch sizes are very small, data distribution changes, or the training loop forgets to call `model.train()` and `model.eval()` at the correct times.

### Batch Normalization in an MLP

A common MLP pattern is

```python
Linear -> BatchNorm -> ReLU
```

Example:

```python
class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),

            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),

            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x):
        return self.net(x)
```

The batch normalization layer is usually placed after the linear transformation and before the activation function. Some architectures place it after the activation. The first pattern is more common in classical CNN and MLP designs.

### Batch Normalization in a CNN

In convolutional networks, the common block is

```python
Conv2d -> BatchNorm2d -> ReLU
```

Example:

```python
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=3,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.block(x)
```

The convolution often uses `bias=False` when followed by batch normalization. The reason is that batch normalization subtracts the batch mean, so the convolution bias is mostly redundant. The batch normalization layer already has its own learned shift parameter.

### Benefits of Batch Normalization

Batch normalization often improves training in several ways.

It stabilizes activation scale across layers. This helps gradients remain in a useful range.

It allows larger learning rates in many networks. Since activations are controlled, optimization may tolerate more aggressive updates.

It can reduce sensitivity to initialization. Weight initialization still matters, but batch normalization makes the network less fragile.

It adds mild regularization. Because each example is normalized using statistics from the mini-batch, the output for one example depends slightly on other examples in the same batch. This introduces noise during training.

Batch normalization became especially important in deep convolutional networks such as residual networks.

### Limitations of Batch Normalization

Batch normalization depends on mini-batch statistics. This creates several limitations.

Small batches produce noisy mean and variance estimates. If the batch size is very small, batch normalization may perform poorly.

Batch normalization can be awkward for variable-length sequence modeling. Padding, masking, and autoregressive inference make batch statistics less natural.

Training and inference use different statistics. A model can behave differently between `train()` and `eval()` mode. This is a common source of bugs.

Batch normalization couples examples within a batch. This can be undesirable in some settings, especially when examples should be processed independently.

For these reasons, transformer architectures usually prefer layer normalization over batch normalization.

### Common PyTorch Mistakes

A frequent error is evaluating a model without calling `model.eval()`:

```python
model.eval()
with torch.no_grad():
    y = model(x)
```

Without `model.eval()`, batch normalization keeps using current batch statistics. Predictions may vary depending on which examples appear in the same batch.

Another common error is forgetting to return to training mode after validation:

```python
model.train()
```

A third error is using batch normalization with batch size 1 during training. With only one example, batch statistics can be poor, especially for fully connected layers.

A fourth error is applying `BatchNorm1d` to the wrong axis. For input shape `[B, T, D]`, `BatchNorm1d(D)` does not directly treat `D` as the channel dimension. One common fix is to transpose:

```python
x = torch.randn(32, 100, 128)  # [B, T, D]

bn = nn.BatchNorm1d(128)

y = bn(x.transpose(1, 2)).transpose(1, 2)

print(y.shape)  # torch.Size([32, 100, 128])
```

Here the tensor is temporarily changed from `[B, T, D]` to `[B, D, T]`, which matches what `BatchNorm1d` expects.

### Batch Normalization and Dropout

Batch normalization and dropout can be used together, but they interact.

Batch normalization already adds noise through mini-batch statistics. Dropout adds additional noise by randomly zeroing activations. In some networks, using both can help. In others, heavy dropout after batch normalization can hurt.

A common practical rule is to start with batch normalization and weight decay. Add dropout only if validation performance suggests overfitting.

For CNNs, dropout is often less important when batch normalization and data augmentation are used. For MLPs on tabular data, dropout may still be useful.

### Summary

Batch normalization normalizes activations using mini-batch mean and variance. It then applies learned scale and shift parameters. In PyTorch, `BatchNorm1d` is used for vector and sequence-like feature tensors, while `BatchNorm2d` is used for convolutional feature maps.

Batch normalization often improves optimization, stabilizes training, and reduces sensitivity to initialization. Its main weakness is dependence on batch statistics. Small batches, train/eval mode mistakes, and sequence-modeling settings can make it less suitable.

The essential PyTorch rule is simple: use `model.train()` during training, use `model.eval()` during inference, and make sure the normalized feature axis matches the layer you selected.

