Cross-entropy loss is the standard loss function for classification. It measures how well a model’s predicted class distribution matches the true class label.
Cross-entropy loss is the standard loss function for classification. It measures how well a model’s predicted class distribution matches the true class label.
In regression, the target is usually a real number. In classification, the target is a class. For example, an image classifier may choose one label from
A neural network does not usually output the class directly. It outputs a vector of scores called logits. The logits are then converted into probabilities.
If there are classes, the model produces
where is the logit for class . The softmax function converts logits into probabilities:
The probabilities satisfy
Cross-entropy penalizes the model when it assigns low probability to the correct class.
Classification as Probability Prediction
In a -class classification problem, the model represents a conditional distribution
Given an input , the model estimates the probability of each class. If the true class is , the ideal model assigns high probability to class :
For one example, the cross-entropy loss is
If the model assigns probability to the correct class, the loss is small:
If the model assigns probability to the correct class, the loss is large:
The loss therefore rewards confidence when the model is correct and strongly penalizes confidence in the wrong classes.
One-Hot Targets
A class label can be represented as a one-hot vector. If there are classes and the correct class is , the target vector has
For example, with four classes and correct class , using zero-based indexing:
If the model predicts probabilities
then the cross-entropy is
Because only the correct class has , this reduces to
This is why cross-entropy with one-hot labels is often described as the negative log probability of the true class.
Cross-Entropy for Batches
For a batch of examples, the logits have shape
The target labels have shape
Each row of contains the logits for one example. Each entry of contains the correct class index for one example.
The batch cross-entropy loss is
Here is the predicted probability assigned to the correct class for example .
In PyTorch:
import torch
import torch.nn as nn
logits = torch.tensor([
[2.0, 0.5, -1.0],
[0.1, 1.5, 0.3],
[-0.5, 0.2, 2.0],
])
targets = torch.tensor([0, 1, 2])
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, targets)
print(loss)nn.CrossEntropyLoss expects raw logits, not softmax probabilities. It applies log_softmax internally and then computes negative log-likelihood.
Why PyTorch Uses Logits
A common mistake is to apply softmax before passing outputs into nn.CrossEntropyLoss:
probabilities = torch.softmax(logits, dim=-1)
loss = nn.CrossEntropyLoss()(probabilities, targets) # WrongThis is numerically and mathematically wrong for the standard PyTorch loss.
The correct form is:
loss = nn.CrossEntropyLoss()(logits, targets)PyTorch combines softmax and logarithm in a numerically stable operation. Computing softmax first can produce very small probabilities. Taking the logarithm of those values can cause numerical instability.
The stable computation uses
In practice, implementations subtract the maximum logit before exponentiation to avoid overflow.
Gradient of Cross-Entropy with Softmax
Cross-entropy has a particularly simple gradient when combined with softmax.
Let
and let be a one-hot target vector. The loss is
The gradient with respect to the logit is
For the correct class, this gradient is . For every incorrect class, it is .
This has a useful interpretation. The model increases the logit for the correct class and decreases logits for incorrect classes in proportion to their predicted probabilities.
For a batch of size , the averaged gradient is
for each example and class.
Binary Cross-Entropy
For binary classification, there are two common formulations.
The first uses two logits and nn.CrossEntropyLoss:
logits = model(x) # shape [B, 2]
targets = targets.long() # shape [B]
loss = nn.CrossEntropyLoss()(logits, targets)The second uses one logit and binary cross-entropy:
logits = model(x) # shape [B]
targets = targets.float() # shape [B]
loss = nn.BCEWithLogitsLoss()(logits, targets)nn.BCEWithLogitsLoss combines a sigmoid function and binary cross-entropy in one numerically stable operation.
For one logit , the sigmoid probability is
The binary cross-entropy loss is
Use BCEWithLogitsLoss for binary classification with one output logit. Use CrossEntropyLoss for multiclass classification with mutually exclusive classes.
Multiclass Versus Multilabel Classification
Multiclass classification means each example belongs to exactly one class. For example, an image may be classified as one of ten digits. The model output shape is
and the target shape is
Use:
loss = nn.CrossEntropyLoss()(logits, targets)Multilabel classification means each example may belong to multiple classes at the same time. For example, an image may contain both “car” and “person.” The model output shape is
and the target shape is also
where each target entry is or .
Use:
loss = nn.BCEWithLogitsLoss()(logits, targets.float())The difference is important. Softmax forces probabilities across classes to compete. Sigmoid treats each class independently.
| Task type | Output shape | Target shape | Loss |
|---|---|---|---|
| Binary, one logit | [B] or [B, 1] | [B] or [B, 1] | BCEWithLogitsLoss |
| Multiclass | [B, K] | [B] | CrossEntropyLoss |
| Multilabel | [B, K] | [B, K] | BCEWithLogitsLoss |
Class Imbalance
Classification datasets often have imbalanced classes. For example, in medical screening, normal cases may be much more common than positive cases. A model trained with ordinary cross-entropy may learn to favor the majority class.
PyTorch allows class weighting:
class_weights = torch.tensor([1.0, 3.0, 10.0])
loss_fn = nn.CrossEntropyLoss(weight=class_weights)
loss = loss_fn(logits, targets)The weight for each class increases or decreases the loss contribution for examples of that class.
For binary or multilabel classification, BCEWithLogitsLoss provides pos_weight:
pos_weight = torch.tensor([5.0])
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
loss = loss_fn(logits, targets.float())Class weights should be used carefully. They change the optimization objective. They may improve recall for rare classes, but they can also reduce calibration or increase false positives.
Label Smoothing
Standard cross-entropy treats the target label as fully certain. For a correct class , the one-hot target assigns probability to class and to every other class.
Label smoothing softens the target distribution. Instead of
it uses
and distributes the remaining probability across other classes.
For classes, a common form is
This discourages the model from becoming overly confident.
In PyTorch:
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
loss = loss_fn(logits, targets)Label smoothing can improve generalization and calibration, especially in large classification models. It may also reduce the maximum confidence of the model, so it should be evaluated against the needs of the application.
Cross-Entropy and Maximum Likelihood
Cross-entropy has a probabilistic interpretation. Suppose each target label is sampled from a categorical distribution predicted by the model:
The likelihood of the observed label is
The negative log-likelihood over a dataset is
Dividing by gives the mean cross-entropy loss:
Thus, minimizing cross-entropy is equivalent to maximizing the likelihood of the observed labels under a categorical model.
Cross-Entropy and Information Theory
Cross-entropy can also be viewed as an information-theoretic quantity. If is the true data distribution and is the model distribution, the cross-entropy is
It is related to entropy and KL divergence:
Since does not depend on the model, minimizing cross-entropy also minimizes the KL divergence from the true distribution to the model distribution.
In ordinary supervised classification with one-hot targets, places all probability mass on the observed class. Cross-entropy then becomes the negative log probability of that class.
Cross-Entropy for Segmentation
In image segmentation, the model predicts a class for each pixel. The logits often have shape
where is batch size, is number of classes, and are image dimensions.
The targets usually have shape
where each pixel stores a class index.
In PyTorch:
logits = torch.randn(4, 21, 256, 256)
targets = torch.randint(0, 21, (4, 256, 256))
loss = nn.CrossEntropyLoss()(logits, targets)The same loss applies at every pixel and then averages across batch and spatial dimensions.
For binary segmentation, one may instead use BCEWithLogitsLoss with output shape [B, 1, H, W] and target shape [B, 1, H, W].
Ignore Index
Some classification targets should not contribute to the loss. This is common in segmentation and sequence modeling.
For example, padded tokens in a text batch should usually be ignored.
PyTorch provides ignore_index:
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
logits = torch.randn(8, 128, 50257) # [B, T, V]
targets = torch.randint(0, 50257, (8, 128))
loss = loss_fn(
logits.reshape(-1, 50257),
targets.reshape(-1),
)Any target equal to -100 is excluded from the loss.
For language modeling, padded positions or masked positions can be set to -100 so they do not affect training.
Cross-Entropy for Language Modeling
In autoregressive language modeling, the model predicts the next token. If the vocabulary size is , the model produces logits over the vocabulary.
For a batch of token sequences, logits may have shape
and targets may have shape
Each position predicts the next token. The loss is cross-entropy over vocabulary classes:
In PyTorch, CrossEntropyLoss expects the class dimension before extra dimensions, or a flattened shape. A common implementation is:
B, T, V = logits.shape
loss = nn.CrossEntropyLoss()(
logits.reshape(B * T, V),
targets.reshape(B * T),
)This is the central training objective for GPT-style language models.
Perplexity
For language models, cross-entropy is often reported as perplexity.
If the average cross-entropy loss is measured using natural logarithms, perplexity is
Lower perplexity means the model assigns higher probability to the observed tokens.
For example:
loss = torch.tensor(2.0)
perplexity = torch.exp(loss)
print(perplexity)Perplexity is easier to interpret than raw cross-entropy in some language modeling settings, but it should be compared only under the same tokenization and evaluation setup.
Common PyTorch Mistakes
The most common error is passing probabilities into nn.CrossEntropyLoss. Pass raw logits.
Another common error is using floating-point one-hot labels when CrossEntropyLoss expects integer class indices. For ordinary multiclass classification, targets should usually have dtype torch.long:
targets = targets.long()A third error is using CrossEntropyLoss for multilabel classification. Multilabel classification should usually use BCEWithLogitsLoss.
A fourth error is putting the class dimension in the wrong place. For image classification, logits should be [B, K]. For segmentation, logits should be [B, K, H, W]. For sequence models, flattening to [B*T, V] is often the simplest approach.
Practical Guidelines
Use nn.CrossEntropyLoss for multiclass classification with mutually exclusive classes. The model should output raw logits. The targets should be integer class indices.
Use nn.BCEWithLogitsLoss for binary or multilabel classification. The model should output raw logits. The targets should be floating-point values with entries usually equal to or .
For imbalanced datasets, consider class weights, positive-class weights, resampling, or specialized losses. For large classification models, consider label smoothing. For sequence models, use ignore_index to remove padding or irrelevant positions from the loss.
Cross-entropy is the default loss for classification because it is both statistically principled and computationally convenient. It corresponds to maximum likelihood training for categorical targets, works directly with probabilistic predictions, and gives a simple gradient when combined with softmax.