Skip to content

Model Parallelism

Model parallelism splits a model across multiple devices. Instead of copying the whole model onto every GPU, different parts of the model live on different GPUs.

Model parallelism splits a model across multiple devices. Instead of copying the whole model onto every GPU, different parts of the model live on different GPUs.

This is useful when the model is too large to fit on one device. Data parallelism replicates the model, so each device must hold a full copy. Model parallelism removes this requirement by partitioning the model itself.

A simple example is a network with two large blocks:

fθ(x)=f2(f1(x)). f_\theta(x) = f_2(f_1(x)).

We can place f1f_1 on GPU 0 and f2f_2 on GPU 1:

class TwoPartModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.part1 = Block1().to("cuda:0")
        self.part2 = Block2().to("cuda:1")

    def forward(self, x):
        x = x.to("cuda:0")
        h = self.part1(x)

        h = h.to("cuda:1")
        y = self.part2(h)

        return y

The activation tensor h must move from GPU 0 to GPU 1. This transfer is the main cost of simple model parallelism.

Why Model Parallelism Is Needed

Large models require memory for several objects:

ObjectDescription
ParametersTrainable weights
GradientsDerivatives of the loss with respect to parameters
Optimizer stateMomentum, variance, and other optimizer tensors
ActivationsIntermediate tensors saved for backpropagation
Temporary buffersWorkspace used by kernels and communication libraries

For AdamW training in float32, each parameter may require storage for the parameter, gradient, first moment, and second moment. Mixed precision can add master weights and casting buffers.

As models grow, a single GPU may run out of memory even for batch size 1. Model parallelism allows the model state and activations to be spread across devices.

Layer-Wise Model Parallelism

The simplest form of model parallelism places different layers on different devices.

For example:

self.layers = torch.nn.ModuleList([
    Layer0().to("cuda:0"),
    Layer1().to("cuda:0"),
    Layer2().to("cuda:1"),
    Layer3().to("cuda:1"),
])

The forward pass moves activations between devices when needed:

def forward(self, x):
    x = x.to("cuda:0")

    x = self.layers[0](x)
    x = self.layers[1](x)

    x = x.to("cuda:1")

    x = self.layers[2](x)
    x = self.layers[3](x)

    return x

This approach is easy to understand but often inefficient. While GPU 0 computes the early layers, GPU 1 waits. After the activation is transferred, GPU 1 computes the later layers while GPU 0 waits.

The result is poor device utilization.

The Idle Device Problem

Consider a two-stage model:

xf1f2. x \rightarrow f_1 \rightarrow f_2.

If f1f_1 is on GPU 0 and f2f_2 is on GPU 1, only one GPU may be active at a time for a single batch.

Timeline:

TimeGPU 0GPU 1
Step 1Compute f1(x)f_1(x)Idle
Step 2Transfer activationIdle or receive
Step 3IdleCompute f2(h)f_2(h)

This wastes compute. The model fits in memory, but throughput may be worse than single-GPU training if communication and idle time dominate.

Pipeline parallelism, covered next, addresses this by splitting a batch into microbatches and keeping multiple devices busy at the same time.

Tensor Parallelism

Tensor parallelism splits individual tensor operations across devices. Instead of placing whole layers on different GPUs, it partitions the matrices inside a layer.

A linear layer computes

Y=XW+b. Y = XW + b.

If WRd×hW\in\mathbb{R}^{d \times h}, we can split WW by columns:

W=[W1W2WK]. W = \begin{bmatrix} W_1 & W_2 & \cdots & W_K \end{bmatrix}.

Each GPU computes part of the output:

Yk=XWk. Y_k = XW_k.

The partial outputs are concatenated:

Y=[Y1Y2YK]. Y = \begin{bmatrix} Y_1 & Y_2 & \cdots & Y_K \end{bmatrix}.

This is column parallelism.

We can also split WW by rows:

W=[W1W2WK]. W = \begin{bmatrix} W_1 \\ W_2 \\ \vdots \\ W_K \end{bmatrix}.

Each device computes a partial contribution, and the results are summed with an all-reduce.

Tensor parallelism is common in large transformer training because transformer blocks contain huge matrix multiplications in attention and feedforward layers.

Column-Parallel Linear Layers

A column-parallel linear layer splits the output dimension across devices.

Suppose

XRB×d,WRd×h. X\in\mathbb{R}^{B\times d}, \quad W\in\mathbb{R}^{d\times h}.

With KK devices, each device stores

WkRd×h/K. W_k\in\mathbb{R}^{d\times h/K}.

Each device computes:

Yk=XWk. Y_k = XW_k.

The local output has shape:

YkRB×h/K. Y_k\in\mathbb{R}^{B\times h/K}.

If the next operation can consume partitioned outputs, no immediate gather is needed. Otherwise, the outputs are gathered to form:

YRB×h. Y\in\mathbb{R}^{B\times h}.

Column parallelism reduces parameter memory per device and distributes matrix multiplication.

Row-Parallel Linear Layers

A row-parallel linear layer splits the input dimension.

Suppose each device receives part of the input:

X=[X1X2XK], X = \begin{bmatrix} X_1 & X_2 & \cdots & X_K \end{bmatrix},

and stores a matching partition:

W=[W1W2WK]. W = \begin{bmatrix} W_1 \\ W_2 \\ \vdots \\ W_K \end{bmatrix}.

Each device computes:

Zk=XkWk. Z_k = X_k W_k.

The full output is the sum:

Y=k=1KZk. Y = \sum_{k=1}^{K} Z_k.

This requires an all-reduce across devices.

Column-parallel and row-parallel layers are often paired. A transformer feedforward block may use column parallelism in the first projection and row parallelism in the second projection, reducing unnecessary communication between them.

Tensor Parallelism in Transformers

A transformer block contains several large linear operations:

ComponentTypical tensor operation
Query projectionXWQXW_Q
Key projectionXWKXW_K
Value projectionXWVXW_V
Attention output projectionXWOXW_O
Feedforward up projectionXW1XW_1
Feedforward down projectionXW2XW_2

These matrices are natural targets for tensor parallelism.

Attention heads can also be partitioned. If a model has 32 attention heads and 4 GPUs, each GPU can process 8 heads. This reduces memory and computation per device while keeping the mathematical structure intact.

Communication Costs

Model parallelism trades memory savings for communication.

Communication occurs when:

  • activations move between layers on different devices
  • partial outputs are gathered
  • partial sums are all-reduced
  • gradients flow backward across partitions

The performance of model parallelism depends on the ratio between computation and communication. Large matrix multiplications are favorable because they perform many arithmetic operations per byte communicated.

Small layers or frequent device transfers are unfavorable.

This is why model parallelism works best with large dense models and fast interconnects such as NVLink, NVSwitch, or high-bandwidth cluster networks.

Autograd Across Devices

PyTorch autograd can track operations across devices. If an activation is moved from one GPU to another, autograd records the transfer as part of the computation graph.

Example:

x = x.to("cuda:0")
h = self.part1(x)

h = h.to("cuda:1")
y = self.part2(h)

loss = loss_fn(y, target.to("cuda:1"))
loss.backward()

During the backward pass, gradients flow from GPU 1 back through the transfer operation to GPU 0.

This makes simple model parallelism easy to implement. However, ease of implementation does not guarantee high throughput. Communication and scheduling still dominate performance.

Model Parallelism Versus Data Parallelism

Data parallelism and model parallelism split different axes of the training problem.

MethodWhat is splitWhat each device storesBest when
Data parallelismBatchFull model replicaModel fits on one device
Model parallelismModelPart of modelModel does not fit on one device
Tensor parallelismLayer tensorsSlices of parametersIndividual layers are too large
Pipeline parallelismLayer groupsSequential model stagesMany layers can be staged

These methods are often combined. Large language model training commonly uses data parallelism across groups of replicas, tensor parallelism inside each transformer block, and pipeline parallelism across blocks.

Memory Accounting

To decide whether model parallelism is needed, estimate memory per parameter.

For AdamW training, approximate memory can include:

ItemBytes per parameter
Parameter in fp16 or bf162
Gradient in fp16 or bf162
Master parameter in fp324
Adam first moment4
Adam second moment4
Total16

A 7-billion-parameter model may require about:

7×109×16=112×109 7 \times 10^9 \times 16 = 112 \times 10^9

bytes, or roughly 112 GB, before accounting for activations, temporary buffers, and fragmentation.

This exceeds the memory of many single GPUs. Model parallelism or sharded training becomes necessary.

Practical PyTorch Patterns

For small model-parallel experiments, explicit device placement is enough:

class SplitMLP(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(4096, 4096).to("cuda:0")
        self.l2 = torch.nn.Linear(4096, 4096).to("cuda:1")

    def forward(self, x):
        x = x.to("cuda:0")
        x = torch.relu(self.l1(x))

        x = x.to("cuda:1")
        x = self.l2(x)

        return x

For serious large-model training, manual placement becomes hard to maintain. Common tools include:

  • PyTorch FSDP
  • PyTorch tensor parallel APIs
  • DeepSpeed
  • Megatron-LM
  • Hugging Face Accelerate
  • FairScale

These systems automate partitioning, communication, and memory management.

Common Failure Modes

The first common failure is excessive activation transfer. If tensors move across GPUs too often, communication dominates runtime.

The second failure is device mismatch. Inputs, targets, and model parts must be on compatible devices.

Example error:

Expected all tensors to be on the same device

The third failure is poor load balance. If one GPU owns much more computation than another, faster devices wait for the slowest stage.

The fourth failure is hidden memory imbalance. One GPU may hold embeddings, output heads, or normalization layers that make it run out of memory first.

The fifth failure is optimizer state placement. Parameters placed on multiple devices require optimizer state on those devices. A careless optimizer setup may create state in unexpected locations.

When to Use Model Parallelism

Use model parallelism when the model cannot fit on one GPU, or when a single layer is too large for one device.

Avoid simple layer-wise model parallelism when the model already fits on one GPU and throughput is the main goal. Data parallelism is usually faster and simpler.

Use tensor parallelism when large matrix multiplications dominate and the interconnect is fast.

Use pipeline parallelism when the model has many sequential layers that can be divided into balanced stages.

Use sharded data parallelism when the model mostly fits computationally but full replication wastes too much memory.

Model parallelism solves a memory problem first. Efficient model parallelism also solves a scheduling problem. The challenge is to partition the model so that each device stores less, computes enough, and communicates as little as possible.