Skip to content

Group and Instance Normalization

Batch normalization and layer normalization are the two most common normalization layers, but they do not cover every setting well.

Batch normalization and layer normalization are the two most common normalization layers, but they do not cover every setting well. Batch normalization works well when batch statistics are reliable. Layer normalization works well for token and sequence models. Group normalization and instance normalization are useful when we want normalization that works independently of batch size while still respecting channel structure.

These methods are most common in computer vision, especially when the tensor layout is

XRB×C×H×W. X \in \mathbb{R}^{B \times C \times H \times W}.

Here BB is batch size, CC is channels, HH is height, and WW is width.

The Normalization Axis

The main difference between normalization methods is the set of elements used to compute mean and variance.

MethodMean and variance computed overTypical use
BatchNormBatch and spatial axes, per channelCNNs with large batches
LayerNormFeature axes, per exampleTransformers, RNNs
InstanceNormSpatial axes, per sample and channelStyle transfer, image generation
GroupNormGroups of channels and spatial axes, per sampleCNNs with small batches

For image tensors, the choice controls how much information is shared when computing normalization statistics.

Batch normalization couples examples in a batch. Group normalization and instance normalization do not.

Instance Normalization

Instance normalization normalizes each channel of each example independently. For an image tensor

XRB×C×H×W, X \in \mathbb{R}^{B \times C \times H \times W},

instance normalization computes statistics over the spatial dimensions HH and WW, separately for each batch element bb and channel cc:

μbc=1HWh=1Hw=1WXbchw. \mu_{bc} = \frac{1}{HW} \sum_{h=1}^{H} \sum_{w=1}^{W} X_{bchw}.

The variance is

σbc2=1HWh=1Hw=1W(Xbchwμbc)2. \sigma_{bc}^{2} = \frac{1}{HW} \sum_{h=1}^{H} \sum_{w=1}^{W} (X_{bchw}-\mu_{bc})^2.

Then each value is normalized as

X^bchw=Xbchwμbcσbc2+ϵ. \hat{X}_{bchw} = \frac{X_{bchw}-\mu_{bc}} {\sqrt{\sigma_{bc}^{2}+\epsilon}}.

With affine parameters, the output is

Ybchw=γcX^bchw+βc. Y_{bchw} = \gamma_c \hat{X}_{bchw} + \beta_c.

The scale γc\gamma_c and shift βc\beta_c are usually learned per channel.

In PyTorch:

import torch
from torch import nn

x = torch.randn(8, 64, 32, 32)

norm = nn.InstanceNorm2d(64, affine=True)
y = norm(x)

print(y.shape)  # torch.Size([8, 64, 32, 32])

The argument 64 is the number of channels.

When Instance Normalization Helps

Instance normalization is common in style transfer and image generation. It removes instance-specific contrast and brightness statistics from each channel. This can make the model focus more on spatial structure and style-independent content.

For example, two images may have different lighting or contrast. Instance normalization normalizes each image separately, so those global appearance differences are reduced.

This property is useful in some generative image models. It can also be harmful in tasks where absolute intensity matters, such as some medical imaging or scientific imaging tasks.

Instance normalization is usually less common than batch normalization for ordinary image classification.

Group Normalization

Group normalization sits between layer normalization and instance normalization. It divides channels into groups, then normalizes over each group and the spatial dimensions.

Suppose

XRB×C×H×W X \in \mathbb{R}^{B \times C \times H \times W}

and the channels are divided into GG groups. Each group contains

CG \frac{C}{G}

channels.

For each sample bb and group gg, group normalization computes mean and variance over the channels in that group and over all spatial positions.

If SgS_g is the set of channel indices in group gg, then

μbg=1SgHWcSgh=1Hw=1WXbchw. \mu_{bg} = \frac{1}{|S_g|HW} \sum_{c\in S_g} \sum_{h=1}^{H} \sum_{w=1}^{W} X_{bchw}.

Then all values in that group are normalized using μbg\mu_{bg} and σbg2\sigma_{bg}^2.

In PyTorch:

x = torch.randn(8, 64, 32, 32)

norm = nn.GroupNorm(num_groups=8, num_channels=64)
y = norm(x)

print(y.shape)  # torch.Size([8, 64, 32, 32])

Here 64 channels are split into 8 groups, so each group has 8 channels.

Group Count

The number of groups changes the behavior.

num_groupsBehavior
1Similar to layer normalization over channels and spatial dimensions
CSimilar to instance normalization
Between 1 and CTrue group normalization

For example, if C=64C=64:

nn.GroupNorm(num_groups=1, num_channels=64)

normalizes all channels together for each sample.

nn.GroupNorm(num_groups=64, num_channels=64)

normalizes each channel separately for each sample, similar to instance normalization.

A common practical choice is 16 or 32 groups when the number of channels allows it.

norm = nn.GroupNorm(num_groups=32, num_channels=128)

The number of channels must be divisible by the number of groups.

Why Group Normalization Handles Small Batches

Batch normalization depends on statistics from the mini-batch. If the batch size is small, those statistics are noisy. In object detection, segmentation, high-resolution image training, and memory-limited workloads, batch sizes may be small.

Group normalization avoids this problem because it computes statistics within each example. It behaves the same way regardless of the other images in the batch.

This makes it useful when:

SettingReason
Small-batch CNN trainingBatchNorm statistics are noisy
Object detectionHigh memory use limits batch size
Semantic segmentationLarge images reduce batch size
Multi-device training with tiny local batchesPer-device BatchNorm becomes unstable
Style and generation modelsPer-example normalization is often preferred

Group normalization also has no running mean or running variance. It behaves the same in training and evaluation mode.

Group Normalization in CNN Blocks

A convolutional block with group normalization often looks like this:

class ConvGNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_groups=32):
        super().__init__()

        groups = min(num_groups, out_channels)

        # Ensure the group count divides the channel count.
        while out_channels % groups != 0:
            groups -= 1

        self.block = nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=3,
                padding=1,
                bias=False,
            ),
            nn.GroupNorm(groups, out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.block(x)

Unlike batch normalization, group normalization does not subtract a batch-level mean. A convolution bias can still be redundant if the normalization is immediately after the convolution, but practice varies. Many implementations still use bias=False in this pattern.

Instance Normalization in Image Models

An instance-normalized convolutional block:

class ConvINBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=3,
                padding=1,
            ),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.block(x)

For style-transfer networks, instance normalization is often used after convolution and before activation.

In generative image models, normalization choice strongly affects output texture, contrast, and style. This is why normalization should be treated as part of the architecture, not as a neutral implementation detail.

Train and Evaluation Behavior

Group normalization and instance normalization usually do not rely on running batch statistics.

For GroupNorm, the behavior is the same in training and evaluation mode. It always computes statistics from the current input.

For InstanceNorm2d, PyTorch defaults to track_running_stats=False, so it also uses current input statistics in both modes:

norm = nn.InstanceNorm2d(64, affine=True)
print(norm.track_running_stats)  # False

You can enable running statistics:

norm = nn.InstanceNorm2d(
    64,
    affine=True,
    track_running_stats=True,
)

This is less common. When enabled, the layer behaves more like batch normalization with separate training and inference statistics.

Comparing Normalization Layers in PyTorch

A compact comparison:

x = torch.randn(8, 64, 32, 32)

batch_norm = nn.BatchNorm2d(64)
group_norm = nn.GroupNorm(8, 64)
instance_norm = nn.InstanceNorm2d(64, affine=True)

print(batch_norm(x).shape)    # torch.Size([8, 64, 32, 32])
print(group_norm(x).shape)    # torch.Size([8, 64, 32, 32])
print(instance_norm(x).shape) # torch.Size([8, 64, 32, 32])

All three preserve shape. They differ only in how they compute normalization statistics.

Practical Selection Rules

Use batch normalization when training CNNs with sufficiently large batches and stable batch composition.

Use group normalization when training CNNs with small batches, large images, detection systems, or segmentation systems.

Use instance normalization when per-image appearance statistics should be removed or controlled, especially in style transfer and some generative models.

Use layer normalization for transformers, token embeddings, recurrent networks, and most modern language models.

Model familyCommon normalization
Image classification CNNBatchNorm
Small-batch CNNGroupNorm
Object detectionGroupNorm or frozen BatchNorm
Semantic segmentationGroupNorm or SyncBatchNorm
Style transferInstanceNorm
TransformerLayerNorm or RMSNorm
Large language modelLayerNorm or RMSNorm

Common PyTorch Mistakes

The first common mistake is choosing a group count that does not divide the number of channels:

nn.GroupNorm(num_groups=7, num_channels=64)

This is invalid because 64 cannot be divided evenly into 7 groups.

The second common mistake is using instance normalization for classification without checking whether it removes useful signal. If global contrast or intensity helps the task, instance normalization may hurt performance.

The third common mistake is assuming all normalization layers use running statistics. GroupNorm does not. LayerNorm does not. BatchNorm does. InstanceNorm only does when track_running_stats=True.

The fourth common mistake is using batch normalization with very small per-device batches. If distributed training uses a global batch size of 64 across 8 GPUs, each GPU may only see 8 examples. Ordinary batch normalization computes statistics per device unless synchronized batch normalization is used.

Summary

Group normalization and instance normalization provide alternatives to batch normalization for image-like tensors.

Instance normalization normalizes each channel of each sample over spatial dimensions. It is useful when per-instance appearance statistics should be reduced, such as in style transfer.

Group normalization divides channels into groups and normalizes each group within each sample. It works well for CNNs when batch sizes are small or batch statistics are unreliable.

The main design question is simple: which axes should provide the mean and variance? Once that is clear, the choice among batch, layer, group, and instance normalization becomes easier.