# Data Parallelism

Data parallelism is the simplest and most widely used form of distributed deep learning. The idea is to keep a copy of the same model on several devices, feed each device a different part of the batch, compute gradients independently, and then combine the gradients before updating the parameters.

Suppose we have a model with parameters $\theta$, a loss function $\ell$, and a mini-batch

$$
B = \{(x_1, y_1), \ldots, (x_m, y_m)\}.
$$

The mini-batch loss is usually written as

$$
L(\theta) =
\frac{1}{m}
\sum_{i=1}^{m}
\ell(f_\theta(x_i), y_i).
$$

In data parallel training, the batch is split across $K$ devices:

$$
B = B_1 \cup B_2 \cup \cdots \cup B_K.
$$

Each device $k$ receives a local batch $B_k$, computes a local loss,

$$
L_k(\theta) =
\frac{1}{|B_k|}
\sum_{(x_i,y_i)\in B_k}
\ell(f_\theta(x_i), y_i),
$$

and computes local gradients,

$$
g_k = \nabla_\theta L_k(\theta).
$$

The global gradient is obtained by averaging the local gradients:

$$
g =
\frac{1}{K}
\sum_{k=1}^{K} g_k.
$$

Then every device applies the same update:

$$
\theta \leftarrow \theta - \eta g.
$$

The important point is that each device ends the step with the same parameters. If the replicas ever diverge, the training procedure no longer matches ordinary mini-batch gradient descent.

### Why Data Parallelism Works

Data parallelism works because the gradient of an average loss is the average of the gradients. If a mini-batch is split into smaller pieces, each device can compute the gradient for its piece. Averaging these gradients gives the same result as computing the gradient on the whole batch, up to small numerical differences caused by floating-point arithmetic and operation ordering.

This makes data parallelism especially natural for deep learning. Most models apply the same computation independently to each example in a batch. Images in an image batch, sentences in a language batch, and examples in a tabular batch can usually be divided across devices without changing model semantics.

For example, if one GPU can process 64 images at a time, then 8 GPUs can process a global batch of 512 images by assigning 64 images to each GPU. Each GPU performs the same forward and backward pass on a different shard of the batch.

### Local Batch Size and Global Batch Size

Two batch sizes must be distinguished.

The local batch size is the number of examples processed by one device. The global batch size is the total number of examples processed across all devices in one optimization step.

If we have $K$ devices and each device processes $b$ examples, then

$$
\text{global batch size} = Kb.
$$

For example, with 8 GPUs and local batch size 32,

$$
\text{global batch size} = 8 \times 32 = 256.
$$

This distinction matters because optimization depends on the global batch size. Increasing the number of GPUs while keeping the local batch size fixed increases the global batch size. This can change training dynamics, sometimes requiring a larger learning rate, warmup, or a different schedule.

A common rule of thumb is linear learning rate scaling:

$$
\eta_{\text{new}} =
\eta_{\text{old}}
\cdot
\frac{B_{\text{new}}}{B_{\text{old}}}.
$$

This rule often works for moderate batch-size increases, but it is not a theorem. Very large batches may reduce gradient noise too much, harm generalization, or require careful learning rate warmup.

### Synchronous Data Parallelism

The standard form of data parallelism is synchronous data parallelism. All workers compute gradients for their local batches. Then the workers communicate and average gradients. No worker updates its parameters until all workers have contributed.

A single training step looks like this:

1. Copy the same model parameters to all devices.
2. Split the batch across devices.
3. Run the forward pass independently on each device.
4. Compute the local loss on each device.
5. Run the backward pass independently on each device.
6. Average gradients across devices.
7. Apply the optimizer update on every device.

In PyTorch, this is usually implemented with `DistributedDataParallel`, commonly abbreviated as `DDP`.

A minimal single-process-per-GPU DDP training pattern looks like this:

```python
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler

def setup_distributed():
    dist.init_process_group(backend="nccl")

    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)

    return local_rank

def train(rank, model, dataset, epochs):
    sampler = DistributedSampler(dataset, shuffle=True)

    loader = DataLoader(
        dataset,
        batch_size=32,
        sampler=sampler,
        num_workers=4,
        pin_memory=True,
    )

    model = model.cuda(rank)
    model = DDP(model, device_ids=[rank])

    optimizer = optim.AdamW(model.parameters(), lr=3e-4)
    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        sampler.set_epoch(epoch)

        for x, y in loader:
            x = x.cuda(rank, non_blocking=True)
            y = y.cuda(rank, non_blocking=True)

            logits = model(x)
            loss = loss_fn(logits, y)

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()

def main():
    local_rank = setup_distributed()

    model = MyModel()
    dataset = MyDataset()

    train(local_rank, model, dataset, epochs=10)

    dist.destroy_process_group()
```

This script is usually launched with `torchrun`:

```bash
torchrun --nproc_per_node=8 train.py
```

Here `--nproc_per_node=8` starts one process per GPU on an 8-GPU machine.

### Distributed Samplers

When using data parallelism, each process must see a different shard of the dataset. If every process reads the same examples, the effective batch contains duplicate data and the compute is wasted.

PyTorch solves this with `DistributedSampler`.

```python
sampler = DistributedSampler(dataset, shuffle=True)

loader = DataLoader(
    dataset,
    batch_size=32,
    sampler=sampler,
)
```

The sampler partitions the dataset by process rank. If there are 8 processes, rank 0 sees one shard, rank 1 sees another shard, and so on.

At the beginning of every epoch, call:

```python
sampler.set_epoch(epoch)
```

This ensures that shuffling changes across epochs while remaining synchronized across workers.

Without this call, each epoch may use the same ordering, which can subtly reduce training quality.

### Gradient Synchronization

In DDP, gradient synchronization happens during the backward pass. When `loss.backward()` is called, PyTorch computes gradients and uses collective communication to average them across processes.

The main communication operation is usually `all-reduce`.

An all-reduce takes one tensor from each worker, combines the tensors, and returns the combined result to every worker. For gradient averaging, the operation is typically sum followed by division by the number of workers.

Conceptually:

```python
for param in model.parameters():
    dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
    param.grad /= world_size
```

DDP performs this efficiently and overlaps communication with backpropagation. As soon as gradients for a group of parameters are ready, DDP can begin communicating them while the backward pass continues for earlier layers.

This overlap is one reason DDP is usually much faster than older approaches such as `nn.DataParallel`.

### `DataParallel` Versus `DistributedDataParallel`

PyTorch provides two data parallel APIs: `nn.DataParallel` and `DistributedDataParallel`.

`nn.DataParallel` runs multiple devices from one Python process. It is easy to use but usually slower. It suffers from Python overhead, uneven work on the main device, and weaker scaling.

`DistributedDataParallel` runs one process per device. It is the recommended method for serious training. It scales better within one machine and across multiple machines.

| Feature | `nn.DataParallel` | `DistributedDataParallel` |
|---|---|---|
| Process model | One process controls many GPUs | Usually one process per GPU |
| Performance | Lower | Higher |
| Multi-node support | Poor | Standard |
| Communication | Less efficient | Optimized all-reduce |
| Recommended use | Small experiments | Production training |

For almost all modern PyTorch training, use `DistributedDataParallel`.

### Rank, World Size, and Local Rank

Distributed training uses a few standard terms.

The world size is the total number of processes participating in training. If we train on 4 machines with 8 GPUs each, the world size is 32.

The rank is the unique global process ID. Ranks usually range from 0 to `world_size - 1`.

The local rank is the process ID within one machine. On an 8-GPU node, local ranks usually range from 0 to 7.

For example:

| Machine | GPU | Global rank | Local rank |
|---|---:|---:|---:|
| Node 0 | GPU 0 | 0 | 0 |
| Node 0 | GPU 1 | 1 | 1 |
| Node 0 | GPU 7 | 7 | 7 |
| Node 1 | GPU 0 | 8 | 0 |
| Node 1 | GPU 1 | 9 | 1 |
| Node 1 | GPU 7 | 15 | 7 |

The local rank is used to choose the CUDA device:

```python
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
```

The global rank is used for distributed coordination, logging, checkpointing, and deciding which process performs special tasks.

Usually only rank 0 should write logs and checkpoints:

```python
if dist.get_rank() == 0:
    torch.save(model.module.state_dict(), "model.pt")
```

The `.module` is used because the original model is wrapped inside the DDP object.

### Correct Loss Scaling

Loss reduction must be handled carefully.

Suppose each worker computes a mean loss over its local batch. DDP averages gradients across workers. If all local batches have the same size, this matches the gradient of the global mean loss.

However, if local batch sizes differ, especially at the end of an epoch, the simple average of worker gradients may differ from the exact global mean. This is usually small, but it can matter in precise training setups.

A common practical solution is to use `drop_last=True` in the distributed data loader:

```python
loader = DataLoader(
    dataset,
    batch_size=32,
    sampler=sampler,
    drop_last=True,
)
```

This keeps local batch sizes equal.

For most large-scale training runs, dropping the final incomplete batch is acceptable.

### Gradient Accumulation with Data Parallelism

Gradient accumulation simulates a larger batch by accumulating gradients over several forward and backward passes before calling `optimizer.step()`.

If each GPU has local batch size $b$, there are $K$ GPUs, and gradients are accumulated for $A$ steps, then the effective global batch size is

$$
B_{\text{effective}} = bKA.
$$

For example, with local batch size 8, 16 GPUs, and 4 accumulation steps,

$$
B_{\text{effective}} = 8 \times 16 \times 4 = 512.
$$

In DDP, unnecessary synchronization during intermediate accumulation steps should be avoided. PyTorch provides `no_sync()` for this purpose:

```python
for step, (x, y) in enumerate(loader):
    x = x.cuda(local_rank, non_blocking=True)
    y = y.cuda(local_rank, non_blocking=True)

    is_accumulating = (step + 1) % accum_steps != 0

    if is_accumulating:
        with model.no_sync():
            logits = model(x)
            loss = loss_fn(logits, y)
            loss = loss / accum_steps
            loss.backward()
    else:
        logits = model(x)
        loss = loss_fn(logits, y)
        loss = loss / accum_steps
        loss.backward()

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
```

Dividing the loss by `accum_steps` keeps the gradient scale consistent with an ordinary large batch.

### Batch Normalization in Data Parallel Training

Batch normalization computes statistics over a mini-batch. In data parallel training, each GPU sees only its local batch. Therefore ordinary batch normalization computes statistics separately on each GPU.

This can be acceptable when the local batch size is large. But if each GPU receives only a small batch, the batch statistics may be noisy.

PyTorch provides synchronized batch normalization:

```python
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[local_rank])
```

Synchronized batch normalization computes batch statistics across workers. This gives statistics closer to those from the full global batch, but it adds communication overhead.

For many modern architectures, especially transformers, layer normalization is preferred because it does not depend on the batch dimension.

### Checkpointing in Data Parallel Training

Since each process has the same model parameters after each synchronized step, only one process needs to save the model.

A common pattern is:

```python
if dist.get_rank() == 0:
    checkpoint = {
        "model": model.module.state_dict(),
        "optimizer": optimizer.state_dict(),
        "epoch": epoch,
    }
    torch.save(checkpoint, "checkpoint.pt")
```

When loading:

```python
checkpoint = torch.load("checkpoint.pt", map_location="cpu")
model.module.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
```

If the model has not yet been wrapped in DDP, load into the plain model first:

```python
model.load_state_dict(checkpoint["model"])
model = DDP(model.cuda(local_rank), device_ids=[local_rank])
```

Checkpointing must also account for random number generator state, learning rate scheduler state, gradient scaler state for mixed precision, and sampler position if exact resumption is required.

### Common Failure Modes

Data parallel training often fails for operational reasons rather than mathematical reasons.

The first common failure is duplicated data. This happens when a normal `DataLoader` is used without a `DistributedSampler`. Each process trains on the same examples.

The second common failure is excessive logging or checkpointing. If every process writes logs or checkpoints, outputs may be duplicated or corrupted. Usually only rank 0 should perform these operations.

The third common failure is device mismatch. Each process must use its assigned GPU. Forgetting `torch.cuda.set_device(local_rank)` often causes multiple processes to use the same GPU.

The fourth common failure is hidden synchronization cost. Some operations force communication or CPU-GPU synchronization. Frequent calls to `.item()`, excessive metric reduction, or logging every step can reduce throughput.

The fifth common failure is uneven batches. If one worker runs out of data earlier than others, training may hang because collective operations require all workers to participate.

### Measuring Scaling Efficiency

Data parallelism is useful only if adding devices increases throughput. The ideal case is linear scaling: doubling the number of GPUs doubles examples processed per second.

Scaling efficiency can be measured as

$$
\text{efficiency} =
\frac{\text{throughput with }K\text{ devices}}
{K \times \text{throughput with one device}}.
$$

If one GPU processes 1,000 examples per second, then 8 GPUs would ideally process 8,000 examples per second. If the actual throughput is 6,400 examples per second, then the scaling efficiency is

$$
\frac{6400}{8 \times 1000} = 0.8.
$$

So the system has 80 percent scaling efficiency.

Efficiency is reduced by communication, input pipeline bottlenecks, synchronization delays, small local batch sizes, and imbalance across devices.

### When Data Parallelism Is Enough

Data parallelism works best when the model fits on one device and the main goal is to process more examples per unit time. This covers many CNNs, vision transformers, recommender systems, speech models, and moderate-sized language models.

Data parallelism becomes insufficient when the model, optimizer state, activations, or batch no longer fit in device memory. In that case, we need model parallelism, pipeline parallelism, tensor parallelism, sharded optimizers, or fully sharded data parallelism.

Still, data parallelism remains the baseline. More advanced distributed methods often combine with it. For example, a large language model may use tensor parallelism within each layer, pipeline parallelism across groups of layers, and data parallelism across replicas of the full pipeline.

Data parallelism provides the first scaling axis: more devices process more data while preserving the same model computation.

