# Large-Scale Training

Large-scale training means training models on datasets, model sizes, or hardware configurations that exceed a simple single-GPU workflow. In image classification, this often means millions of images, large backbones, long schedules, high-resolution inputs, or multi-GPU training.

The goal is not only to make training faster. The goal is to keep optimization stable, data loading efficient, validation reliable, and checkpoints recoverable while the system grows.

### What Changes at Scale

A small training job can tolerate inefficient code. A large training job cannot. At scale, small inefficiencies become expensive.

| Area | Small training | Large-scale training |
|---|---|---|
| Data | Fits on local disk | Often sharded, streamed, cached |
| Batch size | Tens of images | Hundreds or thousands of images |
| Hardware | CPU or one GPU | Multiple GPUs or nodes |
| Precision | Usually float32 | Often mixed precision |
| Checkpoints | Occasional manual saves | Regular resumable checkpoints |
| Validation | Simple loop | Distributed, scheduled, logged |
| Failure handling | Restart from scratch | Resume from checkpoint |
| Profiling | Optional | Required |

Large-scale training is a systems problem as much as a modeling problem.

### Throughput

Throughput measures how many examples the system processes per second.

$$
\text{throughput} = \frac{\text{number of images processed}}{\text{elapsed time}}
$$

For training, throughput depends on data loading, preprocessing, host-to-device transfer, forward pass, backward pass, optimizer step, and synchronization.

A simple measurement:

```python
import time

def measure_throughput(model, loader, loss_fn, optimizer, device, steps=100):
    model.train()

    start = time.time()
    total_images = 0

    for step, (images, labels) in enumerate(loader):
        if step >= steps:
            break

        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        logits = model(images)
        loss = loss_fn(logits, labels)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        total_images += images.size(0)

    elapsed = time.time() - start
    return total_images / elapsed
```

Use throughput together with validation metrics. A faster system that reaches worse accuracy may not be a better system.

### Data Loading Bottlenecks

At scale, the GPU should spend most of its time computing, not waiting for data. Slow data loading is one of the most common causes of poor hardware utilization.

Useful `DataLoader` settings:

```python
train_loader = DataLoader(
    train_set,
    batch_size=256,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
)
```

Important options:

| Option | Purpose |
|---|---|
| `num_workers` | Uses subprocesses for data loading |
| `pin_memory` | Speeds CPU-to-GPU transfer |
| `persistent_workers` | Keeps workers alive across epochs |
| `prefetch_factor` | Loads future batches in advance |
| `non_blocking=True` | Allows asynchronous device transfer |

Use more workers only while it improves throughput. Too many workers can saturate CPU, disk, or memory bandwidth.

### Sharded Datasets

Large datasets are often stored as shards rather than individual image files. A directory containing millions of small files can be slow because metadata access dominates.

Common large-scale storage formats include:

| Format | Typical use |
|---|---|
| Tar shards | Sequential image loading |
| WebDataset | Streaming samples from tar files |
| LMDB | Key-value image storage |
| TFRecord | TensorFlow-style records |
| Parquet | Structured metadata and features |
| Object storage shards | Cloud-scale training |

A sharded dataset improves sequential reads and makes distributed training easier. Each worker can read different shards, reducing contention.

A simple shard naming convention:

```text
train-000000.tar
train-000001.tar
train-000002.tar
...
```

The shard size should balance two costs. Very small shards increase scheduling overhead. Very large shards reduce flexibility and make recovery more expensive.

### Large Batch Training

Large-scale image training often uses large batches. A larger batch gives a more stable gradient estimate and better hardware utilization. However, it can also reduce generalization or require learning rate adjustment.

A common rule is linear learning rate scaling:

$$
\eta_{\text{new}} = \eta_{\text{base}} \frac{B_{\text{new}}}{B_{\text{base}}}.
$$

For example, if the base learning rate is $0.1$ for batch size $256$, then for batch size $1024$:

$$
\eta_{\text{new}} = 0.1 \times \frac{1024}{256} = 0.4.
$$

Large learning rates should usually be combined with warmup.

### Learning Rate Warmup

Warmup gradually increases the learning rate at the start of training. This is useful when training with large batches, mixed precision, or very deep models.

```python
warmup = torch.optim.lr_scheduler.LinearLR(
    optimizer,
    start_factor=0.01,
    total_iters=5,
)

cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=95,
)

scheduler = torch.optim.lr_scheduler.SequentialLR(
    optimizer,
    schedulers=[warmup, cosine],
    milestones=[5],
)
```

A common schedule is:

| Phase | Learning rate behavior |
|---|---|
| Warmup | Increase from small value |
| Main training | Cosine decay, step decay, or polynomial decay |
| Final phase | Small updates near convergence |

Warmup prevents early unstable updates before the network’s activations and normalization statistics settle.

### Mixed Precision Training

Mixed precision uses lower precision arithmetic where possible while keeping numerically sensitive parts stable. In PyTorch, this is usually done with automatic mixed precision.

```python
scaler = torch.cuda.amp.GradScaler()

for images, labels in train_loader:
    images = images.to(device, non_blocking=True)
    labels = labels.to(device, non_blocking=True)

    optimizer.zero_grad(set_to_none=True)

    with torch.cuda.amp.autocast():
        logits = model(images)
        loss = loss_fn(logits, labels)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
```

Mixed precision reduces memory use and can increase throughput. It is especially useful on modern GPUs with tensor cores.

The `GradScaler` helps prevent gradient underflow when using float16. With bfloat16, explicit gradient scaling is often less important, but the exact behavior depends on hardware and PyTorch configuration.

### Gradient Accumulation

Gradient accumulation simulates a larger batch size by accumulating gradients over several smaller mini-batches before taking an optimizer step.

```python
accum_steps = 4

optimizer.zero_grad(set_to_none=True)

for step, (images, labels) in enumerate(train_loader):
    images = images.to(device)
    labels = labels.to(device)

    logits = model(images)
    loss = loss_fn(logits, labels)
    loss = loss / accum_steps

    loss.backward()

    if (step + 1) % accum_steps == 0:
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
```

If the per-device batch size is 64 and `accum_steps = 4`, the effective batch size per device is 256.

With distributed training, the effective global batch size is:

$$
B_{\text{global}} = B_{\text{per device}} \times N_{\text{devices}} \times N_{\text{accum steps}}.
$$

This value should be used when choosing the learning rate.

### Distributed Data Parallel Training

`DistributedDataParallel`, usually called DDP, is the standard PyTorch method for multi-GPU training. Each process owns one GPU and one copy of the model. Each process receives a different data shard. Gradients are synchronized across processes during backpropagation.

A minimal DDP setup has these parts:

```python
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup_ddp(rank, world_size):
    dist.init_process_group(
        backend="nccl",
        rank=rank,
        world_size=world_size,
    )
    torch.cuda.set_device(rank)
```

Wrap the model:

```python
model = model.to(rank)
model = DDP(model, device_ids=[rank])
```

Use a distributed sampler:

```python
from torch.utils.data.distributed import DistributedSampler

sampler = DistributedSampler(
    train_set,
    num_replicas=world_size,
    rank=rank,
    shuffle=True,
)

train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    sampler=sampler,
    num_workers=8,
    pin_memory=True,
)
```

At the start of each epoch:

```python
sampler.set_epoch(epoch)
```

This ensures that shuffling differs across epochs while remaining synchronized across processes.

### Checkpointing at Scale

Large training jobs must be resumable. A checkpoint should contain enough state to continue training without changing the optimization trajectory too much.

A useful checkpoint contains:

| Item | Reason |
|---|---|
| Model state | Learned parameters |
| Optimizer state | Momentum and adaptive statistics |
| Scheduler state | Learning rate schedule position |
| Grad scaler state | Mixed precision state |
| Epoch and step | Resume location |
| Best metric | Model selection |
| Class mapping | Correct inference labels |
| Config | Reproducibility |

Example:

```python
def save_checkpoint(path, model, optimizer, scheduler, scaler, epoch, step, best_metric, config):
    state = {
        "model": model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict() if scheduler is not None else None,
        "scaler": scaler.state_dict() if scaler is not None else None,
        "epoch": epoch,
        "step": step,
        "best_metric": best_metric,
        "config": config,
    }
    torch.save(state, path)
```

Only one process should usually write checkpoints in DDP:

```python
if rank == 0:
    save_checkpoint(...)
```

This avoids multiple processes writing the same file.

### Validation at Scale

Validation can be expensive on large datasets. It should be scheduled deliberately.

Common options:

| Strategy | Use case |
|---|---|
| Validate every epoch | Medium datasets |
| Validate every N steps | Long epochs |
| Validate on subset | Fast feedback |
| Full validation before checkpoint | Reliable model selection |
| Final test only once | Unbiased final estimate |

For distributed validation, each process evaluates a shard of the validation set. Then counts and losses are reduced across processes.

Conceptually:

```python
correct_tensor = torch.tensor([local_correct], device=device)
count_tensor = torch.tensor([local_count], device=device)

dist.all_reduce(correct_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM)

global_accuracy = correct_tensor.item() / count_tensor.item()
```

Validation metrics should be computed from summed counts, not by averaging per-process accuracies when batch counts differ.

### Reproducibility

Large-scale training is often only approximately reproducible. Parallel data loading, GPU kernels, distributed synchronization, and nondeterministic algorithms can introduce variation.

Still, experiments should record:

```python
config = {
    "model": "resnet50",
    "weights": "imagenet",
    "image_size": 224,
    "batch_size": 256,
    "optimizer": "adamw",
    "lr": 3e-4,
    "weight_decay": 1e-4,
    "epochs": 100,
    "augmentation": "randaugment_erasing",
    "precision": "amp_fp16",
    "seed": 1234,
}
```

Set seeds when possible:

```python
import random
import numpy as np
import torch

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
```

Perfect determinism can reduce performance. For large training runs, reproducibility usually means recording enough information to explain and repeat the result within expected variance.

### Profiling

Profiling finds where training time is spent. A system may be bottlenecked by data loading, CPU transforms, GPU kernels, synchronization, or disk I/O.

Basic timing:

```python
import time

start = time.time()
for step, batch in enumerate(train_loader):
    data_time = time.time() - start

    images, labels = batch
    images = images.to(device, non_blocking=True)
    labels = labels.to(device, non_blocking=True)

    compute_start = time.time()

    logits = model(images)
    loss = loss_fn(logits, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

    compute_time = time.time() - compute_start

    print(f"data_time={data_time:.4f} compute_time={compute_time:.4f}")

    start = time.time()
```

If `data_time` is large, improve data loading. If `compute_time` dominates but GPU utilization is low, inspect batch size, kernels, model shape, and synchronization.

### Memory Management

Large-scale training often hits memory limits before compute limits. Main memory consumers include activations, parameters, gradients, optimizer state, and input batches.

| Method | Effect |
|---|---|
| Mixed precision | Reduces activation and parameter memory |
| Gradient checkpointing | Recomputes activations to save memory |
| Smaller batch size | Reduces activation memory |
| Gradient accumulation | Preserves effective batch size |
| Activation offloading | Moves tensors to CPU or storage |
| Optimizer state sharding | Reduces per-GPU optimizer memory |

For large CNNs and transformers, activation memory often dominates. Gradient checkpointing trades extra compute for lower memory use.

```python
from torch.utils.checkpoint import checkpoint

def forward(self, x):
    x = checkpoint(self.block1, x)
    x = checkpoint(self.block2, x)
    x = self.head(x)
    return x
```

Use checkpointing selectively. It can slow training because forward computations are repeated during backpropagation.

### Practical Large-Scale Recipe

A stable large-scale image classification setup:

| Component | Practical default |
|---|---|
| Storage | Sharded dataset |
| Loader | Many workers, pinned memory, persistent workers |
| Precision | AMP float16 or bfloat16 |
| Optimizer | SGD with momentum or AdamW |
| Schedule | Warmup plus cosine decay |
| Batch size | Largest stable global batch |
| Regularization | Weight decay, augmentation, label smoothing |
| Distributed method | DDP |
| Checkpointing | Model, optimizer, scheduler, scaler, config |
| Validation | Distributed full validation at fixed intervals |
| Logging | Throughput, loss, learning rate, memory, accuracy |

Start with a small run before launching a large one. A short overfit test on a small subset can verify that the model, loss, labels, and optimizer are wired correctly.

### Summary

Large-scale training requires control over data throughput, memory, precision, distributed synchronization, checkpointing, and validation. The model code may look similar to single-GPU training, but the system around it becomes more important.

A reliable large-scale pipeline keeps GPUs fed, uses mixed precision safely, scales batch size and learning rate together, saves resumable checkpoints, validates correctly across workers, and records enough configuration to make results interpretable.

