Skip to content

Population-Based Training

Population-based training, or PBT, is a hyperparameter optimization method that trains many models at the same time.

Population-based training, or PBT, is a hyperparameter optimization method that trains many models at the same time. Each model has its own weights and hyperparameters. During training, weak models are replaced or modified using information from stronger models.

Grid search, random search, and Bayesian optimization usually treat each trial as a separate run. A configuration is selected before training begins, and it usually stays fixed until the run ends. PBT changes this assumption. Hyperparameters can change while training is already in progress.

This makes PBT useful for hyperparameters whose best values vary over time, such as learning rate, weight decay, dropout, augmentation strength, entropy bonus, or reinforcement learning exploration rate.

The Basic Idea

PBT keeps a population of workers. Each worker trains a model.

A worker has:

ComponentMeaning
Weights θi\theta_iCurrent model parameters
Hyperparameters λi\lambda_iCurrent training settings
Score sis_iCurrent validation or reward metric
CheckpointSaved state for replacement

At regular intervals, the population is evaluated. Strong workers are allowed to continue. Weak workers copy weights from stronger workers and then perturb their hyperparameters.

The algorithm alternates between two phases:

PhasePurpose
ExploreModify hyperparameters
ExploitCopy from better-performing workers

This creates an evolutionary process over both model states and hyperparameters.

Why PBT Is Different

In ordinary hyperparameter search, each trial trains from scratch. If a trial starts with a poor learning rate, the entire run may be wasted.

PBT can recover from poor choices. A weak worker may copy a better checkpoint and continue training with modified hyperparameters.

This means PBT searches over schedules, not just fixed values. For example, instead of choosing one learning rate for all training, PBT may discover that a high learning rate works early and a low learning rate works later.

A fixed configuration looks like:

λ=(η,λwd,pdrop). \lambda = (\eta, \lambda_{\text{wd}}, p_{\text{drop}}).

A schedule is a function of training time:

λ(t)=(η(t),λwd(t),pdrop(t)). \lambda(t) = (\eta(t), \lambda_{\text{wd}}(t), p_{\text{drop}}(t)).

PBT searches for useful schedules by adapting hyperparameters during training.

Population State

Suppose the population has NN workers.

At time tt, worker ii has state:

(θi(t),λi(t),si(t)). (\theta_i^{(t)}, \lambda_i^{(t)}, s_i^{(t)}).

Here:

SymbolMeaning
θi(t)\theta_i^{(t)}Model weights for worker ii at time tt
λi(t)\lambda_i^{(t)}Hyperparameters for worker ii at time tt
si(t)s_i^{(t)}Validation score for worker ii at time tt

A worker trains locally for some number of steps, then reports its score. The population controller compares workers and decides whether any worker should be replaced.

Exploitation

Exploitation means copying from a stronger worker.

For example, after evaluation, workers are ranked by validation score. A worker in the bottom 20 percent may copy the checkpoint of a worker in the top 20 percent.

If worker jj is strong and worker ii is weak, exploitation performs:

θiθj, \theta_i \leftarrow \theta_j, λiλj. \lambda_i \leftarrow \lambda_j.

This gives the weak worker a better starting point. It avoids spending more compute on a clearly poor trajectory.

In practice, exploitation copies:

ItemUsually copied
Model weightsYes
Optimizer stateOften yes
Learning rate scheduler stateDepends
HyperparametersYes
Training stepUsually yes
Data loader stateUsually no

Copying optimizer state can matter. For AdamW, the momentum buffers influence future updates. If weights are copied but optimizer state is not, training may behave differently after replacement.

Exploration

Exploration means modifying copied hyperparameters.

After copying from a strong worker, the weak worker perturbs the inherited hyperparameters. This prevents all workers from becoming identical.

A simple perturbation rule is multiplicative:

ηη×r, \eta \leftarrow \eta \times r,

where

r{0.8,1.2}. r \in \{0.8, 1.2\}.

For example, if the copied learning rate is 10310^{-3}, exploration may change it to 8×1048\times10^{-4} or 1.2×1031.2\times10^{-3}.

For continuous hyperparameters, perturbation may use random noise:

ηηexp(ϵ),ϵN(0,σ2). \eta \leftarrow \eta \cdot \exp(\epsilon), \qquad \epsilon \sim \mathcal{N}(0,\sigma^2).

For categorical hyperparameters, exploration may randomly resample from a set:

optimizer = random.choice(["SGD", "AdamW"])

Exploration should respect valid ranges:

learning_rate = min(max(learning_rate, 1e-5), 1e-1)
dropout = min(max(dropout, 0.0), 0.5)

A Minimal PBT Algorithm

A simple PBT loop looks like this:

population = initialize_workers(num_workers)

for step in range(total_steps):
    for worker in population:
        worker.train(num_steps=steps_per_interval)

    for worker in population:
        worker.score = worker.evaluate()

    ranked = sort_by_score(population)

    bottom = ranked[:num_replace]
    top = ranked[-num_replace:]

    for weak_worker in bottom:
        strong_worker = random.choice(top)

        weak_worker.load_checkpoint(strong_worker.checkpoint)
        weak_worker.config = perturb(strong_worker.config)

This code omits many engineering details, but it shows the core pattern.

The key difference from random search is that a worker can inherit both weights and hyperparameters from another worker.

PyTorch Worker Structure

A PBT worker needs to save and load complete training state.

import torch

class Worker:
    def __init__(self, model, optimizer, config, device):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.config = config
        self.device = device
        self.step = 0
        self.score = None

    def train(self, loader, num_steps):
        self.model.train()

        iterator = iter(loader)

        for _ in range(num_steps):
            try:
                x, y = next(iterator)
            except StopIteration:
                iterator = iter(loader)
                x, y = next(iterator)

            x = x.to(self.device)
            y = y.to(self.device)

            logits = self.model(x)
            loss = torch.nn.functional.cross_entropy(logits, y)

            self.optimizer.zero_grad(set_to_none=True)
            loss.backward()
            self.optimizer.step()

            self.step += 1

    @torch.no_grad()
    def evaluate(self, loader):
        self.model.eval()

        correct = 0
        total = 0

        for x, y in loader:
            x = x.to(self.device)
            y = y.to(self.device)

            logits = self.model(x)
            pred = logits.argmax(dim=1)

            correct += (pred == y).sum().item()
            total += y.numel()

        self.score = correct / total
        return self.score

The checkpoint should include model weights, optimizer state, configuration, and step:

def state_dict(self):
    return {
        "model": self.model.state_dict(),
        "optimizer": self.optimizer.state_dict(),
        "config": self.config,
        "step": self.step,
        "score": self.score,
    }

def load_state_dict(self, state):
    self.model.load_state_dict(state["model"])
    self.optimizer.load_state_dict(state["optimizer"])
    self.config = dict(state["config"])
    self.step = state["step"]
    self.score = state["score"]

Perturbing Hyperparameters

A perturbation function should know which hyperparameters are mutable.

import copy
import random

def perturb(config):
    config = copy.deepcopy(config)

    for name in ["learning_rate", "weight_decay"]:
        if name in config:
            factor = random.choice([0.8, 1.2])
            config[name] *= factor

    if "dropout" in config:
        config["dropout"] += random.choice([-0.05, 0.05])
        config["dropout"] = min(max(config["dropout"], 0.0), 0.5)

    config["learning_rate"] = min(max(config["learning_rate"], 1e-5), 1e-1)
    config["weight_decay"] = min(max(config["weight_decay"], 1e-6), 1e-1)

    return config

If the optimizer uses the learning rate stored in param_groups, the optimizer must be updated after perturbation:

def apply_config_to_optimizer(optimizer, config):
    for group in optimizer.param_groups:
        group["lr"] = config["learning_rate"]
        group["weight_decay"] = config["weight_decay"]

For architecture hyperparameters such as hidden dimension or number of layers, simple PBT cannot modify them after training starts because the parameter shapes would change. PBT is best suited for training hyperparameters and regularization settings.

Choosing the Evaluation Interval

PBT requires an interval between exploit-explore steps.

If the interval is too short, scores are noisy and workers may copy from models that are only temporarily ahead.

If the interval is too long, weak workers waste compute before being replaced.

A practical interval depends on the task:

TaskTypical interval
Small image classificationEvery few epochs
Large supervised trainingEvery few thousand steps
Reinforcement learningEvery fixed number of environment steps
Language model fine-tuningEvery validation checkpoint
Diffusion trainingLess frequent, due to noisy metrics

The interval should be long enough that validation scores contain useful signal.

Metrics for Selection

PBT needs a scalar score for ranking workers.

For classification, this might be validation accuracy. For language modeling, it may be negative validation loss or negative perplexity. For reinforcement learning, it may be average episodic return.

When the objective has multiple terms, a scalar score can combine them:

s=accuracyαlatencyβmemory. s = \text{accuracy} - \alpha \cdot \text{latency} - \beta \cdot \text{memory}.

This allows PBT to optimize under deployment constraints.

The score should be stable enough to compare workers. If validation measurement is noisy, use moving averages or repeated evaluations.

Population Size

Population size controls diversity.

A small population is cheaper but explores fewer schedules. A large population explores more schedules but requires more hardware.

Population sizeBehavior
4 to 8Minimal, useful for small experiments
16 to 32Common practical range
64 or moreLarge-scale search

PBT works best when workers run in parallel. If only one GPU is available, PBT loses much of its advantage because the population must be simulated sequentially.

Strengths and Weaknesses

StrengthsWeaknesses
Searches dynamic schedulesRequires parallel compute
Reuses partial training progressMore complex than random search
Can recover from poor early choicesHarder to reproduce exactly
Works well for RL and large training runsNeeds careful checkpoint management
Handles nonstationary hyperparametersLess useful for architecture choices

PBT is especially useful when the best hyperparameters change during training.

When to Use PBT

PBT is appropriate when:

SituationReason
Many workers can run in parallelPBT is population-based
Hyperparameters should change over timePBT discovers schedules
Training is longMid-training adaptation helps
Early bad choices are costlyWorkers can recover
Reinforcement learning is involvedRL often has unstable dynamics

PBT is less appropriate when training is cheap, when only a few trials are possible, or when the main choices are fixed architecture decisions.

PBT and Learning Rate Schedules

Learning rate schedules are a natural fit for PBT. Instead of choosing a schedule manually, PBT can adapt the learning rate based on observed performance.

For example, one worker may keep a high learning rate longer and improve quickly. Another may reduce the learning rate earlier and generalize better. PBT can copy from the better trajectory and perturb it.

The resulting schedule may be irregular:

1031.2×1039.6×1047.7×1049.2×104. 10^{-3} \rightarrow 1.2\times10^{-3} \rightarrow 9.6\times10^{-4} \rightarrow 7.7\times10^{-4} \rightarrow 9.2\times10^{-4}.

This kind of schedule may perform well, even if it looks less clean than cosine decay or step decay.

Summary

Population-based training trains a population of models while adapting their hyperparameters during training. Weak workers copy weights and hyperparameters from stronger workers, then perturb those hyperparameters to continue exploration.

PBT combines training, selection, checkpoint reuse, and hyperparameter search into one process. It is useful for long-running training jobs, reinforcement learning, and hyperparameters whose best values change over time.

Its cost is engineering complexity. A reliable PBT system needs parallel workers, checkpoints, reproducible logging, stable evaluation, and careful handling of optimizer state.