Skip to content

Large-Scale Training

Large-scale training means training models on datasets, model sizes, or hardware configurations that exceed a simple single-GPU workflow.

Large-scale training means training models on datasets, model sizes, or hardware configurations that exceed a simple single-GPU workflow. In image classification, this often means millions of images, large backbones, long schedules, high-resolution inputs, or multi-GPU training.

The goal is not only to make training faster. The goal is to keep optimization stable, data loading efficient, validation reliable, and checkpoints recoverable while the system grows.

What Changes at Scale

A small training job can tolerate inefficient code. A large training job cannot. At scale, small inefficiencies become expensive.

AreaSmall trainingLarge-scale training
DataFits on local diskOften sharded, streamed, cached
Batch sizeTens of imagesHundreds or thousands of images
HardwareCPU or one GPUMultiple GPUs or nodes
PrecisionUsually float32Often mixed precision
CheckpointsOccasional manual savesRegular resumable checkpoints
ValidationSimple loopDistributed, scheduled, logged
Failure handlingRestart from scratchResume from checkpoint
ProfilingOptionalRequired

Large-scale training is a systems problem as much as a modeling problem.

Throughput

Throughput measures how many examples the system processes per second.

throughput=number of images processedelapsed time \text{throughput} = \frac{\text{number of images processed}}{\text{elapsed time}}

For training, throughput depends on data loading, preprocessing, host-to-device transfer, forward pass, backward pass, optimizer step, and synchronization.

A simple measurement:

import time

def measure_throughput(model, loader, loss_fn, optimizer, device, steps=100):
    model.train()

    start = time.time()
    total_images = 0

    for step, (images, labels) in enumerate(loader):
        if step >= steps:
            break

        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        logits = model(images)
        loss = loss_fn(logits, labels)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        total_images += images.size(0)

    elapsed = time.time() - start
    return total_images / elapsed

Use throughput together with validation metrics. A faster system that reaches worse accuracy may not be a better system.

Data Loading Bottlenecks

At scale, the GPU should spend most of its time computing, not waiting for data. Slow data loading is one of the most common causes of poor hardware utilization.

Useful DataLoader settings:

train_loader = DataLoader(
    train_set,
    batch_size=256,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
)

Important options:

OptionPurpose
num_workersUses subprocesses for data loading
pin_memorySpeeds CPU-to-GPU transfer
persistent_workersKeeps workers alive across epochs
prefetch_factorLoads future batches in advance
non_blocking=TrueAllows asynchronous device transfer

Use more workers only while it improves throughput. Too many workers can saturate CPU, disk, or memory bandwidth.

Sharded Datasets

Large datasets are often stored as shards rather than individual image files. A directory containing millions of small files can be slow because metadata access dominates.

Common large-scale storage formats include:

FormatTypical use
Tar shardsSequential image loading
WebDatasetStreaming samples from tar files
LMDBKey-value image storage
TFRecordTensorFlow-style records
ParquetStructured metadata and features
Object storage shardsCloud-scale training

A sharded dataset improves sequential reads and makes distributed training easier. Each worker can read different shards, reducing contention.

A simple shard naming convention:

train-000000.tar
train-000001.tar
train-000002.tar
...

The shard size should balance two costs. Very small shards increase scheduling overhead. Very large shards reduce flexibility and make recovery more expensive.

Large Batch Training

Large-scale image training often uses large batches. A larger batch gives a more stable gradient estimate and better hardware utilization. However, it can also reduce generalization or require learning rate adjustment.

A common rule is linear learning rate scaling:

ηnew=ηbaseBnewBbase. \eta_{\text{new}} = \eta_{\text{base}} \frac{B_{\text{new}}}{B_{\text{base}}}.

For example, if the base learning rate is 0.10.1 for batch size 256256, then for batch size 10241024:

ηnew=0.1×1024256=0.4. \eta_{\text{new}} = 0.1 \times \frac{1024}{256} = 0.4.

Large learning rates should usually be combined with warmup.

Learning Rate Warmup

Warmup gradually increases the learning rate at the start of training. This is useful when training with large batches, mixed precision, or very deep models.

warmup = torch.optim.lr_scheduler.LinearLR(
    optimizer,
    start_factor=0.01,
    total_iters=5,
)

cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=95,
)

scheduler = torch.optim.lr_scheduler.SequentialLR(
    optimizer,
    schedulers=[warmup, cosine],
    milestones=[5],
)

A common schedule is:

PhaseLearning rate behavior
WarmupIncrease from small value
Main trainingCosine decay, step decay, or polynomial decay
Final phaseSmall updates near convergence

Warmup prevents early unstable updates before the network’s activations and normalization statistics settle.

Mixed Precision Training

Mixed precision uses lower precision arithmetic where possible while keeping numerically sensitive parts stable. In PyTorch, this is usually done with automatic mixed precision.

scaler = torch.cuda.amp.GradScaler()

for images, labels in train_loader:
    images = images.to(device, non_blocking=True)
    labels = labels.to(device, non_blocking=True)

    optimizer.zero_grad(set_to_none=True)

    with torch.cuda.amp.autocast():
        logits = model(images)
        loss = loss_fn(logits, labels)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

Mixed precision reduces memory use and can increase throughput. It is especially useful on modern GPUs with tensor cores.

The GradScaler helps prevent gradient underflow when using float16. With bfloat16, explicit gradient scaling is often less important, but the exact behavior depends on hardware and PyTorch configuration.

Gradient Accumulation

Gradient accumulation simulates a larger batch size by accumulating gradients over several smaller mini-batches before taking an optimizer step.

accum_steps = 4

optimizer.zero_grad(set_to_none=True)

for step, (images, labels) in enumerate(train_loader):
    images = images.to(device)
    labels = labels.to(device)

    logits = model(images)
    loss = loss_fn(logits, labels)
    loss = loss / accum_steps

    loss.backward()

    if (step + 1) % accum_steps == 0:
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

If the per-device batch size is 64 and accum_steps = 4, the effective batch size per device is 256.

With distributed training, the effective global batch size is:

Bglobal=Bper device×Ndevices×Naccum steps. B_{\text{global}} = B_{\text{per device}} \times N_{\text{devices}} \times N_{\text{accum steps}}.

This value should be used when choosing the learning rate.

Distributed Data Parallel Training

DistributedDataParallel, usually called DDP, is the standard PyTorch method for multi-GPU training. Each process owns one GPU and one copy of the model. Each process receives a different data shard. Gradients are synchronized across processes during backpropagation.

A minimal DDP setup has these parts:

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup_ddp(rank, world_size):
    dist.init_process_group(
        backend="nccl",
        rank=rank,
        world_size=world_size,
    )
    torch.cuda.set_device(rank)

Wrap the model:

model = model.to(rank)
model = DDP(model, device_ids=[rank])

Use a distributed sampler:

from torch.utils.data.distributed import DistributedSampler

sampler = DistributedSampler(
    train_set,
    num_replicas=world_size,
    rank=rank,
    shuffle=True,
)

train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    sampler=sampler,
    num_workers=8,
    pin_memory=True,
)

At the start of each epoch:

sampler.set_epoch(epoch)

This ensures that shuffling differs across epochs while remaining synchronized across processes.

Checkpointing at Scale

Large training jobs must be resumable. A checkpoint should contain enough state to continue training without changing the optimization trajectory too much.

A useful checkpoint contains:

ItemReason
Model stateLearned parameters
Optimizer stateMomentum and adaptive statistics
Scheduler stateLearning rate schedule position
Grad scaler stateMixed precision state
Epoch and stepResume location
Best metricModel selection
Class mappingCorrect inference labels
ConfigReproducibility

Example:

def save_checkpoint(path, model, optimizer, scheduler, scaler, epoch, step, best_metric, config):
    state = {
        "model": model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict() if scheduler is not None else None,
        "scaler": scaler.state_dict() if scaler is not None else None,
        "epoch": epoch,
        "step": step,
        "best_metric": best_metric,
        "config": config,
    }
    torch.save(state, path)

Only one process should usually write checkpoints in DDP:

if rank == 0:
    save_checkpoint(...)

This avoids multiple processes writing the same file.

Validation at Scale

Validation can be expensive on large datasets. It should be scheduled deliberately.

Common options:

StrategyUse case
Validate every epochMedium datasets
Validate every N stepsLong epochs
Validate on subsetFast feedback
Full validation before checkpointReliable model selection
Final test only onceUnbiased final estimate

For distributed validation, each process evaluates a shard of the validation set. Then counts and losses are reduced across processes.

Conceptually:

correct_tensor = torch.tensor([local_correct], device=device)
count_tensor = torch.tensor([local_count], device=device)

dist.all_reduce(correct_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM)

global_accuracy = correct_tensor.item() / count_tensor.item()

Validation metrics should be computed from summed counts, not by averaging per-process accuracies when batch counts differ.

Reproducibility

Large-scale training is often only approximately reproducible. Parallel data loading, GPU kernels, distributed synchronization, and nondeterministic algorithms can introduce variation.

Still, experiments should record:

config = {
    "model": "resnet50",
    "weights": "imagenet",
    "image_size": 224,
    "batch_size": 256,
    "optimizer": "adamw",
    "lr": 3e-4,
    "weight_decay": 1e-4,
    "epochs": 100,
    "augmentation": "randaugment_erasing",
    "precision": "amp_fp16",
    "seed": 1234,
}

Set seeds when possible:

import random
import numpy as np
import torch

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

Perfect determinism can reduce performance. For large training runs, reproducibility usually means recording enough information to explain and repeat the result within expected variance.

Profiling

Profiling finds where training time is spent. A system may be bottlenecked by data loading, CPU transforms, GPU kernels, synchronization, or disk I/O.

Basic timing:

import time

start = time.time()
for step, batch in enumerate(train_loader):
    data_time = time.time() - start

    images, labels = batch
    images = images.to(device, non_blocking=True)
    labels = labels.to(device, non_blocking=True)

    compute_start = time.time()

    logits = model(images)
    loss = loss_fn(logits, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

    compute_time = time.time() - compute_start

    print(f"data_time={data_time:.4f} compute_time={compute_time:.4f}")

    start = time.time()

If data_time is large, improve data loading. If compute_time dominates but GPU utilization is low, inspect batch size, kernels, model shape, and synchronization.

Memory Management

Large-scale training often hits memory limits before compute limits. Main memory consumers include activations, parameters, gradients, optimizer state, and input batches.

MethodEffect
Mixed precisionReduces activation and parameter memory
Gradient checkpointingRecomputes activations to save memory
Smaller batch sizeReduces activation memory
Gradient accumulationPreserves effective batch size
Activation offloadingMoves tensors to CPU or storage
Optimizer state shardingReduces per-GPU optimizer memory

For large CNNs and transformers, activation memory often dominates. Gradient checkpointing trades extra compute for lower memory use.

from torch.utils.checkpoint import checkpoint

def forward(self, x):
    x = checkpoint(self.block1, x)
    x = checkpoint(self.block2, x)
    x = self.head(x)
    return x

Use checkpointing selectively. It can slow training because forward computations are repeated during backpropagation.

Practical Large-Scale Recipe

A stable large-scale image classification setup:

ComponentPractical default
StorageSharded dataset
LoaderMany workers, pinned memory, persistent workers
PrecisionAMP float16 or bfloat16
OptimizerSGD with momentum or AdamW
ScheduleWarmup plus cosine decay
Batch sizeLargest stable global batch
RegularizationWeight decay, augmentation, label smoothing
Distributed methodDDP
CheckpointingModel, optimizer, scheduler, scaler, config
ValidationDistributed full validation at fixed intervals
LoggingThroughput, loss, learning rate, memory, accuracy

Start with a small run before launching a large one. A short overfit test on a small subset can verify that the model, loss, labels, and optimizer are wired correctly.

Summary

Large-scale training requires control over data throughput, memory, precision, distributed synchronization, checkpointing, and validation. The model code may look similar to single-GPU training, but the system around it becomes more important.

A reliable large-scale pipeline keeps GPUs fed, uses mixed precision safely, scales batch size and learning rate together, saves resumable checkpoints, validates correctly across workers, and records enough configuration to make results interpretable.