A loss function defines what the model is trained to improve. It translates a modeling goal into a scalar value that can be minimized by gradient-based optimization.
A loss function defines what the model is trained to improve. It translates a modeling goal into a scalar value that can be minimized by gradient-based optimization.
The choice of loss function affects the learned representation, the gradient signal, the stability of training, and the behavior of the final model. Two models with the same architecture and data can learn different solutions if they use different losses.
Loss Functions as Modeling Assumptions
A loss function encodes assumptions about the target.
Mean squared error assumes continuous targets and penalizes large errors strongly. Cross-entropy assumes categorical targets and trains the model to assign probability to the correct class. Contrastive losses assume relationships between examples, such as similarity or relevance. Multi-task losses assume that several objectives can share useful representations.
A loss is therefore both an optimization tool and a statistical statement.
| Target type | Common loss | Model output |
|---|---|---|
| Continuous value | MSE, MAE, Huber | Scalar or vector |
| Binary label | Binary cross-entropy | One logit |
| Single class | Cross-entropy | One logit per class |
| Multiple labels | Binary cross-entropy | One logit per label |
| Ranking | Pairwise margin loss | Scores |
| Embedding similarity | Contrastive or triplet loss | Embeddings |
| Reconstruction | MSE or BCE | Reconstructed input |
| Sequence prediction | Token cross-entropy | Vocabulary logits |
The correct loss depends on what the output means.
Match the Loss to the Output
Loss functions expect specific output forms.
For PyTorch classification, nn.CrossEntropyLoss expects raw logits of shape [B, K] and integer class targets of shape [B].
loss = torch.nn.CrossEntropyLoss()(logits, targets)Do not apply softmax before this loss.
For binary classification, nn.BCEWithLogitsLoss expects raw logits and floating-point targets.
loss = torch.nn.BCEWithLogitsLoss()(logits, targets.float())Do not apply sigmoid before this loss.
For regression, nn.MSELoss expects predictions and targets with matching shapes.
loss = torch.nn.MSELoss()(predictions, targets)These conventions matter because PyTorch loss modules often combine several numerical operations internally for stability.
Classification Loss Selection
For mutually exclusive classes, use cross-entropy.
Example: an image belongs to exactly one class among cat, dog, car, or bird.
logits = model(images) # [B, K]
targets = targets.long() # [B]
loss = torch.nn.CrossEntropyLoss()(logits, targets)For binary classification, use binary cross-entropy with logits.
logits = model(x).squeeze(-1) # [B]
targets = targets.float() # [B]
loss = torch.nn.BCEWithLogitsLoss()(logits, targets)For multilabel classification, use binary cross-entropy independently for each label.
logits = model(x) # [B, K]
targets = targets.float() # [B, K]
loss = torch.nn.BCEWithLogitsLoss()(logits, targets)The key distinction is whether classes compete. Softmax makes classes compete. Sigmoid treats labels independently.
Regression Loss Selection
For continuous targets, MSE is the usual baseline. It works well when errors are roughly Gaussian and large errors should receive strong penalties.
loss = torch.nn.functional.mse_loss(pred, target)Mean absolute error is more robust to outliers:
loss = torch.nn.functional.l1_loss(pred, target)Huber loss combines the two behaviors. It is quadratic near zero and linear for large errors.
loss = torch.nn.HuberLoss(delta=1.0)(pred, target)A useful rule is simple: start with MSE, compare against MAE and Huber when outliers matter, and standardize targets when their numerical scale is large.
Sequence Loss Selection
For language modeling and token prediction, use cross-entropy over the vocabulary.
Suppose logits have shape [B, T, V] and targets have shape [B, T].
B, T, V = logits.shape
loss = torch.nn.CrossEntropyLoss()(
logits.reshape(B * T, V),
targets.reshape(B * T),
)For padded sequences, set padded targets to ignore_index, usually -100.
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fn(logits.reshape(B * T, V), targets.reshape(B * T))Only real target positions should contribute to the loss. Padding tokens are implementation artifacts, not learning targets.
Imbalanced Data
When classes are imbalanced, the unweighted loss may favor common classes.
For multiclass classification, pass class weights to cross-entropy:
weights = torch.tensor([1.0, 2.0, 8.0], device=device)
loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
loss = loss_fn(logits, targets)For binary or multilabel classification, use pos_weight:
pos_weight = torch.tensor([5.0], device=device)
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
loss = loss_fn(logits, targets.float())Weights change the training objective. They can improve rare-class recall, but they may hurt calibration. Always evaluate precision, recall, F1, calibration, and task-specific costs.
Combining Losses
Many models use more than one loss. The standard form is a weighted sum:
For example:
loss = (
1.0 * classification_loss
+ 0.1 * regression_loss
+ 0.01 * regularization_loss
)The weights determine the relative gradient contribution of each term. They should be treated as hyperparameters.
A good implementation logs each component separately:
losses = {
"classification": classification_loss,
"regression": regression_loss,
"regularization": regularization_loss,
}
loss = sum(weights[name] * value for name, value in losses.items())This prevents the total loss from hiding failures in individual objectives.
Scale and Units
Loss terms with larger numerical scales can dominate optimization.
For example:
The regression loss will dominate unless it is normalized or down-weighted.
This can happen even when the regression task is not more important. It may only have larger units.
Common fixes include target standardization, loss normalization, task weights, gradient balancing, and separate learning rates for different heads.
For regression targets:
Train the model to predict , then convert predictions back to the original scale:
This often improves both optimization and interpretability.
Reduction Modes
PyTorch losses commonly support reduction modes:
| Reduction | Meaning |
|---|---|
"mean" | Average all loss elements |
"sum" | Sum all loss elements |
"none" | Return elementwise losses |
The default is usually "mean".
Use "none" when you need masking or per-example weights:
loss_per_example = torch.nn.functional.cross_entropy(
logits,
targets,
reduction="none",
)
loss = (loss_per_example * weights).mean()For masked sequences:
loss_per_token = torch.nn.functional.cross_entropy(
logits.reshape(B * T, V),
targets.reshape(B * T),
reduction="none",
).reshape(B, T)
loss = (loss_per_token * mask).sum() / mask.sum().clamp_min(1)This pattern gives precise control over which examples contribute.
Numerical Stability
Use loss functions that accept logits when available.
Prefer:
torch.nn.CrossEntropyLoss()
torch.nn.BCEWithLogitsLoss()over manually applying softmax, sigmoid, log, and then loss computation.
Unstable code:
probs = torch.softmax(logits, dim=-1)
loss = -torch.log(probs[torch.arange(B), targets]).mean()Stable code:
loss = torch.nn.functional.cross_entropy(logits, targets)The stable implementation handles large logits using numerically safe log-sum-exp computations.
Loss Curves and Metrics
The training loss is not the same as the evaluation metric.
A classifier may minimize cross-entropy, but the metric may be accuracy, F1, AUROC, or calibration error. A regression model may minimize MSE, but the reported metric may be RMSE, MAE, or domain-specific cost.
A good training report separates losses and metrics:
| Quantity | Purpose |
|---|---|
| Training loss | Drives optimization |
| Validation loss | Measures objective on held-out data |
| Task metric | Measures application behavior |
| Calibration metric | Measures probability quality |
| Error analysis | Reveals failure patterns |
When the validation metric worsens while training loss improves, the model may be overfitting or optimizing the wrong objective.
Debugging Loss Problems
Loss bugs are common. The first checks should be mechanical.
Check shapes. Predictions and targets must match the loss API. Check target dtype. CrossEntropyLoss expects integer class indices. BCEWithLogitsLoss expects floating-point targets. Check whether logits or probabilities are expected. Check reduction mode. Check whether padding or missing labels are included by mistake.
Minimal diagnostic code:
print("logits", logits.shape, logits.dtype)
print("targets", targets.shape, targets.dtype)
print("loss", loss.item())For classification, verify that targets are within range:
print(targets.min().item(), targets.max().item())For classes, targets must be between 0 and K - 1.
Loss Sanity Checks
A useful model should pass simple sanity checks.
First, train on a tiny batch. The model should be able to overfit a few examples. If it cannot, there may be a bug in the loss, data, labels, or model.
Second, compare against a random baseline. For -class classification, random logits usually give cross-entropy near
For example, with , the initial loss is often near
If the initial loss is far from this without a reason, inspect the implementation.
Third, verify that the loss decreases when predictions improve. Construct a small artificial example and compute the loss manually.
Practical Guidelines
Use CrossEntropyLoss for single-label multiclass classification. Use BCEWithLogitsLoss for binary and multilabel classification. Use MSE, MAE, or Huber for regression. Use token cross-entropy for language modeling. Use contrastive or triplet losses when the output is an embedding space.
Prefer logits-based PyTorch losses for numerical stability. Normalize regression targets. Mask padding and missing labels. Log every loss component separately. Evaluate with task metrics, not only training loss.
A loss function is the contract between the problem definition and the optimizer. Design it carefully.