Large-scale training means training models on datasets, model sizes, or hardware configurations that exceed a simple single-GPU workflow.
Large-scale training means training models on datasets, model sizes, or hardware configurations that exceed a simple single-GPU workflow. In image classification, this often means millions of images, large backbones, long schedules, high-resolution inputs, or multi-GPU training.
The goal is not only to make training faster. The goal is to keep optimization stable, data loading efficient, validation reliable, and checkpoints recoverable while the system grows.
What Changes at Scale
A small training job can tolerate inefficient code. A large training job cannot. At scale, small inefficiencies become expensive.
| Area | Small training | Large-scale training |
|---|---|---|
| Data | Fits on local disk | Often sharded, streamed, cached |
| Batch size | Tens of images | Hundreds or thousands of images |
| Hardware | CPU or one GPU | Multiple GPUs or nodes |
| Precision | Usually float32 | Often mixed precision |
| Checkpoints | Occasional manual saves | Regular resumable checkpoints |
| Validation | Simple loop | Distributed, scheduled, logged |
| Failure handling | Restart from scratch | Resume from checkpoint |
| Profiling | Optional | Required |
Large-scale training is a systems problem as much as a modeling problem.
Throughput
Throughput measures how many examples the system processes per second.
For training, throughput depends on data loading, preprocessing, host-to-device transfer, forward pass, backward pass, optimizer step, and synchronization.
A simple measurement:
import time
def measure_throughput(model, loader, loss_fn, optimizer, device, steps=100):
model.train()
start = time.time()
total_images = 0
for step, (images, labels) in enumerate(loader):
if step >= steps:
break
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
logits = model(images)
loss = loss_fn(logits, labels)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
total_images += images.size(0)
elapsed = time.time() - start
return total_images / elapsedUse throughput together with validation metrics. A faster system that reaches worse accuracy may not be a better system.
Data Loading Bottlenecks
At scale, the GPU should spend most of its time computing, not waiting for data. Slow data loading is one of the most common causes of poor hardware utilization.
Useful DataLoader settings:
train_loader = DataLoader(
train_set,
batch_size=256,
shuffle=True,
num_workers=8,
pin_memory=True,
persistent_workers=True,
prefetch_factor=4,
)Important options:
| Option | Purpose |
|---|---|
num_workers | Uses subprocesses for data loading |
pin_memory | Speeds CPU-to-GPU transfer |
persistent_workers | Keeps workers alive across epochs |
prefetch_factor | Loads future batches in advance |
non_blocking=True | Allows asynchronous device transfer |
Use more workers only while it improves throughput. Too many workers can saturate CPU, disk, or memory bandwidth.
Sharded Datasets
Large datasets are often stored as shards rather than individual image files. A directory containing millions of small files can be slow because metadata access dominates.
Common large-scale storage formats include:
| Format | Typical use |
|---|---|
| Tar shards | Sequential image loading |
| WebDataset | Streaming samples from tar files |
| LMDB | Key-value image storage |
| TFRecord | TensorFlow-style records |
| Parquet | Structured metadata and features |
| Object storage shards | Cloud-scale training |
A sharded dataset improves sequential reads and makes distributed training easier. Each worker can read different shards, reducing contention.
A simple shard naming convention:
train-000000.tar
train-000001.tar
train-000002.tar
...The shard size should balance two costs. Very small shards increase scheduling overhead. Very large shards reduce flexibility and make recovery more expensive.
Large Batch Training
Large-scale image training often uses large batches. A larger batch gives a more stable gradient estimate and better hardware utilization. However, it can also reduce generalization or require learning rate adjustment.
A common rule is linear learning rate scaling:
For example, if the base learning rate is for batch size , then for batch size :
Large learning rates should usually be combined with warmup.
Learning Rate Warmup
Warmup gradually increases the learning rate at the start of training. This is useful when training with large batches, mixed precision, or very deep models.
warmup = torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=0.01,
total_iters=5,
)
cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=95,
)
scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
schedulers=[warmup, cosine],
milestones=[5],
)A common schedule is:
| Phase | Learning rate behavior |
|---|---|
| Warmup | Increase from small value |
| Main training | Cosine decay, step decay, or polynomial decay |
| Final phase | Small updates near convergence |
Warmup prevents early unstable updates before the network’s activations and normalization statistics settle.
Mixed Precision Training
Mixed precision uses lower precision arithmetic where possible while keeping numerically sensitive parts stable. In PyTorch, this is usually done with automatic mixed precision.
scaler = torch.cuda.amp.GradScaler()
for images, labels in train_loader:
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
optimizer.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast():
logits = model(images)
loss = loss_fn(logits, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()Mixed precision reduces memory use and can increase throughput. It is especially useful on modern GPUs with tensor cores.
The GradScaler helps prevent gradient underflow when using float16. With bfloat16, explicit gradient scaling is often less important, but the exact behavior depends on hardware and PyTorch configuration.
Gradient Accumulation
Gradient accumulation simulates a larger batch size by accumulating gradients over several smaller mini-batches before taking an optimizer step.
accum_steps = 4
optimizer.zero_grad(set_to_none=True)
for step, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
logits = model(images)
loss = loss_fn(logits, labels)
loss = loss / accum_steps
loss.backward()
if (step + 1) % accum_steps == 0:
optimizer.step()
optimizer.zero_grad(set_to_none=True)If the per-device batch size is 64 and accum_steps = 4, the effective batch size per device is 256.
With distributed training, the effective global batch size is:
This value should be used when choosing the learning rate.
Distributed Data Parallel Training
DistributedDataParallel, usually called DDP, is the standard PyTorch method for multi-GPU training. Each process owns one GPU and one copy of the model. Each process receives a different data shard. Gradients are synchronized across processes during backpropagation.
A minimal DDP setup has these parts:
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_ddp(rank, world_size):
dist.init_process_group(
backend="nccl",
rank=rank,
world_size=world_size,
)
torch.cuda.set_device(rank)Wrap the model:
model = model.to(rank)
model = DDP(model, device_ids=[rank])Use a distributed sampler:
from torch.utils.data.distributed import DistributedSampler
sampler = DistributedSampler(
train_set,
num_replicas=world_size,
rank=rank,
shuffle=True,
)
train_loader = DataLoader(
train_set,
batch_size=batch_size,
sampler=sampler,
num_workers=8,
pin_memory=True,
)At the start of each epoch:
sampler.set_epoch(epoch)This ensures that shuffling differs across epochs while remaining synchronized across processes.
Checkpointing at Scale
Large training jobs must be resumable. A checkpoint should contain enough state to continue training without changing the optimization trajectory too much.
A useful checkpoint contains:
| Item | Reason |
|---|---|
| Model state | Learned parameters |
| Optimizer state | Momentum and adaptive statistics |
| Scheduler state | Learning rate schedule position |
| Grad scaler state | Mixed precision state |
| Epoch and step | Resume location |
| Best metric | Model selection |
| Class mapping | Correct inference labels |
| Config | Reproducibility |
Example:
def save_checkpoint(path, model, optimizer, scheduler, scaler, epoch, step, best_metric, config):
state = {
"model": model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict() if scheduler is not None else None,
"scaler": scaler.state_dict() if scaler is not None else None,
"epoch": epoch,
"step": step,
"best_metric": best_metric,
"config": config,
}
torch.save(state, path)Only one process should usually write checkpoints in DDP:
if rank == 0:
save_checkpoint(...)This avoids multiple processes writing the same file.
Validation at Scale
Validation can be expensive on large datasets. It should be scheduled deliberately.
Common options:
| Strategy | Use case |
|---|---|
| Validate every epoch | Medium datasets |
| Validate every N steps | Long epochs |
| Validate on subset | Fast feedback |
| Full validation before checkpoint | Reliable model selection |
| Final test only once | Unbiased final estimate |
For distributed validation, each process evaluates a shard of the validation set. Then counts and losses are reduced across processes.
Conceptually:
correct_tensor = torch.tensor([local_correct], device=device)
count_tensor = torch.tensor([local_count], device=device)
dist.all_reduce(correct_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM)
global_accuracy = correct_tensor.item() / count_tensor.item()Validation metrics should be computed from summed counts, not by averaging per-process accuracies when batch counts differ.
Reproducibility
Large-scale training is often only approximately reproducible. Parallel data loading, GPU kernels, distributed synchronization, and nondeterministic algorithms can introduce variation.
Still, experiments should record:
config = {
"model": "resnet50",
"weights": "imagenet",
"image_size": 224,
"batch_size": 256,
"optimizer": "adamw",
"lr": 3e-4,
"weight_decay": 1e-4,
"epochs": 100,
"augmentation": "randaugment_erasing",
"precision": "amp_fp16",
"seed": 1234,
}Set seeds when possible:
import random
import numpy as np
import torch
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)Perfect determinism can reduce performance. For large training runs, reproducibility usually means recording enough information to explain and repeat the result within expected variance.
Profiling
Profiling finds where training time is spent. A system may be bottlenecked by data loading, CPU transforms, GPU kernels, synchronization, or disk I/O.
Basic timing:
import time
start = time.time()
for step, batch in enumerate(train_loader):
data_time = time.time() - start
images, labels = batch
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
compute_start = time.time()
logits = model(images)
loss = loss_fn(logits, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
compute_time = time.time() - compute_start
print(f"data_time={data_time:.4f} compute_time={compute_time:.4f}")
start = time.time()If data_time is large, improve data loading. If compute_time dominates but GPU utilization is low, inspect batch size, kernels, model shape, and synchronization.
Memory Management
Large-scale training often hits memory limits before compute limits. Main memory consumers include activations, parameters, gradients, optimizer state, and input batches.
| Method | Effect |
|---|---|
| Mixed precision | Reduces activation and parameter memory |
| Gradient checkpointing | Recomputes activations to save memory |
| Smaller batch size | Reduces activation memory |
| Gradient accumulation | Preserves effective batch size |
| Activation offloading | Moves tensors to CPU or storage |
| Optimizer state sharding | Reduces per-GPU optimizer memory |
For large CNNs and transformers, activation memory often dominates. Gradient checkpointing trades extra compute for lower memory use.
from torch.utils.checkpoint import checkpoint
def forward(self, x):
x = checkpoint(self.block1, x)
x = checkpoint(self.block2, x)
x = self.head(x)
return xUse checkpointing selectively. It can slow training because forward computations are repeated during backpropagation.
Practical Large-Scale Recipe
A stable large-scale image classification setup:
| Component | Practical default |
|---|---|
| Storage | Sharded dataset |
| Loader | Many workers, pinned memory, persistent workers |
| Precision | AMP float16 or bfloat16 |
| Optimizer | SGD with momentum or AdamW |
| Schedule | Warmup plus cosine decay |
| Batch size | Largest stable global batch |
| Regularization | Weight decay, augmentation, label smoothing |
| Distributed method | DDP |
| Checkpointing | Model, optimizer, scheduler, scaler, config |
| Validation | Distributed full validation at fixed intervals |
| Logging | Throughput, loss, learning rate, memory, accuracy |
Start with a small run before launching a large one. A short overfit test on a small subset can verify that the model, loss, labels, and optimizer are wired correctly.
Summary
Large-scale training requires control over data throughput, memory, precision, distributed synchronization, checkpointing, and validation. The model code may look similar to single-GPU training, but the system around it becomes more important.
A reliable large-scale pipeline keeps GPUs fed, uses mixed precision safely, scales batch size and learning rate together, saves resumable checkpoints, validates correctly across workers, and records enough configuration to make results interpretable.