Skip to content

Multi-Node Training

Multi-node training uses more than one machine for a single training job. Each machine contributes one or more accelerators, and all machines cooperate to train the same model.

Multi-node training uses more than one machine for a single training job. Each machine contributes one or more accelerators, and all machines cooperate to train the same model.

A node is one physical or virtual machine. A typical node may contain 4 or 8 GPUs. If we train on 4 nodes with 8 GPUs each, the job uses 32 GPUs.

world size=number of nodes×GPUs per node. \text{world size} = \text{number of nodes} \times \text{GPUs per node}.

For 4 nodes and 8 GPUs per node:

world size=4×8=32. \text{world size} = 4 \times 8 = 32.

Multi-node training extends the ideas from single-node distributed training. We still need ranks, process groups, distributed samplers, gradient synchronization, checkpointing, and failure handling. The difference is that communication now crosses machine boundaries, so networking becomes a first-class part of the training system.

Why Multi-Node Training Is Used

A single machine can provide only limited compute and memory. Multi-node training is used when a job needs:

NeedExplanation
More throughputMore GPUs process more examples per second
Larger global batch sizeMany workers contribute to one optimizer step
Larger model capacityModel, optimizer state, or activations may be partitioned
Faster experimentationShorter wall-clock training time
Large-scale pretrainingFoundation models require many accelerator-hours

For small models, multi-node training may add unnecessary complexity. For large models, it becomes unavoidable.

Process Layout

The common layout is one process per GPU.

If each node has 8 GPUs and we use 4 nodes, we launch 32 processes.

Each process receives:

IdentifierMeaning
RANKGlobal process ID
LOCAL_RANKGPU index within the current node
WORLD_SIZETotal process count
LOCAL_WORLD_SIZEProcess count within the current node

Example:

NodeLocal GPUGlobal rankLocal rank
0000
0111
0777
1080
17157
37317

The global rank identifies the process across the whole job. The local rank selects the GPU on the current node.

Launching with torchrun

PyTorch commonly launches multi-node jobs with torchrun.

For a 4-node job with 8 GPUs per node:

torchrun \
  --nnodes=4 \
  --nproc_per_node=8 \
  --node_rank=0 \
  --master_addr=10.0.0.5 \
  --master_port=29500 \
  train.py

On node 1, change only --node_rank:

torchrun \
  --nnodes=4 \
  --nproc_per_node=8 \
  --node_rank=1 \
  --master_addr=10.0.0.5 \
  --master_port=29500 \
  train.py

The same applies to nodes 2 and 3.

The master address points to the node that coordinates rendezvous. It does not mean that rank 0 performs all training. After initialization, all ranks participate in computation.

Initialization Code

Inside the training script, initialization is almost the same as single-node DDP.

import os
import torch
import torch.distributed as dist

def init_distributed():
    dist.init_process_group(backend="nccl")

    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    return rank, local_rank, world_size

Then wrap the model:

from torch.nn.parallel import DistributedDataParallel as DDP

rank, local_rank, world_size = init_distributed()

model = MyModel().cuda(local_rank)
model = DDP(model, device_ids=[local_rank])

The same code can often run on one node or many nodes, provided it is launched correctly.

Network Communication

Multi-node training depends heavily on network performance.

Within one node, GPUs may communicate through PCIe, NVLink, or NVSwitch. Across nodes, communication uses the network interface.

Common interconnects include:

InterconnectTypical use
EthernetGeneral clusters, lower cost
RoCERDMA over converged Ethernet
InfiniBandHigh-performance GPU clusters
Cloud provider fabricManaged accelerator clusters

Gradient synchronization can require transferring large tensors every training step. If the network is slow, GPUs spend time waiting for communication instead of computing.

The most important network properties are:

PropertyMeaning
BandwidthHow many bytes can move per second
LatencyHow long a message takes to start
TopologyWhich nodes communicate efficiently
CongestionWhether other jobs share the fabric
ReliabilityWhether long jobs survive without resets

NCCL and GPU Communication

For NVIDIA GPU training, PyTorch usually uses NCCL. NCCL provides optimized collective communication primitives such as all-reduce, broadcast, reduce-scatter, and all-gather.

A standard setup uses:

dist.init_process_group(backend="nccl")

Useful NCCL debugging variables include:

export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=INIT,NET

On multi-node systems, NCCL must choose the correct network interface. Sometimes this must be set explicitly:

export NCCL_SOCKET_IFNAME=eth0

or for another interface:

export NCCL_SOCKET_IFNAME=ib0

Wrong interface selection is a common source of slow training or initialization failure.

Distributed Sampling Across Nodes

The dataset must be partitioned across all ranks, not just across GPUs within one node.

Use DistributedSampler with the global world size:

from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

sampler = DistributedSampler(
    dataset,
    num_replicas=world_size,
    rank=rank,
    shuffle=True,
)

loader = DataLoader(
    dataset,
    batch_size=local_batch_size,
    sampler=sampler,
    num_workers=8,
    pin_memory=True,
)

At each epoch:

sampler.set_epoch(epoch)

This ensures each rank receives a unique shard and that shuffling changes between epochs.

Global Batch Size

In multi-node training, the global batch size is:

Bglobal=Blocal×G×N×A, B_{\text{global}} = B_{\text{local}} \times G \times N \times A,

where:

SymbolMeaning
BlocalB_{\text{local}}Batch size per GPU
GGGPUs per node
NNNumber of nodes
AAGradient accumulation steps

For example, with local batch size 8, 8 GPUs per node, 4 nodes, and 2 accumulation steps:

Bglobal=8×8×4×2=512. B_{\text{global}} = 8 \times 8 \times 4 \times 2 = 512.

Changing the number of nodes changes the global batch unless local batch or accumulation is adjusted. This may require learning rate retuning.

Checkpointing in Multi-Node Jobs

Checkpointing becomes more expensive at multi-node scale.

For ordinary DDP, every rank has a full model replica. In that case, rank 0 can save the model:

if rank == 0:
    torch.save(
        {
            "model": model.module.state_dict(),
            "optimizer": optimizer.state_dict(),
            "step": step,
        },
        "checkpoint.pt",
    )

For sharded training, every rank may hold only part of the state. Then distributed checkpointing is required. Each rank writes its shard, and a manifest records how shards fit together.

A robust checkpoint directory may look like:

checkpoint_0001000/
  manifest.json
  rank_00000.pt
  rank_00001.pt
  rank_00002.pt
  ...
  rank_00031.pt
  COMPLETE

The COMPLETE marker indicates that all shards were written successfully.

Validation and Metrics

Metrics must be aggregated across all ranks.

For classification accuracy:

correct = torch.tensor(local_correct, device=local_rank)
total = torch.tensor(local_total, device=local_rank)

dist.all_reduce(correct, op=dist.ReduceOp.SUM)
dist.all_reduce(total, op=dist.ReduceOp.SUM)

accuracy = correct / total

Only rank 0 should print the result:

if rank == 0:
    print(f"accuracy: {accuracy.item():.4f}")

Loss values can also be averaged across ranks. For exact weighting, reduce both total loss and number of examples rather than averaging per-rank averages.

Avoiding Stragglers

A straggler is a worker that runs slower than the others. In synchronous training, all workers must wait for the slowest one.

Stragglers can come from:

CauseExample
Hardware variationOne GPU throttles
Data imbalanceOne rank processes longer samples
Slow storageOne node reads data slowly
Network congestionOne link has lower bandwidth
Background processesCPU contention on one node

For language model training, sequence length variation can cause stragglers. If one rank receives unusually long sequences, its forward and backward pass may take longer.

Mitigation strategies include:

StrategyEffect
Length-based batchingReduces per-batch variation
Data prefetchingHides input latency
Balanced shardingAvoids uneven datasets
Health monitoringDetects slow nodes
Dedicated interconnectReduces congestion

Multi-Node Failure Modes

Multi-node training introduces failure modes that rarely appear on one machine.

Common failures include:

FailureSymptom
Wrong master addressProcesses cannot rendezvous
Port blockedInitialization hangs
Wrong network interfaceNCCL timeout or very slow training
Rank mismatchSome workers wait forever
Different code versionsSilent divergence or crashes
Filesystem inconsistencyCheckpoint load fails
Clock or timeout issuesSpurious process failure

A common debugging sequence is:

export NCCL_DEBUG=INFO
export TORCH_DISTRIBUTED_DEBUG=DETAIL

Then verify:

  1. all nodes can reach MASTER_ADDR:MASTER_PORT
  2. each node sees the expected number of GPUs
  3. all nodes run the same code and dependency versions
  4. the dataset paths are valid on every node
  5. the network interface is correct

Reproducibility Across Nodes

Reproducibility is harder across nodes than on a single GPU.

Sources of variation include:

  • different communication ordering
  • non-deterministic kernels
  • different worker restart timing
  • filesystem ordering
  • random data augmentation
  • floating-point reduction order

Set seeds per rank carefully:

base_seed = 1234
seed = base_seed + rank

torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

This gives each rank a distinct but reproducible random stream.

For exact reproducibility, save RNG states in checkpoints. For most large training runs, statistical reproducibility is more practical than bitwise identity.

Combining Multi-Node with Other Parallelisms

Multi-node training is an execution environment, not a single parallelism strategy.

A large run may combine:

ParallelismRole
Data parallelismReplicate training across data shards
Tensor parallelismSplit large matrix operations
Pipeline parallelismSplit model layers into stages
Sharded data parallelismPartition parameters and optimizer states

For example, a 64-GPU job may use:

DimensionValue
Data parallel groups8
Tensor parallel size4
Pipeline parallel size2
Total GPUs8×4×2=648 \times 4 \times 2 = 64

Each rank belongs to several communication groups. One group handles data parallel gradient synchronization. Another handles tensor-parallel collectives. Another handles pipeline transfers.

Correct group construction is one of the main complexities in large-scale training systems.

When Multi-Node Training Is Worth It

Use multi-node training when a single node cannot provide enough compute, memory, or throughput.

Avoid it when:

SituationReason
Model trains quickly on one nodeAdded complexity gives little benefit
Dataset pipeline is slowMore GPUs will wait for data
Network is weakCommunication dominates
Code is still unstableDebugging becomes harder
Batch-size scaling is poorMore workers may harm optimization

A practical progression is:

  1. make the model correct on one GPU
  2. scale to all GPUs on one node
  3. verify DDP correctness and throughput
  4. scale to multiple nodes
  5. add sharding, tensor parallelism, or pipeline parallelism only when needed

Multi-node training is mainly an engineering multiplier. It turns more hardware into faster training only when the software, data pipeline, network, and optimization setup scale together.