Layer normalization is a normalization method that normalizes features within each individual example.
Layer normalization is a normalization method that normalizes features within each individual example. Unlike batch normalization, it does not use statistics from other examples in the same mini-batch. This makes it especially useful for sequence models, transformers, recurrent networks, and settings where batch sizes are small or variable.
For an input vector
layer normalization computes the mean and variance across the feature dimension of that same vector:
It then normalizes each feature:
Finally, it applies learned scale and shift parameters:
Here and have the same shape as the normalized feature dimension.
Batch Normalization Versus Layer Normalization
Batch normalization computes statistics across the batch. Layer normalization computes statistics inside each example.
For an input
batch normalization computes one mean and variance per feature across the examples. Layer normalization computes one mean and variance per example across the features.
| Method | Statistics computed over | Depends on batch contents | Common use |
|---|---|---|---|
| Batch normalization | Batch axis | Yes | CNNs, MLPs |
| Layer normalization | Feature axis | No | Transformers, RNNs, LLMs |
This difference matters during inference. Batch normalization behaves differently in training and evaluation mode because it uses running statistics during inference. Layer normalization uses the current example in both training and inference. Therefore, it has no train/eval statistical mismatch.
In PyTorch:
import torch
from torch import nn
x = torch.randn(32, 128)
ln = nn.LayerNorm(128)
y = ln(x)
print(y.shape) # torch.Size([32, 128])Each of the 32 examples is normalized independently across its 128 features.
Layer Normalization Shapes
nn.LayerNorm receives the shape of the dimensions to normalize. For a tensor with shape
[B, T, D]a common transformer layout is batch size, sequence length, embedding dimension. To normalize each token embedding independently, use:
x = torch.randn(16, 128, 768)
ln = nn.LayerNorm(768)
y = ln(x)
print(y.shape) # torch.Size([16, 128, 768])This normalizes over the last dimension only. Each token vector of length 768 gets its own mean and variance.
Layer normalization can also normalize over multiple trailing dimensions. For an image-like tensor:
x = torch.randn(8, 3, 32, 32)
ln = nn.LayerNorm([3, 32, 32])
y = ln(x)
print(y.shape) # torch.Size([8, 3, 32, 32])This normalizes each image over all channels and spatial positions. However, this layout is less common in CNNs, where batch normalization or group normalization is usually preferred.
Learned Scale and Shift
Layer normalization has learnable affine parameters by default. For nn.LayerNorm(768), PyTorch creates:
ln = nn.LayerNorm(768)
print(ln.weight.shape) # torch.Size([768])
print(ln.bias.shape) # torch.Size([768])The parameter weight corresponds to . The parameter bias corresponds to .
They are initialized as:
print(ln.weight[:5]) # tensor close to [1, 1, 1, 1, 1]
print(ln.bias[:5]) # tensor close to [0, 0, 0, 0, 0]This means the layer initially returns normalized activations without additional learned scaling or shifting.
You can disable the affine parameters:
ln = nn.LayerNorm(768, elementwise_affine=False)This is less common. Most neural networks keep the affine parameters because they allow the model to recover useful feature scales after normalization.
Layer Normalization in Transformers
Layer normalization is a standard component of transformer blocks. A transformer block usually contains attention, a feedforward network, residual connections, and normalization.
There are two common layouts: post-normalization and pre-normalization.
Post-normalization applies layer normalization after the residual addition:
x = x + attention(x)
x = layer_norm(x)Pre-normalization applies layer normalization before the sublayer:
x = x + attention(layer_norm(x))Modern large transformers usually prefer pre-normalization because it improves gradient flow in deep networks.
A simplified pre-normalization transformer block:
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, hidden_dim):
super().__init__()
self.ln1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(
embed_dim=dim,
num_heads=num_heads,
batch_first=True,
)
self.ln2 = nn.LayerNorm(dim)
self.ffn = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
h = self.ln1(x)
h, _ = self.attn(h, h, h)
x = x + h
h = self.ln2(x)
h = self.ffn(h)
x = x + h
return xThe input and output have the same shape:
[B, T, D]This is necessary because residual connections add tensors elementwise.
Why Layer Normalization Fits Sequence Models
Sequence models often use variable-length inputs, padding masks, and autoregressive decoding. These properties make batch statistics awkward.
In language modeling, a model may generate one token at a time. During inference, the batch size may be 1, and sequence lengths may change at each step. Batch normalization would depend heavily on the composition of the batch. Layer normalization avoids this issue because each token representation is normalized independently.
For a transformer hidden state
LayerNorm(D) normalizes each vector
independently. It does not mix information across batch elements or across time steps.
This is important for autoregressive models. The representation of one token should not depend on unrelated examples in the same batch.
Numerical Stability
Layer normalization includes a small constant in the denominator:
In PyTorch:
ln = nn.LayerNorm(768, eps=1e-5)The default value is usually adequate. In low-precision training, especially with float16, numerical stability may require careful treatment. Many large-model systems use bfloat16, mixed precision, or fused layer normalization kernels for better performance and stability.
A common practical rule is to leave eps unchanged unless there is a clear numerical problem.
Layer Normalization and Residual Streams
In transformers, the residual stream is the main path through which information flows across layers. Layer normalization controls the scale of this stream.
Consider a residual block:
The normalization is applied before the transformation . The residual path itself remains direct. This layout helps gradients flow through the identity connection while keeping the transformed branch numerically controlled.
This is one reason pre-normalized transformers can be trained at substantial depth.
RMSNorm
RMSNorm is a related normalization method used in several modern language models. It removes the mean-centering step and normalizes by root mean square magnitude:
Then
RMSNorm is simpler than layer normalization and can be faster. It preserves the basic idea of controlling feature scale within each token representation.
A minimal PyTorch implementation:
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-8):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return self.weight * x / rmsRMSNorm usually has a learned scale parameter but no bias parameter.
Common PyTorch Mistakes
A common mistake is passing the wrong normalized shape.
For a tensor with shape [B, T, D], this is correct:
ln = nn.LayerNorm(D)This normalizes the last axis.
This is usually wrong:
ln = nn.LayerNorm(T)That would normalize the sequence-length axis only if the sequence length were the last dimension. With shape [B, T, D], the last dimension is , not .
Another mistake is applying layer normalization to a channel-first image tensor with only the channel count:
x = torch.randn(8, 64, 32, 32)
ln = nn.LayerNorm(64)This does not match the last dimension, which is 32. For channel-first images, use BatchNorm2d, GroupNorm, or explicitly permute the tensor to channel-last layout before applying LayerNorm(64).
Example with channel-last layout:
x = torch.randn(8, 64, 32, 32)
x = x.permute(0, 2, 3, 1) # [B, H, W, C]
ln = nn.LayerNorm(64)
y = ln(x)
y = y.permute(0, 3, 1, 2) # [B, C, H, W]Practical Rules
Use layer normalization when the model processes sequences, tokens, embeddings, or hidden states with shape [B, T, D].
Use nn.LayerNorm(D) for transformer embeddings and hidden states.
Prefer pre-normalization for deep transformer-style architectures.
Use batch normalization or group normalization for most convolutional networks unless the architecture specifically uses layer normalization.
Keep affine parameters enabled unless you have a specific reason to remove them.
Check that the dimension passed to LayerNorm matches the trailing dimension or trailing dimensions of the tensor.
Summary
Layer normalization normalizes each example independently across its feature dimensions. It avoids dependence on batch statistics and behaves the same way during training and inference.
This makes it a natural fit for transformers, recurrent networks, and language models. In PyTorch, nn.LayerNorm(D) is usually applied to tensors whose last dimension is the feature or embedding dimension. Modern transformer blocks often combine pre-normalization, residual connections, attention, and feedforward networks to maintain stable gradient flow across many layers.