A machine learning dataset is usually divided into three parts: a training set, a validation set, and a test set.
A machine learning dataset is usually divided into three parts: a training set, a validation set, and a test set. Each part has a different role. The training set is used to fit model parameters. The validation set is used to make design choices. The test set is used only for final evaluation.
This separation is necessary because a model can perform well on examples it has already seen while performing poorly on new examples. The purpose of evaluation is to estimate performance on unseen data, not to measure memorization.
Why We Split Data
Suppose we train a classifier on a dataset of images. During training, the model repeatedly sees the same examples. It may learn useful visual patterns, but it may also memorize accidental details in the training set.
If we evaluate the model only on the training set, the result can be misleading.
A separate evaluation set gives a better estimate of generalization. Generalization means performance on examples that were not used for fitting the model.
The standard split is:
| Split | Used for | Should update model parameters? |
|---|---|---|
| Training set | Fit weights and biases | Yes |
| Validation set | Choose hyperparameters and model design | No |
| Test set | Final unbiased evaluation | No |
The validation and test sets must be held out from training. They should represent the same problem distribution that the model will face after deployment.
Training Set
The training set is the part of the data used by the optimization algorithm.
Given training examples
the model minimizes empirical training loss:
During backpropagation, gradients are computed from training batches only. The optimizer updates model parameters using these gradients.
In PyTorch:
model.train()
for x_batch, y_batch in train_loader:
optimizer.zero_grad()
logits = model(x_batch)
loss = loss_fn(logits, y_batch)
loss.backward()
optimizer.step()The call model.train() puts certain layers into training mode. This matters for modules such as dropout and batch normalization.
Validation Set
The validation set is used during model development, but not for direct parameter updates.
It answers questions such as:
| Question | Example |
|---|---|
| Which architecture should we use? | MLP, CNN, transformer |
| Which hyperparameters work best? | Learning rate, batch size, weight decay |
| When should training stop? | Early stopping epoch |
| Which checkpoint should be kept? | Lowest validation loss |
| Which preprocessing works best? | Tokenizer, image size, normalization |
The validation set estimates how well design choices generalize beyond the training set.
A validation loop does not call backward() and does not call optimizer.step():
model.eval()
total_loss = 0.0
total_examples = 0
with torch.no_grad():
for x_batch, y_batch in val_loader:
logits = model(x_batch)
loss = loss_fn(logits, y_batch)
batch_size = x_batch.size(0)
total_loss += loss.item() * batch_size
total_examples += batch_size
val_loss = total_loss / total_examplesThe call model.eval() switches dropout and batch normalization into evaluation behavior. The context torch.no_grad() disables gradient tracking, reducing memory use and computation.
Test Set
The test set is used once the model development process is complete.
It should answer one question:
How well does the final selected model perform on unseen data?
The test set should not be used repeatedly to guide model design. If many choices are made based on test performance, the test set becomes part of the training process indirectly. The final score then becomes optimistically biased.
A clean workflow is:
- Train many candidate models on the training set.
- Select among them using the validation set.
- Evaluate the final chosen model once on the test set.
The test set is therefore a proxy for future deployment data. It should be protected from model selection decisions.
Hyperparameters and Model Selection
Parameters are learned by gradient descent. Hyperparameters are chosen by the practitioner or an outer search process.
Examples of learned parameters:
| Learned parameter | Example |
|---|---|
| Weight matrix | Linear layer weights |
| Bias vector | Linear layer bias |
| Embedding table | Token embeddings |
| Normalization scale | LayerNorm gain |
Examples of hyperparameters:
| Hyperparameter | Example values |
|---|---|
| Learning rate | |
| Batch size | 32, 128, 1024 |
| Weight decay | 0, 0.01, 0.1 |
| Number of layers | 6, 12, 24 |
| Hidden dimension | 512, 1024, 4096 |
| Dropout rate | 0.0, 0.1, 0.5 |
Validation performance is used to choose hyperparameters. Test performance is reserved for the final estimate.
Data Leakage
Data leakage occurs when information from validation or test data influences training.
Leakage can make a model appear much better than it really is.
Common leakage patterns include:
| Leakage type | Example |
|---|---|
| Duplicate leakage | Same image appears in train and test |
| Preprocessing leakage | Normalization statistics computed on full dataset |
| Temporal leakage | Training uses future information |
| User leakage | Same user appears in train and test when evaluating new-user generalization |
| Label leakage | Input contains a field derived from the target |
Consider feature normalization. A common mistake is computing mean and standard deviation using all data before splitting. The test set then influences preprocessing.
Correct procedure:
mean = train_data.mean()
std = train_data.std()
train_data = (train_data - mean) / std
val_data = (val_data - mean) / std
test_data = (test_data - mean) / stdThe validation and test sets are transformed using statistics from the training set only.
Random Splits
A random split assigns examples randomly to training, validation, and test sets.
Common proportions are:
| Training | Validation | Test |
|---|---|---|
| 80% | 10% | 10% |
| 70% | 15% | 15% |
| 90% | 5% | 5% |
In PyTorch:
from torch.utils.data import random_split
dataset = MyDataset()
n = len(dataset)
n_train = int(0.8 * n)
n_val = int(0.1 * n)
n_test = n - n_train - n_val
train_set, val_set, test_set = random_split(
dataset,
[n_train, n_val, n_test],
generator=torch.Generator().manual_seed(42),
)The fixed random seed makes the split reproducible.
Random splitting works when examples are independent and identically distributed. It can fail when data has groups, time order, duplicates, or repeated entities.
Stratified Splits
For classification problems, a random split may accidentally change class proportions across splits. This is especially harmful when classes are imbalanced.
A stratified split preserves label proportions.
Suppose a dataset has 90% class A and 10% class B. A stratified split tries to keep approximately the same ratio in training, validation, and test sets.
This matters because a validation set with too few minority-class examples gives unstable metrics.
PyTorch does not provide a high-level stratified split utility in its core data API, but it can be done using scikit-learn:
from sklearn.model_selection import train_test_split
indices = list(range(len(labels)))
train_idx, temp_idx = train_test_split(
indices,
test_size=0.2,
stratify=labels,
random_state=42,
)
val_idx, test_idx = train_test_split(
temp_idx,
test_size=0.5,
stratify=[labels[i] for i in temp_idx],
random_state=42,
)Then Subset can wrap the dataset:
from torch.utils.data import Subset
train_set = Subset(dataset, train_idx)
val_set = Subset(dataset, val_idx)
test_set = Subset(dataset, test_idx)Group Splits
Group splitting keeps related examples in the same split.
This is important when multiple examples come from the same source.
Examples:
| Domain | Group |
|---|---|
| Medical imaging | Patient ID |
| Recommendation | User ID |
| Speech | Speaker ID |
| Documents | Source document |
| Video | Video ID |
| Web data | Website or domain |
If images from the same patient appear in both training and test sets, the model may learn patient-specific artifacts. The test score may overestimate performance on new patients.
Group splitting evaluates generalization to new groups. This is often closer to the real deployment problem.
Temporal Splits
Temporal splitting respects time order.
Training uses earlier data. Validation and test use later data.
This is necessary for forecasting, finance, recommendation systems, logs, and many production systems. Random splitting can leak future information into training.
A typical temporal split:
| Split | Time range |
|---|---|
| Training | January to August |
| Validation | September to October |
| Test | November to December |
Temporal validation better simulates deployment, where the model is trained on past data and used on future data.
Cross-Validation
Cross-validation is useful when data is limited.
In -fold cross-validation, the dataset is divided into folds. The model is trained times. Each fold is used once as validation data, while the remaining folds are used for training.
The final validation score is averaged across folds.
This gives a more stable estimate than a single split, but it costs more computation.
Cross-validation is common for small tabular datasets. It is less common for large deep learning workloads because training each model is expensive.
Early Stopping
Early stopping uses validation performance to decide when to stop training.
During training, the training loss may continue decreasing while validation loss starts increasing. This usually indicates overfitting.
A simple early stopping rule is:
- Track validation loss after each epoch.
- Save the model checkpoint with the lowest validation loss.
- Stop if validation loss does not improve for several epochs.
The number of epochs to wait is called patience.
best_val_loss = float("inf")
patience = 5
bad_epochs = 0
for epoch in range(num_epochs):
train_one_epoch(model, train_loader)
val_loss = evaluate(model, val_loader)
if val_loss < best_val_loss:
best_val_loss = val_loss
bad_epochs = 0
torch.save(model.state_dict(), "best.pt")
else:
bad_epochs += 1
if bad_epochs >= patience:
breakEarly stopping uses validation data for model selection. The test set remains untouched.
Distribution Shift
A split is useful only if it reflects the intended deployment setting.
If deployment data differs strongly from the training and test data, test performance may fail to predict real-world behavior.
Distribution shift can occur because of:
| Shift type | Example |
|---|---|
| Time shift | User behavior changes over months |
| Geography shift | Model trained in one country, deployed in another |
| Device shift | Images from different cameras |
| Population shift | Medical model trained on one hospital |
| Label shift | Class frequencies change |
| Concept shift | Meaning of labels changes over time |
For this reason, dataset splitting is part of problem design, not just bookkeeping.
A good split asks: what kind of future data must this model handle?
A Practical PyTorch Split Template
A minimal split and loader setup looks like this:
import torch
from torch.utils.data import DataLoader, random_split
dataset = MyDataset()
n = len(dataset)
n_train = int(0.8 * n)
n_val = int(0.1 * n)
n_test = n - n_train - n_val
train_set, val_set, test_set = random_split(
dataset,
[n_train, n_val, n_test],
generator=torch.Generator().manual_seed(42),
)
train_loader = DataLoader(
train_set,
batch_size=64,
shuffle=True,
num_workers=4,
pin_memory=True,
)
val_loader = DataLoader(
val_set,
batch_size=64,
shuffle=False,
num_workers=4,
pin_memory=True,
)
test_loader = DataLoader(
test_set,
batch_size=64,
shuffle=False,
num_workers=4,
pin_memory=True,
)Training data is shuffled. Validation and test data are usually not shuffled because their order does not affect metrics.
Summary
The training set fits model parameters. The validation set guides model design and hyperparameter selection. The test set gives the final estimate of generalization.
A reliable split prevents leakage and reflects the deployment problem. Random splits are useful for simple independent data. Stratified, group, and temporal splits are needed when class balance, repeated entities, or time order matters.
Good evaluation begins before training. It starts with deciding what kind of unseen data the model must handle.