# Distributed Gradient Computation

## Distributed Gradient Computation

Distributed gradient computation appears when a differentiable program no longer fits comfortably on one device or one machine. The reason may be model size, data volume, sequence length, simulation resolution, or throughput requirements. At that point, automatic differentiation becomes a distributed systems problem.

The derivative rules remain local. The execution is no longer local.

A distributed AD system must compute gradients while coordinating:

- partitioned tensors,
- replicated parameters,
- communication collectives,
- remote memory,
- failure modes,
- synchronization,
- and numerical consistency.

### Basic Setting

Suppose the objective is:

$$
L(\theta) = \frac{1}{N}\sum_{i=1}^{N} \ell(x_i, \theta).
$$

A single machine can compute:

$$
\nabla_\theta L(\theta).
$$

In distributed training, the dataset or computation is split across workers. Worker $k$ computes a local gradient:

$$
g_k = \nabla_\theta L_k(\theta).
$$

The global gradient is then:

$$
g = \frac{1}{K}\sum_{k=1}^{K} g_k.
$$

The mathematical operation is simple. The systems problem is not.

### Data Parallel Gradient Computation

In data parallelism, each worker stores a full copy of the model.

Each worker receives a different minibatch shard.

The steps are:

1. run forward pass on local data,
2. run backward pass locally,
3. synchronize gradients,
4. apply the same optimizer update on every worker.

For $K$ workers:

$$
g = \frac{1}{K}(g_1 + g_2 + \cdots + g_K).
$$

This is the most common distributed training pattern because it preserves the single-device programming model.

### All-Reduce

Gradient synchronization usually uses all-reduce.

All-reduce computes a reduction across workers and returns the result to every worker.

For gradients:

$$
(g_1, g_2, \dots, g_K)
\mapsto
\sum_{k=1}^{K} g_k.
$$

Every worker receives the same summed gradient.

After division by $K$, each worker applies the same update.

### Ring All-Reduce

Ring all-reduce partitions gradients into chunks.

Each worker sends chunks to neighbors in a ring.

The algorithm has two phases:

| Phase | Purpose |
|---|---|
| Reduce-scatter | Sum chunks and distribute ownership |
| All-gather | Share reduced chunks with all workers |

Ring all-reduce uses bandwidth efficiently, but latency grows with the number of workers.

### Communication Cost

If the parameter vector has size:

$$
|\theta|,
$$

then each worker must communicate gradient data proportional to:

$$
O(|\theta|).
$$

For large models, communication can dominate computation.

A model with billions of parameters may produce gradient tensors of many gigabytes per step.

### Gradient Bucketing

Systems reduce synchronization overhead by grouping gradients into buckets.

Instead of communicating each tensor separately, they communicate larger contiguous buffers.

This improves bandwidth efficiency.

Backward computation can overlap with communication:

- early layer gradients become ready late,
- later layer gradients become ready early,
- buckets are sent as soon as complete.

This overlap hides communication latency.

### Model Parallelism

Data parallelism fails when the model does not fit on one device.

Model parallelism partitions the model itself.

Common forms:

| Form | Partition |
|---|---|
| Tensor parallelism | Split tensor operations |
| Pipeline parallelism | Split layers into stages |
| Expert parallelism | Split experts across workers |
| Parameter sharding | Split parameters and optimizer state |

Distributed AD must compute correct gradients across these partitions.

### Tensor Parallel Gradients

For matrix multiplication:

$$
Y = XW.
$$

The backward rules are:

$$
\bar{X} = \bar{Y}W^T,
$$

$$
\bar{W} = X^T\bar{Y}.
$$

If $W$ is split across devices, then $\bar{X}$ may require summing contributions from multiple workers.

Thus the backward pass introduces communication even when the forward pass appears local.

Partitioning choices determine communication cost.

### Pipeline Parallel Gradients

Pipeline parallelism divides layers across workers.

A simple pipeline:

$$
x \to S_1 \to S_2 \to S_3 \to y.
$$

Forward activations move forward across stages.

Backward gradients move backward:

$$
\bar{y} \to S_3 \to S_2 \to S_1.
$$

Each stage computes gradients for its local parameters.

The main challenge is scheduling microbatches so devices do not sit idle.

### Pipeline Bubbles

Pipeline execution has idle periods called bubbles.

A worker may wait for:

- input activations during the forward pass,
- output adjoints during the backward pass,
- synchronization at step boundaries.

More microbatches reduce bubble overhead but increase activation memory.

### Parameter Sharding

Parameter sharding partitions model state across workers.

Instead of every worker storing:

- parameters,
- gradients,
- optimizer state,

each worker stores only a shard.

This reduces memory.

However, workers must fetch or assemble parameter shards during computation.

The backward pass must return gradient shards to the owning workers.

### Optimizer State Sharding

Optimizers such as Adam store extra state.

For each parameter:

$$
m_t,
\quad
v_t.
$$

These moment estimates can require more memory than the parameters themselves.

Sharding optimizer state is often necessary for very large models.

### Reduce-Scatter

Reduce-scatter combines reduction and partitioning.

Given local gradients, it computes the summed gradient and leaves each worker with only one shard.

This is useful when gradients are sharded.

Instead of every worker receiving the full gradient, each receives the part it owns.

### All-Gather

All-gather assembles distributed shards.

During forward or backward computation, a worker may need parameter shards owned by others.

All-gather collects those shards.

Together, reduce-scatter and all-gather form the basis of many memory-efficient distributed training systems.

### Communication as a Differentiable Operation

Distributed communication operations have adjoints.

For example, forward broadcast:

$$
y_k = x
$$

for each worker $k$.

The backward operation is reduction:

$$
\bar{x} = \sum_k \bar{y}_k.
$$

Forward scatter has backward gather.

Forward gather has backward scatter.

Forward all-reduce often has an all-reduce-like backward.

Distributed AD systems must encode these adjoints explicitly.

### Synchronization Semantics

Distributed gradient computation can be synchronous or asynchronous.

In synchronous training, all workers compute gradients for the same parameter version.

The update is:

$$
\theta_{t+1} =
\theta_t -
\eta
\frac{1}{K}
\sum_k g_k(\theta_t).
$$

This is simple and stable, but slow workers delay everyone.

In asynchronous training, workers may compute with stale parameters:

$$
g_k(\theta_{t-\tau_k}).
$$

This improves throughput but changes the optimization algorithm.

### Stale Gradients

Staleness means a gradient was computed using old parameters.

The update becomes:

$$
\theta_{t+1} =
\theta_t -
\eta g(\theta_{t-\tau}).
$$

Large staleness can destabilize optimization.

The gradient may point in a direction that no longer reduces the current loss.

### Numerical Effects of Distribution

Distributed execution changes floating point results.

Gradient sums may occur in different orders depending on:

- reduction topology,
- worker count,
- network scheduling,
- tensor partitioning,
- bucket size.

Because floating point addition is not associative, distributed gradients may differ from single-device gradients.

Usually the differences are small. In sensitive systems, they can affect convergence.

### Gradient Compression

Communication can be reduced by compressing gradients.

Methods include:

| Method | Idea |
|---|---|
| Quantization | Use fewer bits per gradient |
| Sparsification | Send only selected entries |
| Low-rank approximation | Send factorized updates |
| Error feedback | Accumulate compression residuals |

Compression reduces bandwidth cost but changes optimization behavior.

AD computes the local gradient; compression modifies the distributed update.

### Fault Tolerance

Distributed AD systems may run for hours, days, or weeks.

Failures are expected.

Fault tolerance requires:

- checkpointing parameters,
- checkpointing optimizer state,
- recording data position,
- restoring random number state,
- recovering distributed topology.

A gradient computation is not reproducible if recovery changes the training stream.

### Stragglers

A straggler is a slow worker.

In synchronous training, one straggler delays the whole step.

Causes include:

- hardware variation,
- network congestion,
- data loading imbalance,
- uneven sparse routing.

Systems handle stragglers through better scheduling, elastic training, or asynchronous updates.

### Load Imbalance in Sparse Models

Mixture-of-experts models route tokens to different experts.

Some experts may receive more tokens than others.

This creates load imbalance.

The backward pass inherits the same imbalance because gradients must flow through the selected experts.

Routing therefore affects both performance and gradient variance.

### Distributed Checkpointing

Large distributed models cannot checkpoint by writing one file from one worker.

Each worker writes its own shard.

A consistent checkpoint must capture:

| State | Reason |
|---|---|
| Parameter shards | Model recovery |
| Optimizer shards | Training continuity |
| Scheduler state | Learning rate correctness |
| RNG state | Reproducibility |
| Data loader state | Batch sequence recovery |

Checkpoint consistency is part of distributed AD correctness.

### Gradient Accumulation Across Microbatches

When global batch size is too large for memory, workers accumulate gradients across microbatches.

For microbatches $j=1,\dots,M$:

$$
g_k =
\sum_{j=1}^{M}
\nabla_\theta L_{k,j}.
$$

Then workers synchronize.

This reduces communication frequency and activation memory pressure.

However, it changes latency and may affect optimizer semantics if scaling is handled incorrectly.

### Large Batch Effects

Increasing worker count often increases global batch size.

Large batches reduce gradient noise.

This may require learning-rate adjustment.

The usual linear scaling rule suggests:

$$
\eta_K \approx K\eta_1.
$$

But this is heuristic. Very large batches may harm generalization or convergence.

### Distributed Autograd Engines

A distributed autograd engine must track dependencies across machines.

It must know:

- where tensors live,
- which worker owns each operation,
- how adjoints flow across communication edges,
- when a remote gradient is ready,
- how to avoid deadlock.

The backward graph is no longer a local data structure.

It is a distributed execution plan.

### Deadlocks

Distributed backward passes can deadlock if communication order is inconsistent.

Example:

- worker A waits to receive from B,
- worker B waits to receive from A.

Correct systems impose global ordering or use nonblocking communication carefully.

### Overlap of Backward and Communication

In reverse mode, gradients become ready layer by layer.

Distributed systems exploit this by communicating early gradients while computing later ones.

This requires bucket scheduling.

Good overlap can make communication nearly invisible.

Poor overlap leaves devices idle.

### Communication Topology

Physical topology matters.

Workers may be connected by:

- PCIe,
- NVLink,
- InfiniBand,
- Ethernet,
- TPU interconnects.

A theoretically efficient algorithm may perform poorly if it ignores topology.

Distributed AD runtimes therefore map collectives to hardware-aware communication patterns.

### Consistency of Randomness

Distributed systems often use stochastic operations.

Examples:

- dropout,
- sampling,
- randomized data augmentation.

Correct distributed reproducibility requires careful RNG partitioning.

Each worker needs independent but reproducible random streams.

Changing worker count should not silently change statistical assumptions unless intended.

### Verification

Distributed gradients are harder to test than local gradients.

Common checks include:

| Check | Purpose |
|---|---|
| Compare single-device and distributed gradients | Validate partitioning |
| Test small models exactly | Catch communication bugs |
| Disable compression | Establish baseline |
| Run deterministic reductions | Debug numerical drift |
| Check gradient norms per shard | Detect missing contributions |

Most distributed AD bugs are not calculus errors. They are ownership, synchronization, or scaling errors.

### Core Idea

Distributed gradient computation extends automatic differentiation across devices and machines. The derivative rules remain local, but gradient execution depends on communication, partitioning, synchronization, numerical reduction order, and fault tolerance. Scalable AD systems must therefore treat communication operations as differentiable primitives and design the backward pass as a distributed execution graph.

