Text classification assigns one or more labels to a piece of text. The input may be a sentence, paragraph, document, review, message, query, or conversation turn. The output is a class label, a probability distribution over labels, or a set of active labels.
Common examples include sentiment analysis, spam detection, topic classification, intent detection, toxicity detection, language identification, document routing, and product category prediction.
A text classifier has the same basic structure as other neural classifiers:
The tokenizer converts text into token IDs. The embedding layer converts token IDs into vectors. The encoder computes contextual representations. The pooling step converts a variable-length sequence into a fixed-size vector. The classifier maps that vector to class scores.
Single-Label Classification
In single-label classification, each input belongs to exactly one class.
For example, a support message may belong to one intent:
| Text | Label |
|---|---|
I forgot my password | account_access |
Where is my order? | shipping_status |
Please cancel my subscription | billing |
If there are classes, the model returns logits:
The predicted class is the index with the largest logit:
During training, the usual loss is cross-entropy loss. In PyTorch, class labels are integer IDs with shape [batch_size].
import torch
import torch.nn as nn
logits = torch.randn(32, 5) # [batch_size, num_classes]
labels = torch.randint(0, 5, (32,)) # [batch_size]
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, labels)
print(loss.shape) # torch.Size([])nn.CrossEntropyLoss expects raw logits, not probabilities. It internally applies log-softmax and negative log-likelihood.
Multi-Label Classification
In multi-label classification, an input may have several labels at once.
For example, a news article may be labeled with several topics:
| Text | Labels |
|---|---|
New battery plant opens in Germany | technology, energy, business |
Central bank raises interest rates | finance, policy |
The model still returns logits, but each class is treated as an independent binary decision:
The targets are binary vectors:
In PyTorch, the usual loss is binary cross-entropy with logits:
logits = torch.randn(32, 6) # [batch_size, num_labels]
targets = torch.randint(0, 2, (32, 6)).float()
loss_fn = nn.BCEWithLogitsLoss()
loss = loss_fn(logits, targets)At inference time, logits are passed through a sigmoid function:
probs = torch.sigmoid(logits)
preds = probs > 0.5The threshold does not have to be 0.5. For imbalanced datasets, each label may need its own threshold.
A Bag-of-Embeddings Classifier
The simplest neural text classifier embeds each token, averages the embeddings, and applies a linear classifier.
This model ignores word order, but it is fast and often strong on small classification tasks.
import torch
import torch.nn as nn
class BagOfEmbeddingsClassifier(nn.Module):
def __init__(self, vocab_size, embedding_dim, num_classes, padding_idx=0):
super().__init__()
self.padding_idx = padding_idx
self.embedding = nn.Embedding(
vocab_size,
embedding_dim,
padding_idx=padding_idx,
)
self.classifier = nn.Linear(embedding_dim, num_classes)
def forward(self, input_ids):
# input_ids: [B, T]
x = self.embedding(input_ids)
# x: [B, T, D]
mask = input_ids != self.padding_idx
# mask: [B, T]
mask = mask.unsqueeze(-1)
# mask: [B, T, 1]
x = x * mask
lengths = mask.sum(dim=1).clamp(min=1)
# lengths: [B, 1]
pooled = x.sum(dim=1) / lengths
# pooled: [B, D]
logits = self.classifier(pooled)
# logits: [B, K]
return logitsThe important shape transformation is:
[B, T] -> [B, T, D] -> [B, D] -> [B, K]The model begins with token IDs, computes one vector per token, averages across the sequence axis, and produces one vector of logits per example.
Pooling Sequence Representations
A classifier needs a fixed-size representation of the input text. Sequence encoders produce one vector per token, so the model must pool the sequence.
Common pooling methods include mean pooling, max pooling, first-token pooling, last-token pooling, and attention pooling.
| Pooling method | Definition | Common use |
|---|---|---|
| Mean pooling | Average token vectors | Sentence embeddings, simple classifiers |
| Max pooling | Take maximum over time | CNN and RNN classifiers |
| First-token pooling | Use first hidden state | BERT-style [CLS] classifiers |
| Last-token pooling | Use final hidden state | GPT-style classifiers |
| Attention pooling | Learn weighted average | Document classification |
Suppose the encoder output is
Mean pooling produces
With an attention mask, padding positions should be excluded:
def mean_pool(hidden_states, attention_mask):
# hidden_states: [B, T, D]
# attention_mask: [B, T], 1 for real tokens, 0 for padding
mask = attention_mask.unsqueeze(-1)
hidden_states = hidden_states * mask
lengths = mask.sum(dim=1).clamp(min=1)
pooled = hidden_states.sum(dim=1) / lengths
return pooledPooling choice matters. Mean pooling works well when all tokens contribute to the label. First-token pooling works well when the model has been trained to place sentence-level information in a special token. Last-token pooling is natural for decoder-only language models because the final position has attended to all previous tokens.
CNN Text Classifiers
A convolutional text classifier applies one-dimensional convolutions over token embeddings.
If the input embedding tensor has shape
[B, T, D]a 1D convolution in PyTorch usually expects
[B, D, T]so we transpose the tensor before applying nn.Conv1d.
class CNNTextClassifier(nn.Module):
def __init__(self, vocab_size, embedding_dim, num_classes, padding_idx=0):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
self.conv3 = nn.Conv1d(embedding_dim, 128, kernel_size=3, padding=1)
self.conv5 = nn.Conv1d(embedding_dim, 128, kernel_size=5, padding=2)
self.activation = nn.ReLU()
self.classifier = nn.Linear(256, num_classes)
def forward(self, input_ids):
# input_ids: [B, T]
x = self.embedding(input_ids)
# x: [B, T, D]
x = x.transpose(1, 2)
# x: [B, D, T]
h3 = self.activation(self.conv3(x))
h5 = self.activation(self.conv5(x))
# h3, h5: [B, 128, T]
h3 = h3.max(dim=2).values
h5 = h5.max(dim=2).values
# h3, h5: [B, 128]
h = torch.cat([h3, h5], dim=1)
# h: [B, 256]
logits = self.classifier(h)
# logits: [B, K]
return logitsCNN classifiers are useful when local phrases strongly determine the label. For example, short phrases such as free money, reset password, or cancel subscription may be predictive.
RNN Text Classifiers
An RNN classifier processes the token sequence from left to right. An LSTM or GRU can build a hidden state that summarizes the sequence.
class LSTMTextClassifier(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes, padding_idx=0):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
self.lstm = nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_dim,
batch_first=True,
bidirectional=True,
)
self.classifier = nn.Linear(2 * hidden_dim, num_classes)
def forward(self, input_ids):
# input_ids: [B, T]
x = self.embedding(input_ids)
# x: [B, T, D]
output, (h_n, c_n) = self.lstm(x)
# h_n: [2, B, H] for one bidirectional layer
forward_last = h_n[0]
backward_last = h_n[1]
h = torch.cat([forward_last, backward_last], dim=1)
# h: [B, 2H]
logits = self.classifier(h)
# logits: [B, K]
return logitsRNN classifiers are less dominant than transformers in modern NLP, but they remain useful when models must be small, fast, or deployed under strict memory limits.
Transformer Text Classifiers
A transformer classifier uses self-attention to compute contextual token representations. In practice, many classifiers fine-tune a pretrained transformer.
A BERT-style classifier commonly prepends a special [CLS] token. The final hidden state at that position is used as the sentence representation.
The shape flow is:
input_ids: [B, T]
attention_mask: [B, T]
hidden_states: [B, T, D]
cls_state: [B, D]
logits: [B, K]A simplified classifier head looks like this:
class TransformerClassifierHead(nn.Module):
def __init__(self, hidden_dim, num_classes):
super().__init__()
self.norm = nn.LayerNorm(hidden_dim)
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(hidden_dim, num_classes)
def forward(self, hidden_states):
# hidden_states: [B, T, D]
cls_state = hidden_states[:, 0, :]
# cls_state: [B, D]
h = self.norm(cls_state)
h = self.dropout(h)
logits = self.classifier(h)
# logits: [B, K]
return logitsDecoder-only models, such as GPT-style models, often use the hidden state of the final non-padding token instead of a first token.
def last_token_pool(hidden_states, attention_mask):
# hidden_states: [B, T, D]
# attention_mask: [B, T]
lengths = attention_mask.sum(dim=1) - 1
# lengths: [B]
batch_idx = torch.arange(hidden_states.size(0), device=hidden_states.device)
pooled = hidden_states[batch_idx, lengths]
return pooledThe classifier head can then map pooled to logits.
Training Loop
A standard PyTorch training loop for text classification looks like this:
def train_one_epoch(model, dataloader, optimizer, loss_fn, device):
model.train()
total_loss = 0.0
total_examples = 0
for batch in dataloader:
input_ids = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
logits = model(input_ids)
loss = loss_fn(logits, labels)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
batch_size = input_ids.size(0)
total_loss += loss.item() * batch_size
total_examples += batch_size
return total_loss / total_examplesFor models that use attention masks:
logits = model(
input_ids=batch["input_ids"].to(device),
attention_mask=batch["attention_mask"].to(device),
)The loss function depends on the task:
| Task type | Output shape | Target shape | Loss |
|---|---|---|---|
| Single-label | [B, K] | [B] | CrossEntropyLoss |
| Binary | [B] or [B, 1] | [B] or [B, 1] | BCEWithLogitsLoss |
| Multi-label | [B, K] | [B, K] | BCEWithLogitsLoss |
Evaluation Metrics
Accuracy is appropriate when classes are balanced and each mistake has similar cost. Many text classification problems violate these conditions. Spam detection, fraud detection, toxicity detection, and rare intent detection often have imbalanced labels.
Common metrics include precision, recall, F1 score, ROC AUC, PR AUC, and confusion matrices.
| Metric | Meaning | Useful when |
|---|---|---|
| Accuracy | Fraction of correct predictions | Classes are balanced |
| Precision | Fraction of predicted positives that are correct | False positives are costly |
| Recall | Fraction of actual positives found | False negatives are costly |
| F1 | Harmonic mean of precision and recall | Need balance between precision and recall |
| ROC AUC | Ranking quality across thresholds | Binary classification |
| PR AUC | Precision-recall tradeoff | Rare positive class |
For single-label classification:
def accuracy(logits, labels):
preds = logits.argmax(dim=1)
return (preds == labels).float().mean()For multi-label classification:
def multilabel_predictions(logits, threshold=0.5):
probs = torch.sigmoid(logits)
return probs >= thresholdChoosing the threshold is part of model design. A threshold that maximizes validation F1 may differ from a threshold that minimizes business cost.
Class Imbalance
Text datasets often contain class imbalance. Some labels occur frequently, while others are rare.
There are several standard responses.
One approach is weighted loss. In single-label classification, CrossEntropyLoss accepts class weights:
class_counts = torch.tensor([1000, 200, 50], dtype=torch.float)
class_weights = 1.0 / class_counts
class_weights = class_weights / class_weights.sum() * len(class_counts)
loss_fn = nn.CrossEntropyLoss(weight=class_weights)For multi-label classification, BCEWithLogitsLoss accepts pos_weight:
positive_counts = torch.tensor([1000, 200, 50], dtype=torch.float)
negative_counts = torch.tensor([9000, 9800, 9950], dtype=torch.float)
pos_weight = negative_counts / positive_counts
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)Another approach is sampling. A training loader can oversample rare classes or undersample common classes. This changes the training distribution, so validation should still be performed on the natural distribution.
A third approach is threshold tuning. For rare labels, a default threshold of 0.5 may be too conservative.
Calibration
A classifier produces scores. In many applications, we want these scores to behave like probabilities.
A model is calibrated when examples assigned probability 0.8 are correct about 80 percent of the time.
Neural classifiers are often miscalibrated. They may be overconfident, especially after extensive fine-tuning. Calibration can be improved with temperature scaling.
For logits , temperature scaling computes
where is learned on a validation set. Larger makes the distribution softer. Smaller makes it sharper.
Calibration matters when probabilities are used for ranking, triage, deferral, or risk estimation.
Long Documents
Long document classification creates additional problems. Many transformer models have a maximum context length. A document may exceed that length.
Common strategies include truncation, sliding windows, hierarchical models, retrieval, and sparse attention.
| Strategy | Method | Tradeoff |
|---|---|---|
| Truncation | Keep first or last tokens | Simple but may discard evidence |
| Sliding window | Classify chunks and aggregate | More compute |
| Hierarchical model | Encode chunks, then encode chunk vectors | More complex |
| Retrieval | Select relevant passages | Depends on retrieval quality |
| Long-context model | Use model with larger context window | Higher memory cost |
For document classification, mean pooling over chunks is often a strong baseline. More advanced systems may use attention pooling over chunk representations.
Error Analysis
Text classification errors should be analyzed with examples, not only aggregate metrics.
Useful questions include:
| Question | Purpose |
|---|---|
| Which classes are confused? | Detect overlapping labels |
| Which examples have high confidence but wrong predictions? | Detect systematic failures |
| Which examples are low confidence? | Find ambiguous data |
| Are rare classes ignored? | Diagnose imbalance |
| Are labels noisy? | Improve dataset quality |
| Are certain groups or dialects harmed? | Check fairness and robustness |
A confusion matrix is often the first diagnostic tool for single-label classification. For multi-label classification, per-label precision, recall, and threshold curves are more informative.
Common Errors
A frequent error is applying softmax before CrossEntropyLoss. In PyTorch, CrossEntropyLoss expects logits. Passing probabilities can make training less stable.
Correct:
loss = nn.CrossEntropyLoss()(logits, labels)Incorrect:
probs = torch.softmax(logits, dim=1)
loss = nn.CrossEntropyLoss()(probs, labels)Another common error is using CrossEntropyLoss for multi-label classification. Multi-label targets need BCEWithLogitsLoss.
A third error is averaging embeddings without masking padding. Padding tokens should not contribute to the pooled representation.
A fourth error is evaluating on tokenized or truncated text without checking how much content was removed. Silent truncation can create misleading results.
Summary
Text classification maps text to labels. A neural classifier usually consists of a tokenizer, embedding layer, encoder, pooling step, and classifier head.
Single-label classification uses one class per example and is usually trained with cross-entropy loss. Multi-label classification allows several active labels and is usually trained with binary cross-entropy with logits.
Simple bag-of-embeddings models are fast and useful baselines. CNNs capture local phrases. RNNs summarize sequences. Transformers provide contextual representations and are the dominant architecture for high-accuracy NLP systems.