Skip to content

Text Classification

Text classification is the task of assigning one or more labels to a piece of text.

Text classification is the task of assigning one or more labels to a piece of text. The input may be a sentence, a paragraph, a document, a conversation, or a search query. The output is a class label, a set of labels, or a probability distribution over labels.

Examples include sentiment analysis, spam detection, topic classification, intent detection, toxicity detection, language identification, legal document tagging, product review classification, and support ticket routing.

A text classifier learns a function

fθ:XY, f_\theta: \mathcal{X} \rightarrow \mathcal{Y},

where X\mathcal{X} is the space of text inputs and Y\mathcal{Y} is the set of labels. The parameters θ\theta are learned from labeled examples.

For a single-label classification problem with KK classes, the label space is

Y={1,2,,K}. \mathcal{Y} = \{1,2,\ldots,K\}.

For binary sentiment classification, for example,

Y={negative,positive}. \mathcal{Y} = \{\text{negative}, \text{positive}\}.

For topic classification, the label space may be

Y={sports,finance,science,politics,technology}. \mathcal{Y} = \{\text{sports}, \text{finance}, \text{science}, \text{politics}, \text{technology}\}.

In modern deep learning, text classification is usually solved by combining four components:

  1. A tokenizer
  2. An embedding layer or pretrained language model
  3. A text encoder
  4. A classification head

The tokenizer converts raw text into token IDs. The encoder maps token IDs into contextual representations. The classification head maps the representation into class logits.

From Text to Tensors

Neural networks operate on tensors, not strings. A text input must first be converted into a numerical representation.

Consider the sentence:

The movie was surprisingly good.

A tokenizer may split this into tokens:

["The", "movie", "was", "surprisingly", "good", "."]

Each token is then mapped to an integer ID:

[101, 1996, 3185, 2001, 10889, 2204, 1012, 102]

The exact IDs depend on the tokenizer and vocabulary. Transformer tokenizers often add special tokens. For BERT-style models, [CLS] may appear at the beginning and [SEP] at the end.

For a batch of BB examples, the token IDs are stored as a tensor

XZB×T, X \in \mathbb{Z}^{B \times T},

where TT is the maximum sequence length in the batch.

A typical batch contains two tensors:

input_ids      # shape: [B, T]
attention_mask # shape: [B, T]

The input_ids tensor stores token IDs. The attention_mask tensor marks which positions are real tokens and which positions are padding.

For example:

input_ids = torch.tensor([
    [101, 1996, 3185, 2001, 2204, 102, 0, 0],
    [101, 2023, 2003, 2919, 102, 0, 0, 0],
])

attention_mask = torch.tensor([
    [1, 1, 1, 1, 1, 1, 0, 0],
    [1, 1, 1, 1, 1, 0, 0, 0],
])

The zeros in attention_mask tell the model to ignore padding tokens.

Single-Label Classification

In single-label classification, each input belongs to exactly one class. Sentiment classification with three classes is a typical example:

Y={negative,neutral,positive}. \mathcal{Y} = \{\text{negative}, \text{neutral}, \text{positive}\}.

The model produces one logit per class:

z=fθ(x)RK. z = f_\theta(x) \in \mathbb{R}^K.

The logits are raw scores. They are converted into probabilities with the softmax function:

p(y=kx)=exp(zk)j=1Kexp(zj). p(y=k \mid x) = \frac{\exp(z_k)} {\sum_{j=1}^{K}\exp(z_j)}.

The predicted class is the class with the highest probability:

y^=argmaxkp(y=kx). \hat{y} = \arg\max_k p(y=k \mid x).

In PyTorch, nn.CrossEntropyLoss combines softmax and negative log-likelihood. The model should return raw logits, not probabilities.

import torch
import torch.nn as nn

logits = torch.tensor([
    [2.1, 0.3, -1.2],
    [-0.5, 1.7, 0.2],
])

labels = torch.tensor([0, 1])

loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, labels)

print(loss)

Here logits has shape [B, K], and labels has shape [B].

Multi-Label Classification

In multi-label classification, each input may have several labels at the same time.

For example, a news article may be labeled as both:

["technology", "finance"]

The label vector is binary:

y{0,1}K. y \in \{0,1\}^K.

The model again produces logits

zRK. z \in \mathbb{R}^K.

But instead of softmax, each class is treated as an independent binary classification problem. The sigmoid function converts each logit into a probability:

pk=σ(zk)=11+exp(zk). p_k = \sigma(z_k) = \frac{1}{1+\exp(-z_k)}.

In PyTorch, multi-label classification commonly uses nn.BCEWithLogitsLoss.

logits = torch.tensor([
    [2.0, -1.0, 0.5],
    [-0.2, 1.4, 3.1],
])

labels = torch.tensor([
    [1.0, 0.0, 1.0],
    [0.0, 1.0, 1.0],
])

loss_fn = nn.BCEWithLogitsLoss()
loss = loss_fn(logits, labels)

For prediction, we apply a threshold:

probs = torch.sigmoid(logits)
preds = probs > 0.5

The threshold does not need to be 0.5. In imbalanced classification, the threshold is often tuned on a validation set.

A Simple Text Classifier

A minimal neural text classifier can be built from an embedding layer, a pooling operation, and a linear classifier.

The model maps token IDs to embeddings:

ERV×D, E \in \mathbb{R}^{V \times D},

where VV is the vocabulary size and DD is the embedding dimension.

For an input sequence of length TT, the embedding layer produces

HRT×D. H \in \mathbb{R}^{T \times D}.

To classify the whole sequence, we need a fixed-size vector. A simple method is mean pooling:

h=1Tt=1THt. h = \frac{1}{T}\sum_{t=1}^{T} H_t.

Then the classifier computes

z=Wh+b. z = Wh + b.

A PyTorch implementation:

import torch
import torch.nn as nn

class MeanPoolTextClassifier(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, num_classes: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, input_ids, attention_mask):
        # input_ids: [B, T]
        # attention_mask: [B, T]

        x = self.embedding(input_ids)  # [B, T, D]

        mask = attention_mask.unsqueeze(-1)  # [B, T, 1]
        x = x * mask

        lengths = mask.sum(dim=1).clamp(min=1)  # [B, 1]
        pooled = x.sum(dim=1) / lengths         # [B, D]

        logits = self.classifier(pooled)        # [B, K]
        return logits

This model is small and easy to train. It ignores word order except through the learned embeddings. For many simple classification tasks, this baseline is still useful.

Recurrent and Convolutional Classifiers

Before transformers became dominant, text classifiers commonly used recurrent neural networks and one-dimensional convolutional networks.

An RNN reads tokens sequentially and updates a hidden state:

ht=ϕ(Wxxt+Whht1+b). h_t = \phi(W_x x_t + W_h h_{t-1} + b).

The final hidden state can be used as the sequence representation. LSTMs and GRUs improve this model by using gates to control information flow.

A convolutional text classifier applies one-dimensional filters over token embeddings. A filter may detect local patterns such as phrases or n-grams. After convolution, max pooling selects the strongest feature.

RNNs and CNNs remain useful when models must be small, fast, or deployable without a large pretrained transformer. However, transformer encoders now dominate most high-accuracy text classification systems.

Transformer-Based Text Classification

A transformer encoder maps each token into a contextual embedding. The representation of a token depends on the other tokens in the sequence.

For a batch of token IDs,

XZB×T, X \in \mathbb{Z}^{B \times T},

the transformer produces hidden states

HRB×T×D. H \in \mathbb{R}^{B \times T \times D}.

A classifier must convert this sequence of hidden states into one vector per example. Common choices include:

MethodDescription
CLS poolingUse the hidden state of a special classification token
Mean poolingAverage token hidden states using the attention mask
Max poolingTake maximum values across tokens
Attention poolingLearn weights over token positions

For BERT-like models, CLS pooling is common:

h=HCLS. h = H_{\text{CLS}}.

The classification head is usually a linear layer:

z=Wh+b. z = Wh + b.

In practice, the whole model is fine-tuned end to end.

A simplified transformer classifier looks like this:

class TransformerTextClassifier(nn.Module):
    def __init__(self, encoder, hidden_dim: int, num_classes: int):
        super().__init__()
        self.encoder = encoder
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )

        # For BERT-like models:
        cls_state = outputs.last_hidden_state[:, 0, :]  # [B, D]

        logits = self.classifier(cls_state)             # [B, K]
        return logits

The encoder may be a pretrained model such as BERT, RoBERTa, DeBERTa, DistilBERT, or another transformer encoder.

Fine-Tuning a Pretrained Model

Fine-tuning starts from a pretrained language model and adapts it to a labeled classification dataset.

The usual workflow is:

  1. Load a pretrained tokenizer.
  2. Tokenize the dataset.
  3. Load a pretrained encoder with a classification head.
  4. Train using labeled examples.
  5. Evaluate on held-out validation and test sets.

Fine-tuning works well because the pretrained model already contains general linguistic and semantic representations. The classification dataset only needs to teach the model how those representations map to the task labels.

A typical training objective for single-label classification is

L=1Bi=1Blogpθ(yixi). \mathcal{L} = -\frac{1}{B} \sum_{i=1}^{B} \log p_\theta(y_i \mid x_i).

This is the average cross-entropy loss over a batch.

Fine-tuning has several important hyperparameters:

HyperparameterTypical concern
Learning rateOften much smaller than training from scratch
Batch sizeLimited by sequence length and GPU memory
Sequence lengthControls context and memory cost
Weight decayHelps regularize the classification head and encoder
Warmup stepsStabilizes early training
Number of epochsOften small, commonly 2 to 5 for many datasets

A common mistake is using a learning rate that is too large. Transformer fine-tuning is sensitive to this. Values such as 2×1052\times10^{-5}, 3×1053\times10^{-5}, and 5×1055\times10^{-5} are common starting points for BERT-like models.

Class Imbalance

Many real text classification datasets are imbalanced. Spam may be rarer than normal email. Fraud reports may be rare. Toxic content may be a small fraction of all comments.

If class imbalance is ignored, the model may learn to predict the majority class too often.

Suppose 95 percent of examples are non-spam and 5 percent are spam. A classifier that always predicts non-spam has 95 percent accuracy, but it is useless for detecting spam.

Better metrics include precision, recall, F1 score, ROC-AUC, and PR-AUC.

Precision measures how many predicted positives are correct:

precision=TPTP+FP. \text{precision} = \frac{TP}{TP+FP}.

Recall measures how many true positives are found:

recall=TPTP+FN. \text{recall} = \frac{TP}{TP+FN}.

F1 combines precision and recall:

F1=2precisionrecallprecision+recall. F_1 = \frac{2\cdot \text{precision}\cdot \text{recall}} {\text{precision}+\text{recall}}.

Common methods for imbalance include class-weighted loss, resampling, threshold tuning, and collecting more minority-class data.

In PyTorch, single-label class weighting can be implemented with CrossEntropyLoss:

class_weights = torch.tensor([0.2, 1.8])  # example weights
loss_fn = nn.CrossEntropyLoss(weight=class_weights)

For multi-label classification, BCEWithLogitsLoss supports positive-class weights:

pos_weight = torch.tensor([1.0, 3.0, 5.0])
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

Evaluation Metrics

Accuracy is useful when classes are balanced and errors have similar cost. For many classification systems, accuracy alone is insufficient.

MetricUse case
AccuracyBalanced single-label classification
PrecisionFalse positives are costly
RecallFalse negatives are costly
F1 scoreNeed balance between precision and recall
Macro F1Classes are imbalanced
Micro F1Aggregate performance across all examples
ROC-AUCRanking quality for binary classification
PR-AUCRare positive class problems
Calibration errorProbability quality matters

For single-label classification, a confusion matrix is often useful. It shows which classes are confused with each other.

For multi-label classification, metrics can be computed per class, then averaged. Macro averaging gives each class equal weight. Micro averaging pools decisions across all classes.

Calibration

A classifier is calibrated when its predicted probabilities match empirical frequencies.

If a model assigns probability 0.8 to many examples, and about 80 percent of those predictions are correct, the model is well calibrated at that confidence level.

Calibration matters in systems where probabilities affect decisions. Medical triage, fraud detection, moderation queues, and legal search systems often need calibrated confidence scores.

Deep neural networks can be overconfident. They may assign high probabilities even when wrong. Common calibration methods include temperature scaling, Platt scaling, isotonic regression, and validation-based threshold tuning.

Temperature scaling divides logits by a scalar temperature TT:

p(y=kx)=exp(zk/T)jexp(zj/T). p(y=k \mid x) = \frac{\exp(z_k/T)} {\sum_j \exp(z_j/T)}.

A larger temperature produces softer probabilities. The temperature is selected on a validation set.

Data Leakage

Text classification systems are vulnerable to data leakage. Leakage occurs when information from the validation or test set appears in training.

Examples include:

Leakage sourceExample
Duplicate textSame review appears in train and test
Near duplicatesSlightly edited copies across splits
User leakageSame user appears in train and test
Time leakageFuture examples used to predict past behavior
Label artifactsLabel names appear directly in input text
Preprocessing leakageVocabulary or normalization fitted on all data

Time-based tasks require time-based splits. For example, an email classifier intended for future email should be tested on emails later than the training emails.

Data leakage can make a model appear much better than it really is.

Error Analysis

After training, the most useful diagnostic step is error analysis. This means reading model mistakes and grouping them by cause.

Common error categories include:

Error typeExample
Ambiguous text“This was sick” may mean good or bad
Missing contextA reply depends on previous messages
Label noiseHuman annotation is wrong or inconsistent
Rare terminologyDomain-specific terms absent from training
Negation failure“Not good” classified as positive
SarcasmLiteral meaning differs from intended meaning
Long document truncationImportant evidence appears after max length
Distribution shiftTest data differs from training data

Error analysis often suggests the next improvement: better labels, better split design, longer context, domain-specific pretraining, class rebalancing, or a different metric.

Long Documents

Many classifiers assume that the full input fits into a fixed maximum sequence length. Transformers have a maximum context length, and computation usually grows with sequence length.

For long documents, common strategies include:

StrategyDescription
TruncationKeep the first TT tokens
Head-tail truncationKeep beginning and ending tokens
Sliding windowsClassify overlapping chunks
Hierarchical encodingEncode chunks, then aggregate chunk representations
Retrieval-based classificationSelect relevant passages before classification
Long-context transformersUse models designed for longer sequences

Naive truncation may fail when the important evidence appears late in the document. Legal, medical, and scientific documents often require hierarchical or retrieval-based methods.

Practical PyTorch Training Loop

A simple supervised training loop has four steps: forward pass, loss computation, backward pass, and optimizer update.

def train_one_epoch(model, dataloader, optimizer, loss_fn, device):
    model.train()

    total_loss = 0.0

    for batch in dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad(set_to_none=True)

        logits = model(input_ids, attention_mask)
        loss = loss_fn(logits, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

Evaluation disables gradient tracking:

@torch.no_grad()
def evaluate(model, dataloader, loss_fn, device):
    model.eval()

    total_loss = 0.0
    all_preds = []
    all_labels = []

    for batch in dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        logits = model(input_ids, attention_mask)
        loss = loss_fn(logits, labels)

        preds = logits.argmax(dim=-1)

        total_loss += loss.item()
        all_preds.append(preds.cpu())
        all_labels.append(labels.cpu())

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    accuracy = (all_preds == all_labels).float().mean().item()

    return {
        "loss": total_loss / len(dataloader),
        "accuracy": accuracy,
    }

This loop is intentionally minimal. Production training often adds gradient clipping, mixed precision, distributed training, checkpointing, learning rate schedules, metric logging, and early stopping.

Common Failure Modes

Text classifiers fail for predictable reasons.

The model may overfit when the dataset is small. It may rely on spurious correlations, such as author name, formatting, source website, or repeated boilerplate. It may perform poorly on minority dialects, rare topics, or new events. It may be overconfident on inputs far from the training distribution.

Common implementation mistakes include using the wrong loss function, applying softmax before CrossEntropyLoss, forgetting attention masks, padding labels incorrectly, mixing label IDs, evaluating on shuffled but duplicated data, or comparing probabilities against class IDs.

A reliable text classification system depends on both modeling and data discipline. Most failures come from the dataset, split, labels, or evaluation protocol rather than from the classifier architecture itself.

Summary

Text classification maps text inputs to labels. The basic pipeline converts text into token IDs, encodes the sequence, pools the representation, and applies a classification head.

Single-label classification uses softmax-style cross-entropy. Multi-label classification uses sigmoid-style binary cross-entropy. Transformer encoders provide strong default models because they produce contextual representations. Fine-tuning adapts these representations to a specific label space.

Good classification systems require careful metrics, clean data splits, class imbalance handling, calibration, and error analysis. In PyTorch, the implementation is straightforward, but the quality of the result depends on matching the model, objective, and evaluation method to the real decision problem.