Skip to content

Training, Validation, and Test Sets

A machine learning dataset is usually divided into three parts: a training set, a validation set, and a test set.

A machine learning dataset is usually divided into three parts: a training set, a validation set, and a test set. Each part has a different role. The training set is used to fit model parameters. The validation set is used to make design choices. The test set is used only for final evaluation.

This separation is necessary because a model can perform well on examples it has already seen while performing poorly on new examples. The purpose of evaluation is to estimate performance on unseen data, not to measure memorization.

Why We Split Data

Suppose we train a classifier on a dataset of images. During training, the model repeatedly sees the same examples. It may learn useful visual patterns, but it may also memorize accidental details in the training set.

If we evaluate the model only on the training set, the result can be misleading.

A separate evaluation set gives a better estimate of generalization. Generalization means performance on examples that were not used for fitting the model.

The standard split is:

SplitUsed forShould update model parameters?
Training setFit weights and biasesYes
Validation setChoose hyperparameters and model designNo
Test setFinal unbiased evaluationNo

The validation and test sets must be held out from training. They should represent the same problem distribution that the model will face after deployment.

Training Set

The training set is the part of the data used by the optimization algorithm.

Given training examples

Dtrain={(x(i),y(i))}i=1Ntrain, \mathcal{D}_{\text{train}} = \{(x^{(i)}, y^{(i)})\}_{i=1}^{N_{\text{train}}},

the model minimizes empirical training loss:

R^train(θ)=1Ntraini=1NtrainL(fθ(x(i)),y(i)). \hat{\mathcal{R}}_{\text{train}}(\theta) = \frac{1}{N_{\text{train}}} \sum_{i=1}^{N_{\text{train}}} L(f_\theta(x^{(i)}), y^{(i)}).

During backpropagation, gradients are computed from training batches only. The optimizer updates model parameters using these gradients.

In PyTorch:

model.train()

for x_batch, y_batch in train_loader:
    optimizer.zero_grad()

    logits = model(x_batch)
    loss = loss_fn(logits, y_batch)

    loss.backward()
    optimizer.step()

The call model.train() puts certain layers into training mode. This matters for modules such as dropout and batch normalization.

Validation Set

The validation set is used during model development, but not for direct parameter updates.

It answers questions such as:

QuestionExample
Which architecture should we use?MLP, CNN, transformer
Which hyperparameters work best?Learning rate, batch size, weight decay
When should training stop?Early stopping epoch
Which checkpoint should be kept?Lowest validation loss
Which preprocessing works best?Tokenizer, image size, normalization

The validation set estimates how well design choices generalize beyond the training set.

A validation loop does not call backward() and does not call optimizer.step():

model.eval()

total_loss = 0.0
total_examples = 0

with torch.no_grad():
    for x_batch, y_batch in val_loader:
        logits = model(x_batch)
        loss = loss_fn(logits, y_batch)

        batch_size = x_batch.size(0)
        total_loss += loss.item() * batch_size
        total_examples += batch_size

val_loss = total_loss / total_examples

The call model.eval() switches dropout and batch normalization into evaluation behavior. The context torch.no_grad() disables gradient tracking, reducing memory use and computation.

Test Set

The test set is used once the model development process is complete.

It should answer one question:

How well does the final selected model perform on unseen data?

The test set should not be used repeatedly to guide model design. If many choices are made based on test performance, the test set becomes part of the training process indirectly. The final score then becomes optimistically biased.

A clean workflow is:

  1. Train many candidate models on the training set.
  2. Select among them using the validation set.
  3. Evaluate the final chosen model once on the test set.

The test set is therefore a proxy for future deployment data. It should be protected from model selection decisions.

Hyperparameters and Model Selection

Parameters are learned by gradient descent. Hyperparameters are chosen by the practitioner or an outer search process.

Examples of learned parameters:

Learned parameterExample
Weight matrixLinear layer weights
Bias vectorLinear layer bias
Embedding tableToken embeddings
Normalization scaleLayerNorm gain

Examples of hyperparameters:

HyperparameterExample values
Learning rate103,3×104,10510^{-3}, 3\times10^{-4}, 10^{-5}
Batch size32, 128, 1024
Weight decay0, 0.01, 0.1
Number of layers6, 12, 24
Hidden dimension512, 1024, 4096
Dropout rate0.0, 0.1, 0.5

Validation performance is used to choose hyperparameters. Test performance is reserved for the final estimate.

Data Leakage

Data leakage occurs when information from validation or test data influences training.

Leakage can make a model appear much better than it really is.

Common leakage patterns include:

Leakage typeExample
Duplicate leakageSame image appears in train and test
Preprocessing leakageNormalization statistics computed on full dataset
Temporal leakageTraining uses future information
User leakageSame user appears in train and test when evaluating new-user generalization
Label leakageInput contains a field derived from the target

Consider feature normalization. A common mistake is computing mean and standard deviation using all data before splitting. The test set then influences preprocessing.

Correct procedure:

mean = train_data.mean()
std = train_data.std()

train_data = (train_data - mean) / std
val_data = (val_data - mean) / std
test_data = (test_data - mean) / std

The validation and test sets are transformed using statistics from the training set only.

Random Splits

A random split assigns examples randomly to training, validation, and test sets.

Common proportions are:

TrainingValidationTest
80%10%10%
70%15%15%
90%5%5%

In PyTorch:

from torch.utils.data import random_split

dataset = MyDataset()

n = len(dataset)
n_train = int(0.8 * n)
n_val = int(0.1 * n)
n_test = n - n_train - n_val

train_set, val_set, test_set = random_split(
    dataset,
    [n_train, n_val, n_test],
    generator=torch.Generator().manual_seed(42),
)

The fixed random seed makes the split reproducible.

Random splitting works when examples are independent and identically distributed. It can fail when data has groups, time order, duplicates, or repeated entities.

Stratified Splits

For classification problems, a random split may accidentally change class proportions across splits. This is especially harmful when classes are imbalanced.

A stratified split preserves label proportions.

Suppose a dataset has 90% class A and 10% class B. A stratified split tries to keep approximately the same ratio in training, validation, and test sets.

This matters because a validation set with too few minority-class examples gives unstable metrics.

PyTorch does not provide a high-level stratified split utility in its core data API, but it can be done using scikit-learn:

from sklearn.model_selection import train_test_split

indices = list(range(len(labels)))

train_idx, temp_idx = train_test_split(
    indices,
    test_size=0.2,
    stratify=labels,
    random_state=42,
)

val_idx, test_idx = train_test_split(
    temp_idx,
    test_size=0.5,
    stratify=[labels[i] for i in temp_idx],
    random_state=42,
)

Then Subset can wrap the dataset:

from torch.utils.data import Subset

train_set = Subset(dataset, train_idx)
val_set = Subset(dataset, val_idx)
test_set = Subset(dataset, test_idx)

Group Splits

Group splitting keeps related examples in the same split.

This is important when multiple examples come from the same source.

Examples:

DomainGroup
Medical imagingPatient ID
RecommendationUser ID
SpeechSpeaker ID
DocumentsSource document
VideoVideo ID
Web dataWebsite or domain

If images from the same patient appear in both training and test sets, the model may learn patient-specific artifacts. The test score may overestimate performance on new patients.

Group splitting evaluates generalization to new groups. This is often closer to the real deployment problem.

Temporal Splits

Temporal splitting respects time order.

Training uses earlier data. Validation and test use later data.

This is necessary for forecasting, finance, recommendation systems, logs, and many production systems. Random splitting can leak future information into training.

A typical temporal split:

SplitTime range
TrainingJanuary to August
ValidationSeptember to October
TestNovember to December

Temporal validation better simulates deployment, where the model is trained on past data and used on future data.

Cross-Validation

Cross-validation is useful when data is limited.

In KK-fold cross-validation, the dataset is divided into KK folds. The model is trained KK times. Each fold is used once as validation data, while the remaining folds are used for training.

The final validation score is averaged across folds.

This gives a more stable estimate than a single split, but it costs more computation.

Cross-validation is common for small tabular datasets. It is less common for large deep learning workloads because training each model is expensive.

Early Stopping

Early stopping uses validation performance to decide when to stop training.

During training, the training loss may continue decreasing while validation loss starts increasing. This usually indicates overfitting.

A simple early stopping rule is:

  1. Track validation loss after each epoch.
  2. Save the model checkpoint with the lowest validation loss.
  3. Stop if validation loss does not improve for several epochs.

The number of epochs to wait is called patience.

best_val_loss = float("inf")
patience = 5
bad_epochs = 0

for epoch in range(num_epochs):
    train_one_epoch(model, train_loader)
    val_loss = evaluate(model, val_loader)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        bad_epochs = 0
        torch.save(model.state_dict(), "best.pt")
    else:
        bad_epochs += 1

    if bad_epochs >= patience:
        break

Early stopping uses validation data for model selection. The test set remains untouched.

Distribution Shift

A split is useful only if it reflects the intended deployment setting.

If deployment data differs strongly from the training and test data, test performance may fail to predict real-world behavior.

Distribution shift can occur because of:

Shift typeExample
Time shiftUser behavior changes over months
Geography shiftModel trained in one country, deployed in another
Device shiftImages from different cameras
Population shiftMedical model trained on one hospital
Label shiftClass frequencies change
Concept shiftMeaning of labels changes over time

For this reason, dataset splitting is part of problem design, not just bookkeeping.

A good split asks: what kind of future data must this model handle?

A Practical PyTorch Split Template

A minimal split and loader setup looks like this:

import torch
from torch.utils.data import DataLoader, random_split

dataset = MyDataset()

n = len(dataset)
n_train = int(0.8 * n)
n_val = int(0.1 * n)
n_test = n - n_train - n_val

train_set, val_set, test_set = random_split(
    dataset,
    [n_train, n_val, n_test],
    generator=torch.Generator().manual_seed(42),
)

train_loader = DataLoader(
    train_set,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

val_loader = DataLoader(
    val_set,
    batch_size=64,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)

test_loader = DataLoader(
    test_set,
    batch_size=64,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)

Training data is shuffled. Validation and test data are usually not shuffled because their order does not affect metrics.

Summary

The training set fits model parameters. The validation set guides model design and hyperparameter selection. The test set gives the final estimate of generalization.

A reliable split prevents leakage and reflects the deployment problem. Random splits are useful for simple independent data. Stratified, group, and temporal splits are needed when class balance, repeated entities, or time order matters.

Good evaluation begins before training. It starts with deciding what kind of unseen data the model must handle.