Skip to content

Choosing and Combining Loss Functions

A loss function defines what the model is trained to improve. It translates a modeling goal into a scalar value that can be minimized by gradient-based optimization.

A loss function defines what the model is trained to improve. It translates a modeling goal into a scalar value that can be minimized by gradient-based optimization.

The choice of loss function affects the learned representation, the gradient signal, the stability of training, and the behavior of the final model. Two models with the same architecture and data can learn different solutions if they use different losses.

Loss Functions as Modeling Assumptions

A loss function encodes assumptions about the target.

Mean squared error assumes continuous targets and penalizes large errors strongly. Cross-entropy assumes categorical targets and trains the model to assign probability to the correct class. Contrastive losses assume relationships between examples, such as similarity or relevance. Multi-task losses assume that several objectives can share useful representations.

A loss is therefore both an optimization tool and a statistical statement.

Target typeCommon lossModel output
Continuous valueMSE, MAE, HuberScalar or vector
Binary labelBinary cross-entropyOne logit
Single classCross-entropyOne logit per class
Multiple labelsBinary cross-entropyOne logit per label
RankingPairwise margin lossScores
Embedding similarityContrastive or triplet lossEmbeddings
ReconstructionMSE or BCEReconstructed input
Sequence predictionToken cross-entropyVocabulary logits

The correct loss depends on what the output means.

Match the Loss to the Output

Loss functions expect specific output forms.

For PyTorch classification, nn.CrossEntropyLoss expects raw logits of shape [B, K] and integer class targets of shape [B].

loss = torch.nn.CrossEntropyLoss()(logits, targets)

Do not apply softmax before this loss.

For binary classification, nn.BCEWithLogitsLoss expects raw logits and floating-point targets.

loss = torch.nn.BCEWithLogitsLoss()(logits, targets.float())

Do not apply sigmoid before this loss.

For regression, nn.MSELoss expects predictions and targets with matching shapes.

loss = torch.nn.MSELoss()(predictions, targets)

These conventions matter because PyTorch loss modules often combine several numerical operations internally for stability.

Classification Loss Selection

For mutually exclusive classes, use cross-entropy.

Example: an image belongs to exactly one class among cat, dog, car, or bird.

logits = model(images)        # [B, K]
targets = targets.long()      # [B]

loss = torch.nn.CrossEntropyLoss()(logits, targets)

For binary classification, use binary cross-entropy with logits.

logits = model(x).squeeze(-1) # [B]
targets = targets.float()     # [B]

loss = torch.nn.BCEWithLogitsLoss()(logits, targets)

For multilabel classification, use binary cross-entropy independently for each label.

logits = model(x)             # [B, K]
targets = targets.float()     # [B, K]

loss = torch.nn.BCEWithLogitsLoss()(logits, targets)

The key distinction is whether classes compete. Softmax makes classes compete. Sigmoid treats labels independently.

Regression Loss Selection

For continuous targets, MSE is the usual baseline. It works well when errors are roughly Gaussian and large errors should receive strong penalties.

loss = torch.nn.functional.mse_loss(pred, target)

Mean absolute error is more robust to outliers:

loss = torch.nn.functional.l1_loss(pred, target)

Huber loss combines the two behaviors. It is quadratic near zero and linear for large errors.

loss = torch.nn.HuberLoss(delta=1.0)(pred, target)

A useful rule is simple: start with MSE, compare against MAE and Huber when outliers matter, and standardize targets when their numerical scale is large.

Sequence Loss Selection

For language modeling and token prediction, use cross-entropy over the vocabulary.

Suppose logits have shape [B, T, V] and targets have shape [B, T].

B, T, V = logits.shape

loss = torch.nn.CrossEntropyLoss()(
    logits.reshape(B * T, V),
    targets.reshape(B * T),
)

For padded sequences, set padded targets to ignore_index, usually -100.

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fn(logits.reshape(B * T, V), targets.reshape(B * T))

Only real target positions should contribute to the loss. Padding tokens are implementation artifacts, not learning targets.

Imbalanced Data

When classes are imbalanced, the unweighted loss may favor common classes.

For multiclass classification, pass class weights to cross-entropy:

weights = torch.tensor([1.0, 2.0, 8.0], device=device)

loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
loss = loss_fn(logits, targets)

For binary or multilabel classification, use pos_weight:

pos_weight = torch.tensor([5.0], device=device)

loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
loss = loss_fn(logits, targets.float())

Weights change the training objective. They can improve rare-class recall, but they may hurt calibration. Always evaluate precision, recall, F1, calibration, and task-specific costs.

Combining Losses

Many models use more than one loss. The standard form is a weighted sum:

L=λ1L1+λ2L2++λkLk. L = \lambda_1L_1 + \lambda_2L_2 + \cdots + \lambda_kL_k.

For example:

loss = (
    1.0 * classification_loss
    + 0.1 * regression_loss
    + 0.01 * regularization_loss
)

The weights determine the relative gradient contribution of each term. They should be treated as hyperparameters.

A good implementation logs each component separately:

losses = {
    "classification": classification_loss,
    "regression": regression_loss,
    "regularization": regularization_loss,
}

loss = sum(weights[name] * value for name, value in losses.items())

This prevents the total loss from hiding failures in individual objectives.

Scale and Units

Loss terms with larger numerical scales can dominate optimization.

For example:

Lcls=0.7,Lreg=2000. L_{\text{cls}} = 0.7, \qquad L_{\text{reg}} = 2000.

The regression loss will dominate unless it is normalized or down-weighted.

This can happen even when the regression task is not more important. It may only have larger units.

Common fixes include target standardization, loss normalization, task weights, gradient balancing, and separate learning rates for different heads.

For regression targets:

y=yμσ. y' = \frac{y-\mu}{\sigma}.

Train the model to predict yy', then convert predictions back to the original scale:

y^=σy^+μ. \hat{y} = \sigma\hat{y}' + \mu.

This often improves both optimization and interpretability.

Reduction Modes

PyTorch losses commonly support reduction modes:

ReductionMeaning
"mean"Average all loss elements
"sum"Sum all loss elements
"none"Return elementwise losses

The default is usually "mean".

Use "none" when you need masking or per-example weights:

loss_per_example = torch.nn.functional.cross_entropy(
    logits,
    targets,
    reduction="none",
)

loss = (loss_per_example * weights).mean()

For masked sequences:

loss_per_token = torch.nn.functional.cross_entropy(
    logits.reshape(B * T, V),
    targets.reshape(B * T),
    reduction="none",
).reshape(B, T)

loss = (loss_per_token * mask).sum() / mask.sum().clamp_min(1)

This pattern gives precise control over which examples contribute.

Numerical Stability

Use loss functions that accept logits when available.

Prefer:

torch.nn.CrossEntropyLoss()
torch.nn.BCEWithLogitsLoss()

over manually applying softmax, sigmoid, log, and then loss computation.

Unstable code:

probs = torch.softmax(logits, dim=-1)
loss = -torch.log(probs[torch.arange(B), targets]).mean()

Stable code:

loss = torch.nn.functional.cross_entropy(logits, targets)

The stable implementation handles large logits using numerically safe log-sum-exp computations.

Loss Curves and Metrics

The training loss is not the same as the evaluation metric.

A classifier may minimize cross-entropy, but the metric may be accuracy, F1, AUROC, or calibration error. A regression model may minimize MSE, but the reported metric may be RMSE, MAE, or domain-specific cost.

A good training report separates losses and metrics:

QuantityPurpose
Training lossDrives optimization
Validation lossMeasures objective on held-out data
Task metricMeasures application behavior
Calibration metricMeasures probability quality
Error analysisReveals failure patterns

When the validation metric worsens while training loss improves, the model may be overfitting or optimizing the wrong objective.

Debugging Loss Problems

Loss bugs are common. The first checks should be mechanical.

Check shapes. Predictions and targets must match the loss API. Check target dtype. CrossEntropyLoss expects integer class indices. BCEWithLogitsLoss expects floating-point targets. Check whether logits or probabilities are expected. Check reduction mode. Check whether padding or missing labels are included by mistake.

Minimal diagnostic code:

print("logits", logits.shape, logits.dtype)
print("targets", targets.shape, targets.dtype)
print("loss", loss.item())

For classification, verify that targets are within range:

print(targets.min().item(), targets.max().item())

For KK classes, targets must be between 0 and K - 1.

Loss Sanity Checks

A useful model should pass simple sanity checks.

First, train on a tiny batch. The model should be able to overfit a few examples. If it cannot, there may be a bug in the loss, data, labels, or model.

Second, compare against a random baseline. For KK-class classification, random logits usually give cross-entropy near

logK. \log K.

For example, with K=10K=10, the initial loss is often near

log102.30. \log 10 \approx 2.30.

If the initial loss is far from this without a reason, inspect the implementation.

Third, verify that the loss decreases when predictions improve. Construct a small artificial example and compute the loss manually.

Practical Guidelines

Use CrossEntropyLoss for single-label multiclass classification. Use BCEWithLogitsLoss for binary and multilabel classification. Use MSE, MAE, or Huber for regression. Use token cross-entropy for language modeling. Use contrastive or triplet losses when the output is an embedding space.

Prefer logits-based PyTorch losses for numerical stability. Normalize regression targets. Mask padding and missing labels. Log every loss component separately. Evaluate with task metrics, not only training loss.

A loss function is the contract between the problem definition and the optimizer. Design it carefully.