Skip to content

Weight Decay and Regularization

Training loss measures how well a model fits the training data.

Training loss measures how well a model fits the training data. A model with many parameters can sometimes fit the training data too closely. It may learn noise, accidental correlations, or details that do not hold for new examples. This problem is overfitting.

Regularization changes training so that the model is encouraged to learn simpler or more stable solutions. Weight decay is one of the most common regularization methods in deep learning.

Overfitting

A model overfits when it performs well on the training set but poorly on validation or test data.

A typical pattern is:

Training lossValidation lossInterpretation
DecreasesDecreasesModel is learning useful structure
DecreasesStops improvingModel may be reaching capacity or noise limit
DecreasesIncreasesModel is likely overfitting

Overfitting is common when the model is large, the dataset is small, labels are noisy, or training runs for too long.

Regularization tries to reduce this gap between training performance and validation performance.

L2 Regularization

L2 regularization adds a penalty on large weights to the loss:

Lreg(θ)=L(θ)+λ2θ22. L_{\text{reg}}(\theta) = L(\theta) + \frac{\lambda}{2} \|\theta\|_2^2.

Here L(θ)L(\theta) is the original training loss, θ\theta is the parameter vector, and λ\lambda controls the strength of regularization.

The squared L2 norm is

θ22=jθj2. \|\theta\|_2^2 = \sum_j \theta_j^2.

Large parameter values increase the penalty. Training therefore prefers solutions with smaller weights when they achieve similar prediction loss.

Weight Decay

For plain SGD, L2 regularization leads to the update

θθη(θL+λθ). \theta \leftarrow \theta - \eta (\nabla_\theta L + \lambda\theta).

This can be rearranged as

θ(1ηλ)θηθL. \theta \leftarrow (1-\eta\lambda)\theta - \eta\nabla_\theta L.

The term (1ηλ)θ(1-\eta\lambda)\theta shrinks parameters slightly at every step. This is why the method is called weight decay.

In PyTorch:

optimizer = torch.optim.SGD(
    model.parameters(),
    lr=0.1,
    weight_decay=1e-4,
)

The optimizer applies the decay during optimizer.step().

AdamW and Decoupled Weight Decay

For adaptive optimizers such as Adam, L2 regularization and weight decay behave differently. Adam rescales gradients using moment estimates, so adding λθ\lambda\theta to the gradient does not produce the same clean shrinkage effect as SGD.

AdamW uses decoupled weight decay. It applies weight decay separately from the adaptive gradient update.

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=0.01,
)

AdamW is commonly preferred for transformers, diffusion models, and many modern architectures.

Which Parameters Should Decay

Weight decay is usually applied to weight matrices and convolution kernels. It is often excluded from bias vectors and normalization parameters.

Biases are offset terms, and penalizing them usually gives little benefit. Normalization parameters, such as LayerNorm scale and shift, control activation statistics. Decaying them can sometimes hurt training.

A common PyTorch pattern is:

decay = []
no_decay = []

for name, param in model.named_parameters():
    if not param.requires_grad:
        continue

    if name.endswith("bias") or "norm" in name.lower():
        no_decay.append(param)
    else:
        decay.append(param)

optimizer = torch.optim.AdamW(
    [
        {"params": decay, "weight_decay": 0.01},
        {"params": no_decay, "weight_decay": 0.0},
    ],
    lr=3e-4,
)

This gives different regularization settings to different parameter groups.

L1 Regularization

L1 regularization adds a penalty on the absolute value of parameters:

Lreg(θ)=L(θ)+λθ1. L_{\text{reg}}(\theta) = L(\theta) + \lambda \|\theta\|_1.

The L1 norm is

θ1=jθj. \|\theta\|_1 = \sum_j |\theta_j|.

L1 regularization tends to encourage sparsity. Some parameter values may become exactly or nearly zero. This can be useful for feature selection or compact models, although it is less common than weight decay in large neural networks.

In PyTorch, L1 regularization is usually added manually:

l1_lambda = 1e-5

pred = model(x)
loss = loss_fn(pred, y)

l1_penalty = torch.tensor(0.0, device=loss.device)

for param in model.parameters():
    l1_penalty = l1_penalty + param.abs().sum()

loss = loss + l1_lambda * l1_penalty

Dropout

Dropout randomly sets some activations to zero during training. This prevents the model from relying too strongly on specific hidden units.

model = torch.nn.Sequential(
    torch.nn.Linear(128, 256),
    torch.nn.ReLU(),
    torch.nn.Dropout(p=0.5),
    torch.nn.Linear(256, 10),
)

During training, dropout is active. During evaluation, dropout is disabled.

model.train()  # dropout active
model.eval()   # dropout inactive

Dropout is common in fully connected networks and some transformer components. It is less central in modern large-scale models trained on very large datasets, where weight decay, data scale, and augmentation often matter more.

Early Stopping

Early stopping stops training when validation performance stops improving.

The idea is simple. Continue training while validation loss improves. Stop when validation loss has not improved for a fixed number of epochs.

best_val_loss = float("inf")
patience = 5
bad_epochs = 0

for epoch in range(num_epochs):
    train_one_epoch(model, train_loader, optimizer)
    val_loss = evaluate(model, val_loader)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        bad_epochs = 0
        best_state = {
            k: v.cpu().clone()
            for k, v in model.state_dict().items()
        }
    else:
        bad_epochs += 1

    if bad_epochs >= patience:
        break

model.load_state_dict(best_state)

Early stopping is useful when overfitting appears after many epochs. It is also useful when the best number of epochs is unknown.

Data Augmentation

Data augmentation regularizes a model by creating modified versions of training examples. For images, common augmentations include cropping, flipping, color jitter, blur, and random erasing.

For text, augmentation is harder because small changes can alter meaning. Common approaches include masking, back-translation, paraphrasing, or using larger pretrained models.

For audio, augmentations include time shifting, noise injection, speed changes, and spectrogram masking.

Augmentation increases the diversity of training data without collecting new labels. It encourages invariance to transformations that should not change the target.

Label Smoothing

In multi-class classification, hard labels assign probability 1 to the correct class and 0 to all others. Label smoothing replaces this with a softer target distribution.

If there are KK classes and smoothing strength is ϵ\epsilon, the correct class receives probability close to 1ϵ1-\epsilon, while other classes receive small nonzero probability.

In PyTorch:

loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.1)

Label smoothing can reduce overconfidence and improve calibration. It is common in image classification and sequence modeling.

Gradient Clipping as Stabilization

Gradient clipping limits the size of gradients before the optimizer update. It is mainly a stabilization method, but it can also act like a constraint on the update.

loss.backward()

torch.nn.utils.clip_grad_norm_(
    model.parameters(),
    max_norm=1.0,
)

optimizer.step()

Gradient clipping is especially useful for recurrent networks, transformers, reinforcement learning, and unstable training regimes.

Regularization and Model Capacity

Regularization does not simply make a model smaller. It changes which solutions are preferred during optimization.

Two models with the same architecture can generalize differently depending on weight decay, augmentation, dropout, label smoothing, learning rate schedule, batch size, and training duration.

Regularization should be tuned using validation performance, not training loss alone. Strong regularization may increase training loss while reducing validation loss. Excessive regularization may cause underfitting.

Practical Defaults

For many PyTorch models, these are reasonable starting points:

SettingRegularization choice
Linear or small MLPWeight decay around 1e-4, early stopping
CNN from scratchWeight decay around 1e-4, data augmentation
Transformer fine-tuningAdamW with 0.01 weight decay, possibly dropout
Large language model pretrainingAdamW, decoupled weight decay, gradient clipping
Small noisy datasetStronger augmentation, early stopping, dropout

These values are starting points. The best setting depends on dataset size, label noise, model capacity, and training budget.

Summary

Regularization reduces overfitting by changing the training objective or training process. Weight decay penalizes large parameters and is one of the most common regularizers. For Adam-style optimizers, AdamW is usually preferred because it decouples weight decay from adaptive gradient scaling.

Other regularization methods include L1 penalties, dropout, early stopping, data augmentation, label smoothing, and gradient clipping. In practice, regularization should be judged by validation performance, not by training loss alone.