Skip to content

Text Classification

Text classification assigns one or more labels to a piece of text.

Text classification assigns one or more labels to a piece of text. The input may be a sentence, paragraph, document, review, message, query, or conversation turn. The output is a class label, a probability distribution over labels, or a set of active labels.

Common examples include sentiment analysis, spam detection, topic classification, intent detection, toxicity detection, language identification, document routing, and product category prediction.

A text classifier has the same basic structure as other neural classifiers:

texttoken IDsembeddingsencoderpooled representationclassifierlogits. \text{text} \rightarrow \text{token IDs} \rightarrow \text{embeddings} \rightarrow \text{encoder} \rightarrow \text{pooled representation} \rightarrow \text{classifier} \rightarrow \text{logits}.

The tokenizer converts text into token IDs. The embedding layer converts token IDs into vectors. The encoder computes contextual representations. The pooling step converts a variable-length sequence into a fixed-size vector. The classifier maps that vector to class scores.

Single-Label Classification

In single-label classification, each input belongs to exactly one class.

For example, a support message may belong to one intent:

TextLabel
I forgot my passwordaccount_access
Where is my order?shipping_status
Please cancel my subscriptionbilling

If there are KK classes, the model returns KK logits:

zRK. z \in \mathbb{R}^K.

The predicted class is the index with the largest logit:

y^=argmaxkzk. \hat{y} = \arg\max_k z_k.

During training, the usual loss is cross-entropy loss. In PyTorch, class labels are integer IDs with shape [batch_size].

import torch
import torch.nn as nn

logits = torch.randn(32, 5)        # [batch_size, num_classes]
labels = torch.randint(0, 5, (32,)) # [batch_size]

loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, labels)

print(loss.shape)  # torch.Size([])

nn.CrossEntropyLoss expects raw logits, not probabilities. It internally applies log-softmax and negative log-likelihood.

Multi-Label Classification

In multi-label classification, an input may have several labels at once.

For example, a news article may be labeled with several topics:

TextLabels
New battery plant opens in Germanytechnology, energy, business
Central bank raises interest ratesfinance, policy

The model still returns KK logits, but each class is treated as an independent binary decision:

zRK. z \in \mathbb{R}^K.

The targets are binary vectors:

y{0,1}K. y \in \{0,1\}^K.

In PyTorch, the usual loss is binary cross-entropy with logits:

logits = torch.randn(32, 6)                 # [batch_size, num_labels]
targets = torch.randint(0, 2, (32, 6)).float()

loss_fn = nn.BCEWithLogitsLoss()
loss = loss_fn(logits, targets)

At inference time, logits are passed through a sigmoid function:

probs = torch.sigmoid(logits)
preds = probs > 0.5

The threshold does not have to be 0.5. For imbalanced datasets, each label may need its own threshold.

A Bag-of-Embeddings Classifier

The simplest neural text classifier embeds each token, averages the embeddings, and applies a linear classifier.

This model ignores word order, but it is fast and often strong on small classification tasks.

import torch
import torch.nn as nn

class BagOfEmbeddingsClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_classes, padding_idx=0):
        super().__init__()
        self.padding_idx = padding_idx
        self.embedding = nn.Embedding(
            vocab_size,
            embedding_dim,
            padding_idx=padding_idx,
        )
        self.classifier = nn.Linear(embedding_dim, num_classes)

    def forward(self, input_ids):
        # input_ids: [B, T]
        x = self.embedding(input_ids)
        # x: [B, T, D]

        mask = input_ids != self.padding_idx
        # mask: [B, T]

        mask = mask.unsqueeze(-1)
        # mask: [B, T, 1]

        x = x * mask

        lengths = mask.sum(dim=1).clamp(min=1)
        # lengths: [B, 1]

        pooled = x.sum(dim=1) / lengths
        # pooled: [B, D]

        logits = self.classifier(pooled)
        # logits: [B, K]

        return logits

The important shape transformation is:

[B, T] -> [B, T, D] -> [B, D] -> [B, K]

The model begins with token IDs, computes one vector per token, averages across the sequence axis, and produces one vector of logits per example.

Pooling Sequence Representations

A classifier needs a fixed-size representation of the input text. Sequence encoders produce one vector per token, so the model must pool the sequence.

Common pooling methods include mean pooling, max pooling, first-token pooling, last-token pooling, and attention pooling.

Pooling methodDefinitionCommon use
Mean poolingAverage token vectorsSentence embeddings, simple classifiers
Max poolingTake maximum over timeCNN and RNN classifiers
First-token poolingUse first hidden stateBERT-style [CLS] classifiers
Last-token poolingUse final hidden stateGPT-style classifiers
Attention poolingLearn weighted averageDocument classification

Suppose the encoder output is

HRB×T×D. H \in \mathbb{R}^{B \times T \times D}.

Mean pooling produces

hRB×D. h \in \mathbb{R}^{B \times D}.

With an attention mask, padding positions should be excluded:

def mean_pool(hidden_states, attention_mask):
    # hidden_states: [B, T, D]
    # attention_mask: [B, T], 1 for real tokens, 0 for padding

    mask = attention_mask.unsqueeze(-1)
    hidden_states = hidden_states * mask

    lengths = mask.sum(dim=1).clamp(min=1)
    pooled = hidden_states.sum(dim=1) / lengths

    return pooled

Pooling choice matters. Mean pooling works well when all tokens contribute to the label. First-token pooling works well when the model has been trained to place sentence-level information in a special token. Last-token pooling is natural for decoder-only language models because the final position has attended to all previous tokens.

CNN Text Classifiers

A convolutional text classifier applies one-dimensional convolutions over token embeddings.

If the input embedding tensor has shape

[B, T, D]

a 1D convolution in PyTorch usually expects

[B, D, T]

so we transpose the tensor before applying nn.Conv1d.

class CNNTextClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_classes, padding_idx=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)

        self.conv3 = nn.Conv1d(embedding_dim, 128, kernel_size=3, padding=1)
        self.conv5 = nn.Conv1d(embedding_dim, 128, kernel_size=5, padding=2)

        self.activation = nn.ReLU()
        self.classifier = nn.Linear(256, num_classes)

    def forward(self, input_ids):
        # input_ids: [B, T]
        x = self.embedding(input_ids)
        # x: [B, T, D]

        x = x.transpose(1, 2)
        # x: [B, D, T]

        h3 = self.activation(self.conv3(x))
        h5 = self.activation(self.conv5(x))
        # h3, h5: [B, 128, T]

        h3 = h3.max(dim=2).values
        h5 = h5.max(dim=2).values
        # h3, h5: [B, 128]

        h = torch.cat([h3, h5], dim=1)
        # h: [B, 256]

        logits = self.classifier(h)
        # logits: [B, K]

        return logits

CNN classifiers are useful when local phrases strongly determine the label. For example, short phrases such as free money, reset password, or cancel subscription may be predictive.

RNN Text Classifiers

An RNN classifier processes the token sequence from left to right. An LSTM or GRU can build a hidden state that summarizes the sequence.

class LSTMTextClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes, padding_idx=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)

        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            batch_first=True,
            bidirectional=True,
        )

        self.classifier = nn.Linear(2 * hidden_dim, num_classes)

    def forward(self, input_ids):
        # input_ids: [B, T]
        x = self.embedding(input_ids)
        # x: [B, T, D]

        output, (h_n, c_n) = self.lstm(x)
        # h_n: [2, B, H] for one bidirectional layer

        forward_last = h_n[0]
        backward_last = h_n[1]

        h = torch.cat([forward_last, backward_last], dim=1)
        # h: [B, 2H]

        logits = self.classifier(h)
        # logits: [B, K]

        return logits

RNN classifiers are less dominant than transformers in modern NLP, but they remain useful when models must be small, fast, or deployed under strict memory limits.

Transformer Text Classifiers

A transformer classifier uses self-attention to compute contextual token representations. In practice, many classifiers fine-tune a pretrained transformer.

A BERT-style classifier commonly prepends a special [CLS] token. The final hidden state at that position is used as the sentence representation.

The shape flow is:

input_ids:      [B, T]
attention_mask: [B, T]
hidden_states:  [B, T, D]
cls_state:      [B, D]
logits:         [B, K]

A simplified classifier head looks like this:

class TransformerClassifierHead(nn.Module):
    def __init__(self, hidden_dim, num_classes):
        super().__init__()
        self.norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, hidden_states):
        # hidden_states: [B, T, D]
        cls_state = hidden_states[:, 0, :]
        # cls_state: [B, D]

        h = self.norm(cls_state)
        h = self.dropout(h)

        logits = self.classifier(h)
        # logits: [B, K]

        return logits

Decoder-only models, such as GPT-style models, often use the hidden state of the final non-padding token instead of a first token.

def last_token_pool(hidden_states, attention_mask):
    # hidden_states: [B, T, D]
    # attention_mask: [B, T]

    lengths = attention_mask.sum(dim=1) - 1
    # lengths: [B]

    batch_idx = torch.arange(hidden_states.size(0), device=hidden_states.device)
    pooled = hidden_states[batch_idx, lengths]

    return pooled

The classifier head can then map pooled to logits.

Training Loop

A standard PyTorch training loop for text classification looks like this:

def train_one_epoch(model, dataloader, optimizer, loss_fn, device):
    model.train()

    total_loss = 0.0
    total_examples = 0

    for batch in dataloader:
        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)

        logits = model(input_ids)
        loss = loss_fn(logits, labels)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        batch_size = input_ids.size(0)
        total_loss += loss.item() * batch_size
        total_examples += batch_size

    return total_loss / total_examples

For models that use attention masks:

logits = model(
    input_ids=batch["input_ids"].to(device),
    attention_mask=batch["attention_mask"].to(device),
)

The loss function depends on the task:

Task typeOutput shapeTarget shapeLoss
Single-label[B, K][B]CrossEntropyLoss
Binary[B] or [B, 1][B] or [B, 1]BCEWithLogitsLoss
Multi-label[B, K][B, K]BCEWithLogitsLoss

Evaluation Metrics

Accuracy is appropriate when classes are balanced and each mistake has similar cost. Many text classification problems violate these conditions. Spam detection, fraud detection, toxicity detection, and rare intent detection often have imbalanced labels.

Common metrics include precision, recall, F1 score, ROC AUC, PR AUC, and confusion matrices.

MetricMeaningUseful when
AccuracyFraction of correct predictionsClasses are balanced
PrecisionFraction of predicted positives that are correctFalse positives are costly
RecallFraction of actual positives foundFalse negatives are costly
F1Harmonic mean of precision and recallNeed balance between precision and recall
ROC AUCRanking quality across thresholdsBinary classification
PR AUCPrecision-recall tradeoffRare positive class

For single-label classification:

def accuracy(logits, labels):
    preds = logits.argmax(dim=1)
    return (preds == labels).float().mean()

For multi-label classification:

def multilabel_predictions(logits, threshold=0.5):
    probs = torch.sigmoid(logits)
    return probs >= threshold

Choosing the threshold is part of model design. A threshold that maximizes validation F1 may differ from a threshold that minimizes business cost.

Class Imbalance

Text datasets often contain class imbalance. Some labels occur frequently, while others are rare.

There are several standard responses.

One approach is weighted loss. In single-label classification, CrossEntropyLoss accepts class weights:

class_counts = torch.tensor([1000, 200, 50], dtype=torch.float)
class_weights = 1.0 / class_counts
class_weights = class_weights / class_weights.sum() * len(class_counts)

loss_fn = nn.CrossEntropyLoss(weight=class_weights)

For multi-label classification, BCEWithLogitsLoss accepts pos_weight:

positive_counts = torch.tensor([1000, 200, 50], dtype=torch.float)
negative_counts = torch.tensor([9000, 9800, 9950], dtype=torch.float)

pos_weight = negative_counts / positive_counts

loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

Another approach is sampling. A training loader can oversample rare classes or undersample common classes. This changes the training distribution, so validation should still be performed on the natural distribution.

A third approach is threshold tuning. For rare labels, a default threshold of 0.5 may be too conservative.

Calibration

A classifier produces scores. In many applications, we want these scores to behave like probabilities.

A model is calibrated when examples assigned probability 0.8 are correct about 80 percent of the time.

Neural classifiers are often miscalibrated. They may be overconfident, especially after extensive fine-tuning. Calibration can be improved with temperature scaling.

For logits zz, temperature scaling computes

p=softmax(z/T), p = \operatorname{softmax}(z / T),

where T>0T > 0 is learned on a validation set. Larger TT makes the distribution softer. Smaller TT makes it sharper.

Calibration matters when probabilities are used for ranking, triage, deferral, or risk estimation.

Long Documents

Long document classification creates additional problems. Many transformer models have a maximum context length. A document may exceed that length.

Common strategies include truncation, sliding windows, hierarchical models, retrieval, and sparse attention.

StrategyMethodTradeoff
TruncationKeep first or last tokensSimple but may discard evidence
Sliding windowClassify chunks and aggregateMore compute
Hierarchical modelEncode chunks, then encode chunk vectorsMore complex
RetrievalSelect relevant passagesDepends on retrieval quality
Long-context modelUse model with larger context windowHigher memory cost

For document classification, mean pooling over chunks is often a strong baseline. More advanced systems may use attention pooling over chunk representations.

Error Analysis

Text classification errors should be analyzed with examples, not only aggregate metrics.

Useful questions include:

QuestionPurpose
Which classes are confused?Detect overlapping labels
Which examples have high confidence but wrong predictions?Detect systematic failures
Which examples are low confidence?Find ambiguous data
Are rare classes ignored?Diagnose imbalance
Are labels noisy?Improve dataset quality
Are certain groups or dialects harmed?Check fairness and robustness

A confusion matrix is often the first diagnostic tool for single-label classification. For multi-label classification, per-label precision, recall, and threshold curves are more informative.

Common Errors

A frequent error is applying softmax before CrossEntropyLoss. In PyTorch, CrossEntropyLoss expects logits. Passing probabilities can make training less stable.

Correct:

loss = nn.CrossEntropyLoss()(logits, labels)

Incorrect:

probs = torch.softmax(logits, dim=1)
loss = nn.CrossEntropyLoss()(probs, labels)

Another common error is using CrossEntropyLoss for multi-label classification. Multi-label targets need BCEWithLogitsLoss.

A third error is averaging embeddings without masking padding. Padding tokens should not contribute to the pooled representation.

A fourth error is evaluating on tokenized or truncated text without checking how much content was removed. Silent truncation can create misleading results.

Summary

Text classification maps text to labels. A neural classifier usually consists of a tokenizer, embedding layer, encoder, pooling step, and classifier head.

Single-label classification uses one class per example and is usually trained with cross-entropy loss. Multi-label classification allows several active labels and is usually trained with binary cross-entropy with logits.

Simple bag-of-embeddings models are fast and useful baselines. CNNs capture local phrases. RNNs summarize sequences. Transformers provide contextual representations and are the dominant architecture for high-accuracy NLP systems.