Skip to content

Scaling Transformers

Scaling a transformer means increasing its capacity, data exposure, context length, training compute, or serving throughput.

Scaling a transformer means increasing its capacity, data exposure, context length, training compute, or serving throughput. In practice, scaling is controlled by several coupled variables: number of parameters, number of training tokens, model dimension, depth, attention heads, sequence length, batch size, optimizer state, hardware memory, and inference latency.

A transformer can be scaled in many ways, but useful scaling is constrained by compute, memory, data quality, and optimization stability.

What Scaling Means

A transformer has several main size dimensions:

DimensionMeaning
DDHidden size or model dimension
LLNumber of layers
hhNumber of attention heads
DffD_{\text{ff}}Feedforward hidden size
TTContext length
VVVocabulary size
BBBatch size

Increasing any of these may improve model quality, but each has a cost.

Increasing DD makes every token representation wider. Increasing LL makes the model deeper. Increasing TT lets the model read longer sequences. Increasing BB improves hardware utilization and gradient estimates, but requires more memory.

Parameter Scaling

Most transformer parameters come from attention projections and feedforward layers.

For one layer, self-attention has query, key, value, and output projections. Each is approximately D×DD \times D, so attention parameters are roughly

4D2. 4D^2.

The feedforward network usually has two projections:

DDff,DffD. D \rightarrow D_{\text{ff}}, \quad D_{\text{ff}} \rightarrow D.

This gives roughly

2DDff 2DD_{\text{ff}}

parameters.

If Dff=4DD_{\text{ff}} = 4D, then the feedforward block has about

8D2 8D^2

parameters.

So one transformer layer has approximately

12D2 12D^2

main parameters.

For LL layers, the transformer stack has approximately

12LD2 12LD^2

parameters, ignoring embeddings, layer norms, and small biases.

The embedding table contributes

VD VD

parameters.

Thus, an approximate decoder-only transformer parameter count is

N12LD2+VD. N \approx 12LD^2 + VD.

This estimate is useful for understanding why width is expensive. Doubling DD roughly quadruples the parameters in the transformer stack.

Compute Scaling

Training compute is dominated by matrix multiplications. A rough estimate for dense transformer training compute is proportional to

computeN×training tokens. \text{compute} \propto N \times \text{training tokens}.

For each token, the model performs forward computation and backward computation. Larger models cost more per token. More training data costs more total steps.

Sequence length also affects compute. Self-attention forms a score matrix of size

T×T. T \times T.

The attention cost grows as

O(T2D). O(T^2D).

The feedforward cost grows as

O(TDDff). O(TD D_{\text{ff}}).

For moderate sequence lengths, feedforward layers often dominate compute. For long contexts, attention becomes a major bottleneck.

Memory Scaling

Training memory contains several parts:

ComponentDescription
ParametersModel weights
GradientsOne gradient tensor per trainable parameter
Optimizer stateExtra tensors such as Adam moments
ActivationsIntermediate tensors needed for backpropagation
Attention matricesT×TT \times T tensors for attention
Temporary buffersWorkspace used by kernels and communication

Adam usually stores two moment tensors per parameter. In standard full-precision accounting, each trainable parameter may require memory for the parameter, gradient, first moment, and second moment.

Mixed precision changes the exact accounting, but optimizer state remains a major memory cost.

Activation memory grows with batch size, sequence length, hidden size, and layer count:

O(BTLD). O(BTLD).

Attention memory can grow as

O(BhT2). O(BhT^2).

This is why long-context training can become memory-bound even when parameter count is moderate.

Depth Versus Width

A transformer can be scaled by adding layers or by increasing hidden size.

Increasing depth adds more sequential transformations. Deeper models can represent more stages of computation, but they are harder to optimize and less parallel across layers.

Increasing width gives each token representation more dimensions. Wider models increase matrix multiplication sizes and parameter count rapidly.

Common scaling patterns increase both depth and width. For stable design, the feedforward dimension and number of attention heads are usually scaled with DD.

A typical rule is:

Dff4D. D_{\text{ff}} \approx 4D.

The number of heads hh is chosen so that each head has a reasonable dimension:

dh=Dh. d_h = \frac{D}{h}.

Common head dimensions include 64 and 128.

Scaling Sequence Length

Longer context lets the model condition on more information. This is useful for long documents, code repositories, conversations, retrieval-augmented systems, and agent traces.

But increasing context length is expensive. If sequence length doubles, the attention matrix size increases by about four times.

For full attention:

T2T T \rightarrow 2T

implies

T24T2. T^2 \rightarrow 4T^2.

This affects both memory and compute.

Long-context models often use one or more of the following:

MethodPurpose
RoPE scalingExtend usable context length
Sliding-window attentionReduce attention span per token
Sparse attentionAttend to selected positions
Global tokensProvide long-range communication
RetrievalMove some memory outside the context
KV cache compressionReduce inference memory
ChunkingProcess long inputs in segments

Long context should be treated as a systems problem, not only as an architectural change.

Data Scaling

More parameters help only when the model sees enough useful data. A large model trained on too little data may overfit or underuse its capacity. A small model trained on too much data may become compute-inefficient.

Data scaling has three dimensions:

DimensionMeaning
QuantityNumber of training tokens or examples
QualityAccuracy, diversity, deduplication, filtering
MixtureBalance among domains, languages, tasks, and formats

For language models, data quality can matter as much as raw token count. Duplicate, low-quality, or contaminated data can waste compute and harm evaluation reliability.

For vision and multimodal models, scaling data includes image quality, caption quality, label consistency, resolution, and domain diversity.

Scaling Laws

Scaling laws describe empirical relationships between model size, dataset size, compute, and loss.

A simplified view is:

lossaNα+bDdataβ+cCγ+ϵ. \text{loss} \approx aN^{-\alpha} + bD_{\text{data}}^{-\beta} + cC^{-\gamma} + \epsilon.

Here NN is model size, DdataD_{\text{data}} is dataset size, and CC is compute. The exponents describe how quickly loss improves as each resource increases.

The practical message is simple: model size, data size, and compute must be balanced. Scaling only one axis eventually gives diminishing returns.

Scaling laws are empirical, not universal mathematical laws. They depend on architecture, data distribution, optimizer, training recipe, and evaluation target.

Compute-Optimal Training

Compute-optimal training asks: given a fixed compute budget, how large should the model be and how many tokens should it see?

A model that is too large may consume most of the budget in expensive updates but see too few tokens. A model that is too small may see many tokens but lack enough capacity.

Compute-optimal scaling balances parameter count and training tokens.

For practical training, this means:

Bad allocationProblem
Huge model, few tokensUndertrained model
Tiny model, excessive tokensCapacity bottleneck
Large context too earlyExpensive training with limited benefit
Poor data qualityWasted compute

A good scaling plan decides model size, token count, batch size, sequence length, and data mixture together.

Batch Size and Optimization

Larger batch sizes improve hardware utilization and reduce gradient noise, but they also change optimization behavior.

The effective batch size is often measured in tokens:

Btokens=B×T. B_{\text{tokens}} = B \times T.

In distributed training, the global batch size is

Bglobal=Bper device×number of devices. B_{\text{global}} = B_{\text{per device}} \times \text{number of devices}.

Gradient accumulation can increase the effective batch size without increasing per-device memory:

optimizer.zero_grad()

for micro_batch in micro_batches:
    loss = model_loss(micro_batch)
    loss = loss / grad_accum_steps
    loss.backward()

optimizer.step()

Large batch training usually requires learning rate tuning. Too large a batch can reduce generalization or cause optimization instability if the learning rate schedule is poorly chosen.

Learning Rate Scaling

When scaling batch size, the learning rate often needs adjustment.

A common heuristic is linear scaling:

ηnew=ηoldBnewBold. \eta_{\text{new}} = \eta_{\text{old}} \frac{B_{\text{new}}}{B_{\text{old}}}.

This is only a heuristic. Transformers are sensitive to warmup, decay schedule, optimizer settings, initialization, and normalization.

A common schedule uses warmup followed by cosine decay:

import math

def cosine_lr(step, *, warmup_steps, total_steps, max_lr, min_lr=0.0):
    if step < warmup_steps:
        return max_lr * step / warmup_steps

    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
    return min_lr + (max_lr - min_lr) * cosine

Warmup prevents early large updates when the model is poorly calibrated. Decay reduces update size later in training.

Mixed Precision Scaling

Large transformers are usually trained with mixed precision. Instead of storing and computing everything in float32, many operations use float16 or bfloat16.

Mixed precision reduces memory use and improves throughput on modern accelerators.

Common choices:

TypeNotes
float32Stable but expensive
float16Fast, lower dynamic range
bfloat16Wider dynamic range, common for large training
float8Emerging for very large systems

In PyTorch, automatic mixed precision can be used as follows:

scaler = torch.cuda.amp.GradScaler()

for batch in loader:
    optimizer.zero_grad()

    with torch.cuda.amp.autocast():
        loss = model_loss(batch)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

For bfloat16, gradient scaling is often unnecessary:

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
    loss = model_loss(batch)

loss.backward()
optimizer.step()

Exact support depends on hardware.

Activation Checkpointing

Activation checkpointing reduces memory by not storing all intermediate activations. Instead, selected activations are recomputed during the backward pass.

This trades compute for memory.

In PyTorch:

from torch.utils.checkpoint import checkpoint

def forward(self, x):
    for layer in self.layers:
        x = checkpoint(layer, x, use_reentrant=False)
    return x

Activation checkpointing is useful when model parameters fit in memory but training activations do not.

The cost is extra forward computation during backpropagation.

Distributed Scaling

Large transformer training usually requires distributed computation.

Common parallelism methods include:

MethodSplitMain benefit
Data parallelismBatchSimple scaling
Tensor parallelismMatrix operationsFits wider layers
Pipeline parallelismLayersFits deeper models
Sequence parallelismSequence dimensionHelps long context
Expert parallelismExperts in MoE layersScales sparse models
ZeRO-style shardingOptimizer, gradients, parametersReduces memory per device

Distributed training must handle communication cost. Adding devices only helps if computation dominates communication enough to maintain good utilization.

Scaling Inference

Training and inference have different bottlenecks.

During inference, batch size may be small, latency matters, and generation is sequential. Decoder-only models also store KV caches for all active sequences.

The KV cache memory grows with:

O(BLTD). O(BLT D).

For long-context generation, KV cache memory can dominate parameter memory.

Inference scaling techniques include:

TechniquePurpose
KV cachingAvoid recomputing previous tokens
Continuous batchingImprove serving throughput
QuantizationReduce memory and bandwidth
Speculative decodingReduce latency
Tensor parallel inferenceSplit large models across devices
Prefix cachingReuse prompt computation
Paged attentionManage KV memory efficiently

A model that is easy to train may still be expensive to serve. Deployment constraints should influence scaling decisions early.

Mixture-of-Experts Scaling

Mixture-of-Experts models increase parameter count without activating all parameters for every token.

An MoE feedforward layer contains multiple experts. A router selects a small number of experts for each token.

If there are EE experts and each token uses kk experts, the total parameter count can grow with EE, while compute per token grows mostly with kk.

This gives sparse scaling:

total parametersactive parameters per token. \text{total parameters} \gg \text{active parameters per token}.

MoE models can improve quality for a given compute budget, but they introduce routing instability, load balancing losses, communication overhead, and serving complexity.

Practical Scaling Checklist

Before scaling a transformer, check the smaller model carefully.

QuestionWhy it matters
Does the loss decrease smoothly?Confirms basic optimization
Are masks correct?Prevents leakage and padding errors
Are gradients finite?Detects instability
Is data loading fast enough?Prevents accelerator idle time
Is validation reliable?Avoids scaling the wrong metric
Is memory dominated by parameters, activations, or KV cache?Determines optimization strategy
Is the model undertrained or capacity-limited?Guides parameter-data balance

Scaling amplifies mistakes. A masking bug, data leak, unstable learning rate, or poor tokenizer can waste large amounts of compute.

Summary

Scaling transformers involves more than increasing parameter count. The main scaling axes are model size, training tokens, context length, batch size, data quality, and hardware parallelism.

Parameter count grows roughly with LD2LD^2. Attention memory grows roughly with T2T^2. Training memory includes parameters, gradients, optimizer states, activations, and attention buffers. Inference memory is often dominated by KV cache for long sequences.

Useful scaling requires balance. A good scaling recipe coordinates architecture, data, optimizer, precision, distributed training, and serving constraints.