# Text Classification

Text classification assigns one or more labels to a piece of text. The input may be a sentence, paragraph, document, review, message, query, or conversation turn. The output is a class label, a probability distribution over labels, or a set of active labels.

Common examples include sentiment analysis, spam detection, topic classification, intent detection, toxicity detection, language identification, document routing, and product category prediction.

A text classifier has the same basic structure as other neural classifiers:

$$
\text{text}
\rightarrow
\text{token IDs}
\rightarrow
\text{embeddings}
\rightarrow
\text{encoder}
\rightarrow
\text{pooled representation}
\rightarrow
\text{classifier}
\rightarrow
\text{logits}.
$$

The tokenizer converts text into token IDs. The embedding layer converts token IDs into vectors. The encoder computes contextual representations. The pooling step converts a variable-length sequence into a fixed-size vector. The classifier maps that vector to class scores.

### Single-Label Classification

In single-label classification, each input belongs to exactly one class.

For example, a support message may belong to one intent:

| Text | Label |
|---|---|
| `I forgot my password` | `account_access` |
| `Where is my order?` | `shipping_status` |
| `Please cancel my subscription` | `billing` |

If there are $K$ classes, the model returns $K$ logits:

$$
z \in \mathbb{R}^K.
$$

The predicted class is the index with the largest logit:

$$
\hat{y} = \arg\max_k z_k.
$$

During training, the usual loss is cross-entropy loss. In PyTorch, class labels are integer IDs with shape `[batch_size]`.

```python
import torch
import torch.nn as nn

logits = torch.randn(32, 5)        # [batch_size, num_classes]
labels = torch.randint(0, 5, (32,)) # [batch_size]

loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, labels)

print(loss.shape)  # torch.Size([])
```

`nn.CrossEntropyLoss` expects raw logits, not probabilities. It internally applies log-softmax and negative log-likelihood.

### Multi-Label Classification

In multi-label classification, an input may have several labels at once.

For example, a news article may be labeled with several topics:

| Text | Labels |
|---|---|
| `New battery plant opens in Germany` | `technology`, `energy`, `business` |
| `Central bank raises interest rates` | `finance`, `policy` |

The model still returns $K$ logits, but each class is treated as an independent binary decision:

$$
z \in \mathbb{R}^K.
$$

The targets are binary vectors:

$$
y \in \{0,1\}^K.
$$

In PyTorch, the usual loss is binary cross-entropy with logits:

```python
logits = torch.randn(32, 6)                 # [batch_size, num_labels]
targets = torch.randint(0, 2, (32, 6)).float()

loss_fn = nn.BCEWithLogitsLoss()
loss = loss_fn(logits, targets)
```

At inference time, logits are passed through a sigmoid function:

```python
probs = torch.sigmoid(logits)
preds = probs > 0.5
```

The threshold does not have to be 0.5. For imbalanced datasets, each label may need its own threshold.

### A Bag-of-Embeddings Classifier

The simplest neural text classifier embeds each token, averages the embeddings, and applies a linear classifier.

This model ignores word order, but it is fast and often strong on small classification tasks.

```python
import torch
import torch.nn as nn

class BagOfEmbeddingsClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_classes, padding_idx=0):
        super().__init__()
        self.padding_idx = padding_idx
        self.embedding = nn.Embedding(
            vocab_size,
            embedding_dim,
            padding_idx=padding_idx,
        )
        self.classifier = nn.Linear(embedding_dim, num_classes)

    def forward(self, input_ids):
        # input_ids: [B, T]
        x = self.embedding(input_ids)
        # x: [B, T, D]

        mask = input_ids != self.padding_idx
        # mask: [B, T]

        mask = mask.unsqueeze(-1)
        # mask: [B, T, 1]

        x = x * mask

        lengths = mask.sum(dim=1).clamp(min=1)
        # lengths: [B, 1]

        pooled = x.sum(dim=1) / lengths
        # pooled: [B, D]

        logits = self.classifier(pooled)
        # logits: [B, K]

        return logits
```

The important shape transformation is:

```text
[B, T] -> [B, T, D] -> [B, D] -> [B, K]
```

The model begins with token IDs, computes one vector per token, averages across the sequence axis, and produces one vector of logits per example.

### Pooling Sequence Representations

A classifier needs a fixed-size representation of the input text. Sequence encoders produce one vector per token, so the model must pool the sequence.

Common pooling methods include mean pooling, max pooling, first-token pooling, last-token pooling, and attention pooling.

| Pooling method | Definition | Common use |
|---|---|---|
| Mean pooling | Average token vectors | Sentence embeddings, simple classifiers |
| Max pooling | Take maximum over time | CNN and RNN classifiers |
| First-token pooling | Use first hidden state | BERT-style `[CLS]` classifiers |
| Last-token pooling | Use final hidden state | GPT-style classifiers |
| Attention pooling | Learn weighted average | Document classification |

Suppose the encoder output is

$$
H \in \mathbb{R}^{B \times T \times D}.
$$

Mean pooling produces

$$
h \in \mathbb{R}^{B \times D}.
$$

With an attention mask, padding positions should be excluded:

```python
def mean_pool(hidden_states, attention_mask):
    # hidden_states: [B, T, D]
    # attention_mask: [B, T], 1 for real tokens, 0 for padding

    mask = attention_mask.unsqueeze(-1)
    hidden_states = hidden_states * mask

    lengths = mask.sum(dim=1).clamp(min=1)
    pooled = hidden_states.sum(dim=1) / lengths

    return pooled
```

Pooling choice matters. Mean pooling works well when all tokens contribute to the label. First-token pooling works well when the model has been trained to place sentence-level information in a special token. Last-token pooling is natural for decoder-only language models because the final position has attended to all previous tokens.

### CNN Text Classifiers

A convolutional text classifier applies one-dimensional convolutions over token embeddings.

If the input embedding tensor has shape

```text
[B, T, D]
```

a 1D convolution in PyTorch usually expects

```text
[B, D, T]
```

so we transpose the tensor before applying `nn.Conv1d`.

```python
class CNNTextClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_classes, padding_idx=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)

        self.conv3 = nn.Conv1d(embedding_dim, 128, kernel_size=3, padding=1)
        self.conv5 = nn.Conv1d(embedding_dim, 128, kernel_size=5, padding=2)

        self.activation = nn.ReLU()
        self.classifier = nn.Linear(256, num_classes)

    def forward(self, input_ids):
        # input_ids: [B, T]
        x = self.embedding(input_ids)
        # x: [B, T, D]

        x = x.transpose(1, 2)
        # x: [B, D, T]

        h3 = self.activation(self.conv3(x))
        h5 = self.activation(self.conv5(x))
        # h3, h5: [B, 128, T]

        h3 = h3.max(dim=2).values
        h5 = h5.max(dim=2).values
        # h3, h5: [B, 128]

        h = torch.cat([h3, h5], dim=1)
        # h: [B, 256]

        logits = self.classifier(h)
        # logits: [B, K]

        return logits
```

CNN classifiers are useful when local phrases strongly determine the label. For example, short phrases such as `free money`, `reset password`, or `cancel subscription` may be predictive.

### RNN Text Classifiers

An RNN classifier processes the token sequence from left to right. An LSTM or GRU can build a hidden state that summarizes the sequence.

```python
class LSTMTextClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes, padding_idx=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)

        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            batch_first=True,
            bidirectional=True,
        )

        self.classifier = nn.Linear(2 * hidden_dim, num_classes)

    def forward(self, input_ids):
        # input_ids: [B, T]
        x = self.embedding(input_ids)
        # x: [B, T, D]

        output, (h_n, c_n) = self.lstm(x)
        # h_n: [2, B, H] for one bidirectional layer

        forward_last = h_n[0]
        backward_last = h_n[1]

        h = torch.cat([forward_last, backward_last], dim=1)
        # h: [B, 2H]

        logits = self.classifier(h)
        # logits: [B, K]

        return logits
```

RNN classifiers are less dominant than transformers in modern NLP, but they remain useful when models must be small, fast, or deployed under strict memory limits.

### Transformer Text Classifiers

A transformer classifier uses self-attention to compute contextual token representations. In practice, many classifiers fine-tune a pretrained transformer.

A BERT-style classifier commonly prepends a special `[CLS]` token. The final hidden state at that position is used as the sentence representation.

The shape flow is:

```text
input_ids:      [B, T]
attention_mask: [B, T]
hidden_states:  [B, T, D]
cls_state:      [B, D]
logits:         [B, K]
```

A simplified classifier head looks like this:

```python
class TransformerClassifierHead(nn.Module):
    def __init__(self, hidden_dim, num_classes):
        super().__init__()
        self.norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, hidden_states):
        # hidden_states: [B, T, D]
        cls_state = hidden_states[:, 0, :]
        # cls_state: [B, D]

        h = self.norm(cls_state)
        h = self.dropout(h)

        logits = self.classifier(h)
        # logits: [B, K]

        return logits
```

Decoder-only models, such as GPT-style models, often use the hidden state of the final non-padding token instead of a first token.

```python
def last_token_pool(hidden_states, attention_mask):
    # hidden_states: [B, T, D]
    # attention_mask: [B, T]

    lengths = attention_mask.sum(dim=1) - 1
    # lengths: [B]

    batch_idx = torch.arange(hidden_states.size(0), device=hidden_states.device)
    pooled = hidden_states[batch_idx, lengths]

    return pooled
```

The classifier head can then map `pooled` to logits.

### Training Loop

A standard PyTorch training loop for text classification looks like this:

```python
def train_one_epoch(model, dataloader, optimizer, loss_fn, device):
    model.train()

    total_loss = 0.0
    total_examples = 0

    for batch in dataloader:
        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)

        logits = model(input_ids)
        loss = loss_fn(logits, labels)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        batch_size = input_ids.size(0)
        total_loss += loss.item() * batch_size
        total_examples += batch_size

    return total_loss / total_examples
```

For models that use attention masks:

```python
logits = model(
    input_ids=batch["input_ids"].to(device),
    attention_mask=batch["attention_mask"].to(device),
)
```

The loss function depends on the task:

| Task type | Output shape | Target shape | Loss |
|---|---:|---:|---|
| Single-label | `[B, K]` | `[B]` | `CrossEntropyLoss` |
| Binary | `[B]` or `[B, 1]` | `[B]` or `[B, 1]` | `BCEWithLogitsLoss` |
| Multi-label | `[B, K]` | `[B, K]` | `BCEWithLogitsLoss` |

### Evaluation Metrics

Accuracy is appropriate when classes are balanced and each mistake has similar cost. Many text classification problems violate these conditions. Spam detection, fraud detection, toxicity detection, and rare intent detection often have imbalanced labels.

Common metrics include precision, recall, F1 score, ROC AUC, PR AUC, and confusion matrices.

| Metric | Meaning | Useful when |
|---|---|---|
| Accuracy | Fraction of correct predictions | Classes are balanced |
| Precision | Fraction of predicted positives that are correct | False positives are costly |
| Recall | Fraction of actual positives found | False negatives are costly |
| F1 | Harmonic mean of precision and recall | Need balance between precision and recall |
| ROC AUC | Ranking quality across thresholds | Binary classification |
| PR AUC | Precision-recall tradeoff | Rare positive class |

For single-label classification:

```python
def accuracy(logits, labels):
    preds = logits.argmax(dim=1)
    return (preds == labels).float().mean()
```

For multi-label classification:

```python
def multilabel_predictions(logits, threshold=0.5):
    probs = torch.sigmoid(logits)
    return probs >= threshold
```

Choosing the threshold is part of model design. A threshold that maximizes validation F1 may differ from a threshold that minimizes business cost.

### Class Imbalance

Text datasets often contain class imbalance. Some labels occur frequently, while others are rare.

There are several standard responses.

One approach is weighted loss. In single-label classification, `CrossEntropyLoss` accepts class weights:

```python
class_counts = torch.tensor([1000, 200, 50], dtype=torch.float)
class_weights = 1.0 / class_counts
class_weights = class_weights / class_weights.sum() * len(class_counts)

loss_fn = nn.CrossEntropyLoss(weight=class_weights)
```

For multi-label classification, `BCEWithLogitsLoss` accepts `pos_weight`:

```python
positive_counts = torch.tensor([1000, 200, 50], dtype=torch.float)
negative_counts = torch.tensor([9000, 9800, 9950], dtype=torch.float)

pos_weight = negative_counts / positive_counts

loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
```

Another approach is sampling. A training loader can oversample rare classes or undersample common classes. This changes the training distribution, so validation should still be performed on the natural distribution.

A third approach is threshold tuning. For rare labels, a default threshold of 0.5 may be too conservative.

### Calibration

A classifier produces scores. In many applications, we want these scores to behave like probabilities.

A model is calibrated when examples assigned probability 0.8 are correct about 80 percent of the time.

Neural classifiers are often miscalibrated. They may be overconfident, especially after extensive fine-tuning. Calibration can be improved with temperature scaling.

For logits $z$, temperature scaling computes

$$
p = \operatorname{softmax}(z / T),
$$

where $T > 0$ is learned on a validation set. Larger $T$ makes the distribution softer. Smaller $T$ makes it sharper.

Calibration matters when probabilities are used for ranking, triage, deferral, or risk estimation.

### Long Documents

Long document classification creates additional problems. Many transformer models have a maximum context length. A document may exceed that length.

Common strategies include truncation, sliding windows, hierarchical models, retrieval, and sparse attention.

| Strategy | Method | Tradeoff |
|---|---|---|
| Truncation | Keep first or last tokens | Simple but may discard evidence |
| Sliding window | Classify chunks and aggregate | More compute |
| Hierarchical model | Encode chunks, then encode chunk vectors | More complex |
| Retrieval | Select relevant passages | Depends on retrieval quality |
| Long-context model | Use model with larger context window | Higher memory cost |

For document classification, mean pooling over chunks is often a strong baseline. More advanced systems may use attention pooling over chunk representations.

### Error Analysis

Text classification errors should be analyzed with examples, not only aggregate metrics.

Useful questions include:

| Question | Purpose |
|---|---|
| Which classes are confused? | Detect overlapping labels |
| Which examples have high confidence but wrong predictions? | Detect systematic failures |
| Which examples are low confidence? | Find ambiguous data |
| Are rare classes ignored? | Diagnose imbalance |
| Are labels noisy? | Improve dataset quality |
| Are certain groups or dialects harmed? | Check fairness and robustness |

A confusion matrix is often the first diagnostic tool for single-label classification. For multi-label classification, per-label precision, recall, and threshold curves are more informative.

### Common Errors

A frequent error is applying softmax before `CrossEntropyLoss`. In PyTorch, `CrossEntropyLoss` expects logits. Passing probabilities can make training less stable.

Correct:

```python
loss = nn.CrossEntropyLoss()(logits, labels)
```

Incorrect:

```python
probs = torch.softmax(logits, dim=1)
loss = nn.CrossEntropyLoss()(probs, labels)
```

Another common error is using `CrossEntropyLoss` for multi-label classification. Multi-label targets need `BCEWithLogitsLoss`.

A third error is averaging embeddings without masking padding. Padding tokens should not contribute to the pooled representation.

A fourth error is evaluating on tokenized or truncated text without checking how much content was removed. Silent truncation can create misleading results.

### Summary

Text classification maps text to labels. A neural classifier usually consists of a tokenizer, embedding layer, encoder, pooling step, and classifier head.

Single-label classification uses one class per example and is usually trained with cross-entropy loss. Multi-label classification allows several active labels and is usually trained with binary cross-entropy with logits.

Simple bag-of-embeddings models are fast and useful baselines. CNNs capture local phrases. RNNs summarize sequences. Transformers provide contextual representations and are the dominant architecture for high-accuracy NLP systems.

