Scaling a transformer means increasing its capacity, data exposure, context length, training compute, or serving throughput.
Scaling a transformer means increasing its capacity, data exposure, context length, training compute, or serving throughput. In practice, scaling is controlled by several coupled variables: number of parameters, number of training tokens, model dimension, depth, attention heads, sequence length, batch size, optimizer state, hardware memory, and inference latency.
A transformer can be scaled in many ways, but useful scaling is constrained by compute, memory, data quality, and optimization stability.
What Scaling Means
A transformer has several main size dimensions:
| Dimension | Meaning |
|---|---|
| Hidden size or model dimension | |
| Number of layers | |
| Number of attention heads | |
| Feedforward hidden size | |
| Context length | |
| Vocabulary size | |
| Batch size |
Increasing any of these may improve model quality, but each has a cost.
Increasing makes every token representation wider. Increasing makes the model deeper. Increasing lets the model read longer sequences. Increasing improves hardware utilization and gradient estimates, but requires more memory.
Parameter Scaling
Most transformer parameters come from attention projections and feedforward layers.
For one layer, self-attention has query, key, value, and output projections. Each is approximately , so attention parameters are roughly
The feedforward network usually has two projections:
This gives roughly
parameters.
If , then the feedforward block has about
parameters.
So one transformer layer has approximately
main parameters.
For layers, the transformer stack has approximately
parameters, ignoring embeddings, layer norms, and small biases.
The embedding table contributes
parameters.
Thus, an approximate decoder-only transformer parameter count is
This estimate is useful for understanding why width is expensive. Doubling roughly quadruples the parameters in the transformer stack.
Compute Scaling
Training compute is dominated by matrix multiplications. A rough estimate for dense transformer training compute is proportional to
For each token, the model performs forward computation and backward computation. Larger models cost more per token. More training data costs more total steps.
Sequence length also affects compute. Self-attention forms a score matrix of size
The attention cost grows as
The feedforward cost grows as
For moderate sequence lengths, feedforward layers often dominate compute. For long contexts, attention becomes a major bottleneck.
Memory Scaling
Training memory contains several parts:
| Component | Description |
|---|---|
| Parameters | Model weights |
| Gradients | One gradient tensor per trainable parameter |
| Optimizer state | Extra tensors such as Adam moments |
| Activations | Intermediate tensors needed for backpropagation |
| Attention matrices | tensors for attention |
| Temporary buffers | Workspace used by kernels and communication |
Adam usually stores two moment tensors per parameter. In standard full-precision accounting, each trainable parameter may require memory for the parameter, gradient, first moment, and second moment.
Mixed precision changes the exact accounting, but optimizer state remains a major memory cost.
Activation memory grows with batch size, sequence length, hidden size, and layer count:
Attention memory can grow as
This is why long-context training can become memory-bound even when parameter count is moderate.
Depth Versus Width
A transformer can be scaled by adding layers or by increasing hidden size.
Increasing depth adds more sequential transformations. Deeper models can represent more stages of computation, but they are harder to optimize and less parallel across layers.
Increasing width gives each token representation more dimensions. Wider models increase matrix multiplication sizes and parameter count rapidly.
Common scaling patterns increase both depth and width. For stable design, the feedforward dimension and number of attention heads are usually scaled with .
A typical rule is:
The number of heads is chosen so that each head has a reasonable dimension:
Common head dimensions include 64 and 128.
Scaling Sequence Length
Longer context lets the model condition on more information. This is useful for long documents, code repositories, conversations, retrieval-augmented systems, and agent traces.
But increasing context length is expensive. If sequence length doubles, the attention matrix size increases by about four times.
For full attention:
implies
This affects both memory and compute.
Long-context models often use one or more of the following:
| Method | Purpose |
|---|---|
| RoPE scaling | Extend usable context length |
| Sliding-window attention | Reduce attention span per token |
| Sparse attention | Attend to selected positions |
| Global tokens | Provide long-range communication |
| Retrieval | Move some memory outside the context |
| KV cache compression | Reduce inference memory |
| Chunking | Process long inputs in segments |
Long context should be treated as a systems problem, not only as an architectural change.
Data Scaling
More parameters help only when the model sees enough useful data. A large model trained on too little data may overfit or underuse its capacity. A small model trained on too much data may become compute-inefficient.
Data scaling has three dimensions:
| Dimension | Meaning |
|---|---|
| Quantity | Number of training tokens or examples |
| Quality | Accuracy, diversity, deduplication, filtering |
| Mixture | Balance among domains, languages, tasks, and formats |
For language models, data quality can matter as much as raw token count. Duplicate, low-quality, or contaminated data can waste compute and harm evaluation reliability.
For vision and multimodal models, scaling data includes image quality, caption quality, label consistency, resolution, and domain diversity.
Scaling Laws
Scaling laws describe empirical relationships between model size, dataset size, compute, and loss.
A simplified view is:
Here is model size, is dataset size, and is compute. The exponents describe how quickly loss improves as each resource increases.
The practical message is simple: model size, data size, and compute must be balanced. Scaling only one axis eventually gives diminishing returns.
Scaling laws are empirical, not universal mathematical laws. They depend on architecture, data distribution, optimizer, training recipe, and evaluation target.
Compute-Optimal Training
Compute-optimal training asks: given a fixed compute budget, how large should the model be and how many tokens should it see?
A model that is too large may consume most of the budget in expensive updates but see too few tokens. A model that is too small may see many tokens but lack enough capacity.
Compute-optimal scaling balances parameter count and training tokens.
For practical training, this means:
| Bad allocation | Problem |
|---|---|
| Huge model, few tokens | Undertrained model |
| Tiny model, excessive tokens | Capacity bottleneck |
| Large context too early | Expensive training with limited benefit |
| Poor data quality | Wasted compute |
A good scaling plan decides model size, token count, batch size, sequence length, and data mixture together.
Batch Size and Optimization
Larger batch sizes improve hardware utilization and reduce gradient noise, but they also change optimization behavior.
The effective batch size is often measured in tokens:
In distributed training, the global batch size is
Gradient accumulation can increase the effective batch size without increasing per-device memory:
optimizer.zero_grad()
for micro_batch in micro_batches:
loss = model_loss(micro_batch)
loss = loss / grad_accum_steps
loss.backward()
optimizer.step()Large batch training usually requires learning rate tuning. Too large a batch can reduce generalization or cause optimization instability if the learning rate schedule is poorly chosen.
Learning Rate Scaling
When scaling batch size, the learning rate often needs adjustment.
A common heuristic is linear scaling:
This is only a heuristic. Transformers are sensitive to warmup, decay schedule, optimizer settings, initialization, and normalization.
A common schedule uses warmup followed by cosine decay:
import math
def cosine_lr(step, *, warmup_steps, total_steps, max_lr, min_lr=0.0):
if step < warmup_steps:
return max_lr * step / warmup_steps
progress = (step - warmup_steps) / (total_steps - warmup_steps)
cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
return min_lr + (max_lr - min_lr) * cosineWarmup prevents early large updates when the model is poorly calibrated. Decay reduces update size later in training.
Mixed Precision Scaling
Large transformers are usually trained with mixed precision. Instead of storing and computing everything in float32, many operations use float16 or bfloat16.
Mixed precision reduces memory use and improves throughput on modern accelerators.
Common choices:
| Type | Notes |
|---|---|
float32 | Stable but expensive |
float16 | Fast, lower dynamic range |
bfloat16 | Wider dynamic range, common for large training |
float8 | Emerging for very large systems |
In PyTorch, automatic mixed precision can be used as follows:
scaler = torch.cuda.amp.GradScaler()
for batch in loader:
optimizer.zero_grad()
with torch.cuda.amp.autocast():
loss = model_loss(batch)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()For bfloat16, gradient scaling is often unnecessary:
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
loss = model_loss(batch)
loss.backward()
optimizer.step()Exact support depends on hardware.
Activation Checkpointing
Activation checkpointing reduces memory by not storing all intermediate activations. Instead, selected activations are recomputed during the backward pass.
This trades compute for memory.
In PyTorch:
from torch.utils.checkpoint import checkpoint
def forward(self, x):
for layer in self.layers:
x = checkpoint(layer, x, use_reentrant=False)
return xActivation checkpointing is useful when model parameters fit in memory but training activations do not.
The cost is extra forward computation during backpropagation.
Distributed Scaling
Large transformer training usually requires distributed computation.
Common parallelism methods include:
| Method | Split | Main benefit |
|---|---|---|
| Data parallelism | Batch | Simple scaling |
| Tensor parallelism | Matrix operations | Fits wider layers |
| Pipeline parallelism | Layers | Fits deeper models |
| Sequence parallelism | Sequence dimension | Helps long context |
| Expert parallelism | Experts in MoE layers | Scales sparse models |
| ZeRO-style sharding | Optimizer, gradients, parameters | Reduces memory per device |
Distributed training must handle communication cost. Adding devices only helps if computation dominates communication enough to maintain good utilization.
Scaling Inference
Training and inference have different bottlenecks.
During inference, batch size may be small, latency matters, and generation is sequential. Decoder-only models also store KV caches for all active sequences.
The KV cache memory grows with:
For long-context generation, KV cache memory can dominate parameter memory.
Inference scaling techniques include:
| Technique | Purpose |
|---|---|
| KV caching | Avoid recomputing previous tokens |
| Continuous batching | Improve serving throughput |
| Quantization | Reduce memory and bandwidth |
| Speculative decoding | Reduce latency |
| Tensor parallel inference | Split large models across devices |
| Prefix caching | Reuse prompt computation |
| Paged attention | Manage KV memory efficiently |
A model that is easy to train may still be expensive to serve. Deployment constraints should influence scaling decisions early.
Mixture-of-Experts Scaling
Mixture-of-Experts models increase parameter count without activating all parameters for every token.
An MoE feedforward layer contains multiple experts. A router selects a small number of experts for each token.
If there are experts and each token uses experts, the total parameter count can grow with , while compute per token grows mostly with .
This gives sparse scaling:
MoE models can improve quality for a given compute budget, but they introduce routing instability, load balancing losses, communication overhead, and serving complexity.
Practical Scaling Checklist
Before scaling a transformer, check the smaller model carefully.
| Question | Why it matters |
|---|---|
| Does the loss decrease smoothly? | Confirms basic optimization |
| Are masks correct? | Prevents leakage and padding errors |
| Are gradients finite? | Detects instability |
| Is data loading fast enough? | Prevents accelerator idle time |
| Is validation reliable? | Avoids scaling the wrong metric |
| Is memory dominated by parameters, activations, or KV cache? | Determines optimization strategy |
| Is the model undertrained or capacity-limited? | Guides parameter-data balance |
Scaling amplifies mistakes. A masking bug, data leak, unstable learning rate, or poor tokenizer can waste large amounts of compute.
Summary
Scaling transformers involves more than increasing parameter count. The main scaling axes are model size, training tokens, context length, batch size, data quality, and hardware parallelism.
Parameter count grows roughly with . Attention memory grows roughly with . Training memory includes parameters, gradients, optimizer states, activations, and attention buffers. Inference memory is often dominated by KV cache for long sequences.
Useful scaling requires balance. A good scaling recipe coordinates architecture, data, optimizer, precision, distributed training, and serving constraints.