Skip to content

Distributed Data Parallel

Distributed Data Parallel, usually abbreviated as DDP, is PyTorch’s primary system for synchronous multi-GPU training.

Distributed Data Parallel, usually abbreviated as DDP, is PyTorch’s primary system for synchronous multi-GPU training. DDP extends ordinary data parallelism to distributed environments while minimizing Python overhead and communication inefficiency.

The central design principle is simple: one process controls one device. Each process owns a complete replica of the model, computes gradients locally, and synchronizes gradients with other processes during backpropagation.

Compared with older single-process approaches such as nn.DataParallel, DDP achieves substantially better scaling because computation and communication are distributed across independent worker processes.

From Single-GPU Training to DDP

A standard single-GPU training loop looks like this:

model = MyModel().cuda()

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

for x, y in loader:
    x = x.cuda()
    y = y.cuda()

    logits = model(x)
    loss = loss_fn(logits, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In DDP, the core logic remains nearly identical. The main differences are:

  1. Multiple processes are launched.
  2. Each process is assigned one GPU.
  3. The model is wrapped with DistributedDataParallel.
  4. The dataset is partitioned across processes.

The training loop itself changes very little.

This is one reason DDP became the standard distributed training system in PyTorch. It preserves the mental model of ordinary training while scaling to many devices and machines.

Distributed Process Groups

DDP depends on a distributed process group. A process group is a collection of worker processes that can communicate with each other using collective operations such as:

  • all-reduce
  • broadcast
  • gather
  • scatter
  • barrier

PyTorch initializes the process group with:

import torch.distributed as dist

dist.init_process_group(
    backend="nccl"
)

The backend determines how communication is implemented.

Common backends include:

BackendTypical use
ncclGPU training on NVIDIA hardware
glooCPU training and debugging
mpiHPC clusters with MPI
uccUnified communication systems

For modern GPU training, nccl is almost always preferred because it is optimized for high-bandwidth GPU communication.

World Size and Rank

Every process in DDP receives two important identifiers.

The world size is the total number of participating processes.

The rank is the unique process ID.

If we launch training on 8 GPUs:

GPURank
GPU 00
GPU 11
GPU 22
GPU 33
GPU 44
GPU 55
GPU 66
GPU 77

Rank 0 is usually treated as the primary worker. It commonly handles:

  • checkpoint saving
  • logging
  • metric printing
  • validation summaries

Example:

if dist.get_rank() == 0:
    print("Saving checkpoint")

Without this condition, every process may try to write the same file simultaneously.

Launching Distributed Training

DDP training is normally launched with torchrun.

Example:

torchrun --nproc_per_node=8 train.py

This command launches 8 independent Python processes.

PyTorch automatically sets several environment variables:

VariableMeaning
RANKGlobal process rank
WORLD_SIZETotal number of processes
LOCAL_RANKGPU index on current machine
MASTER_ADDRAddress of primary node
MASTER_PORTCommunication port

The training script reads these variables:

import os

rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])

Then the correct GPU is selected:

torch.cuda.set_device(local_rank)

Each process must exclusively control its assigned GPU.

Wrapping Models with DDP

After creating the model, we wrap it:

from torch.nn.parallel import DistributedDataParallel as DDP

model = MyModel().cuda(local_rank)

model = DDP(
    model,
    device_ids=[local_rank]
)

The wrapped model behaves almost exactly like the original model.

Forward pass:

logits = model(x)

Backward pass:

loss.backward()

DDP automatically intercepts gradients during backpropagation and synchronizes them across workers.

The optimizer remains unchanged:

optimizer.step()

This design keeps distributed training close to ordinary PyTorch code.

How Gradient Synchronization Works

The key operation inside DDP is gradient all-reduce.

Suppose parameter tensor WW exists on every worker. During backpropagation, each worker computes its local gradient:

g1,g2,,gK. g_1,\quad g_2,\quad \ldots,\quad g_K.

DDP computes:

g=1Kk=1Kgk. g = \frac{1}{K} \sum_{k=1}^{K} g_k.

Then every worker replaces its local gradient with the averaged result.

Because all workers start with identical parameters and apply identical updates, parameter replicas remain synchronized.

Conceptually:

for param in model.parameters():
    dist.all_reduce(param.grad)

    param.grad /= world_size

DDP performs this automatically and efficiently.

Bucketing and Communication Overlap

Naively synchronizing gradients after the entire backward pass would waste time. GPUs would sit idle waiting for communication.

DDP avoids this by using gradient bucketing.

Parameters are grouped into buckets. As soon as gradients for one bucket are ready, DDP begins communicating them while backpropagation continues for remaining layers.

This overlaps:

  • gradient computation
  • communication
  • parameter synchronization

The result is much better scaling efficiency.

Large transformer models depend heavily on this overlap. Without it, communication overhead would dominate training time.

Backward Hooks

DDP internally uses autograd hooks.

A hook is attached to each parameter tensor. When autograd finishes computing a parameter gradient, the hook triggers communication for that parameter’s bucket.

Conceptually:

def hook(grad):
    synchronize_gradient(grad)

param.register_hook(hook)

This integration with autograd is one reason DDP scales efficiently while preserving PyTorch’s eager execution model.

Static Graph Assumptions

DDP assumes that the same parameters participate in the backward pass across workers.

If one worker skips a parameter while another uses it, synchronization can deadlock.

Dynamic control flow sometimes violates this assumption:

if random.random() > 0.5:
    y = branch_a(x)
else:
    y = branch_b(x)

If different workers take different branches, some gradients may be missing.

PyTorch provides:

find_unused_parameters=True

Example:

model = DDP(
    model,
    device_ids=[local_rank],
    find_unused_parameters=True
)

This allows DDP to track unused parameters, but it introduces extra overhead.

Whenever possible, distributed training graphs should remain structurally consistent across workers.

Distributed Data Loading

Each worker must receive different data.

DDP training normally uses:

from torch.utils.data.distributed import DistributedSampler

Example:

sampler = DistributedSampler(
    dataset,
    shuffle=True
)

loader = DataLoader(
    dataset,
    batch_size=32,
    sampler=sampler
)

Without a distributed sampler, every worker may process the same mini-batch, wasting compute and harming optimization.

At the beginning of each epoch:

sampler.set_epoch(epoch)

This synchronizes shuffling across workers.

Initialization Synchronization

At startup, DDP ensures all workers begin with identical parameters.

Typically rank 0 initializes the model, then parameters are broadcast:

dist.broadcast(param.data, src=0)

This guarantees consistency across replicas.

If workers start with different random initialization, averaging gradients would no longer correspond to valid synchronized optimization.

Distributed Loss Computation

Each worker computes its own local loss.

Suppose worker kk computes:

Lk. L_k.

The backward pass uses the local loss directly. DDP synchronizes gradients, not losses.

Therefore this is correct:

loss.backward()

No explicit averaging of the loss is required for optimization correctness.

However, for logging and metrics, losses are often averaged across workers:

loss_tensor = torch.tensor(
    loss.item(),
    device=local_rank
)

dist.all_reduce(loss_tensor)

loss_tensor /= world_size

This produces a global mean loss for reporting.

Validation in Distributed Training

Validation may also be distributed.

Each worker evaluates a subset of validation data:

with torch.no_grad():
    logits = model(x)

Metrics are then aggregated across workers.

For example, total correct predictions:

correct = torch.tensor(correct_count).cuda()

dist.all_reduce(correct)

This allows validation throughput to scale with the number of GPUs.

Some projects instead run validation only on rank 0 to simplify implementation.

Checkpointing with DDP

The DDP wrapper stores the original model in:

model.module

Therefore checkpoints usually save:

torch.save(
    model.module.state_dict(),
    "checkpoint.pt"
)

Loading:

state_dict = torch.load("checkpoint.pt")

model.load_state_dict(state_dict)

Only rank 0 should save checkpoints:

if dist.get_rank() == 0:
    save_checkpoint()

Otherwise multiple workers may overwrite the same file.

Multi-Node DDP

DDP also supports training across multiple machines.

Example cluster:

NodeGPUs
Node 08
Node 18
Node 28
Node 38

Total world size:

32. 32.

Workers communicate over network interconnects such as:

  • InfiniBand
  • NVLink
  • Ethernet
  • RoCE

Multi-node training introduces additional concerns:

  • network bandwidth
  • latency
  • node failure
  • clock skew
  • distributed filesystem performance

At large scale, communication topology becomes critically important.

Communication Bottlenecks

DDP scales well only when communication cost remains manageable.

Communication overhead grows with:

  • parameter count
  • gradient size
  • synchronization frequency
  • network latency

Large transformer models may contain billions of parameters. Synchronizing gradients for every step can dominate runtime.

Common mitigation strategies include:

TechniquePurpose
Mixed precisionReduce communication volume
Gradient compressionCompress transmitted gradients
Larger batchesIncrease computation-to-communication ratio
Faster interconnectsReduce transfer latency
Bucket tuningImprove overlap efficiency

Modern large-scale training systems spend enormous engineering effort minimizing communication overhead.

Failure Handling

Distributed systems fail more often than single-device systems.

Possible failures include:

  • GPU out-of-memory
  • network interruption
  • worker crashes
  • NCCL hangs
  • deadlocks
  • filesystem corruption

A common debugging tool:

export NCCL_DEBUG=INFO

This enables verbose NCCL logging.

Barrier synchronization is also useful:

dist.barrier()

A barrier blocks all workers until every process reaches the synchronization point.

Barriers help isolate where distributed programs hang.

DDP Versus Fully Sharded Training

DDP replicates the full model on every GPU.

If the model contains PP parameters, then every GPU stores:

  • parameters
  • gradients
  • optimizer states

Memory usage scales poorly for very large models.

This limitation motivated:

  • Fully Sharded Data Parallel (FSDP)
  • ZeRO optimization
  • tensor parallelism
  • pipeline parallelism

These systems partition model state across devices instead of fully replicating it.

Still, ordinary DDP remains the standard baseline because it is simple, stable, and highly effective when the model fits on one GPU.

Why DDP Became the Standard

DDP succeeded because it aligns closely with the structure of deep learning workloads.

Neural networks naturally process batches independently. Gradients are additive across examples. Autograd systems already organize computation as graphs. GPUs excel at dense tensor operations.

DDP exploits all of these properties while preserving PyTorch’s imperative programming style.

From the user perspective, distributed training often requires only four conceptual changes:

  1. launch multiple processes
  2. initialize a process group
  3. shard the dataset
  4. wrap the model with DDP

Everything else remains close to ordinary PyTorch training.

That simplicity made DDP one of the most important infrastructure abstractions in modern deep learning.