Skip to content

Transfer Learning

Transfer learning reuses a model trained on one task as the starting point for another task.

Transfer learning reuses a model trained on one task as the starting point for another task. In image classification, this usually means taking a convolutional network or vision transformer trained on a large image dataset, replacing its final classifier, and fine-tuning it on a smaller target dataset.

The central idea is simple: early and middle layers learn reusable visual features. They detect edges, colors, textures, shapes, object parts, and higher-level patterns. A new task often needs different class labels, but it can still benefit from these learned representations.

Why Transfer Learning Works

A randomly initialized model starts with no useful visual features. It must learn low-level patterns and high-level decision boundaries from the target dataset. This requires more data, more compute, and more tuning.

A pretrained model already contains useful representations. Fine-tuning adapts those representations to the new task.

Let a pretrained model be written as

fθ(x)=gϕ(hψ(x)). f_\theta(x) = g_\phi(h_\psi(x)).

Here hψh_\psi is the feature extractor, and gϕg_\phi is the classifier head. In transfer learning, we usually keep hψh_\psi, replace gϕg_\phi, and train the new model on the target classes.

For a target dataset with KK classes, the new classifier produces

zRK. z \in \mathbb{R}^{K}.

The feature extractor may be frozen, partially unfrozen, or fully fine-tuned.

Feature Extraction Versus Fine-Tuning

There are two common transfer learning modes.

ModeWhat changesBest when
Feature extractionFreeze pretrained backbone, train only new headDataset is small, classes are similar to pretraining data
Fine-tuningTrain some or all pretrained layersDataset is larger, task differs from pretraining data

Feature extraction is cheaper and less likely to overfit. Fine-tuning is more flexible and usually gives better final accuracy when enough data is available.

Loading a Pretrained Model

PyTorch provides pretrained models through torchvision.models.

import torch
import torch.nn as nn
from torchvision import models

model = models.resnet18(
    weights=models.ResNet18_Weights.DEFAULT
)

The model was trained with a specific preprocessing convention. The weights object provides the matching transform:

weights = models.ResNet18_Weights.DEFAULT
transform = weights.transforms()

Using the correct transform matters. If the model was pretrained on normalized images, inference and fine-tuning should use the same normalization.

Replacing the Classifier Head

A ResNet classifier ends with a fully connected layer called fc.

num_classes = 5

model.fc = nn.Linear(
    in_features=model.fc.in_features,
    out_features=num_classes,
)

Now the model outputs logits with shape:

[B, 5]

Only the final layer shape changed. The convolutional backbone remains the same.

For other architectures, the classifier location differs. For example, many torchvision models use classifier instead of fc.

model = models.efficientnet_b0(
    weights=models.EfficientNet_B0_Weights.DEFAULT
)

num_classes = 5
in_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(in_features, num_classes)

Always inspect the model before replacing the head:

print(model)

Freezing the Backbone

To use the pretrained model as a fixed feature extractor, disable gradients for the backbone.

for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(model.fc.in_features, num_classes)

The new head has requires_grad=True by default. The optimizer should receive only trainable parameters:

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-3,
    weight_decay=1e-4,
)

This avoids updating frozen parameters and reduces optimizer state memory.

Fine-Tuning the Whole Model

For full fine-tuning, leave all parameters trainable:

for param in model.parameters():
    param.requires_grad = True

Then use a smaller learning rate than ordinary training from scratch:

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-5,
    weight_decay=1e-4,
)

Pretrained weights already encode useful structure. Large learning rates can destroy that structure quickly. This is sometimes called catastrophic forgetting.

Differential Learning Rates

Often, the classifier head should learn faster than the pretrained backbone. PyTorch allows parameter groups with different learning rates.

backbone_params = []
head_params = []

for name, param in model.named_parameters():
    if name.startswith("fc."):
        head_params.append(param)
    else:
        backbone_params.append(param)

optimizer = torch.optim.AdamW(
    [
        {"params": backbone_params, "lr": 3e-5},
        {"params": head_params, "lr": 3e-4},
    ],
    weight_decay=1e-4,
)

This trains the new classifier more aggressively while making smaller updates to the pretrained representation.

Progressive Unfreezing

Progressive unfreezing starts by training only the classifier head. Then it gradually unfreezes deeper parts of the backbone.

A common schedule is:

PhaseTrainable parameters
Phase 1Classifier head only
Phase 2Last block plus classifier head
Phase 3Full model

This approach is useful when the target dataset is small. It reduces the risk of damaging useful pretrained features early in training.

For ResNet, the final residual block is usually layer4.

for param in model.parameters():
    param.requires_grad = False

for param in model.fc.parameters():
    param.requires_grad = True

# later
for param in model.layer4.parameters():
    param.requires_grad = True

After changing which parameters are trainable, recreate the optimizer so it tracks the correct parameter set.

Batch Normalization During Transfer Learning

Batch normalization needs special care. A batch normalization layer has two kinds of state:

StateExampleUpdated by
Trainable parametersscale and biasgradients
Running statisticsrunning mean and varianceforward passes in training mode

Freezing parameters does not automatically freeze running statistics. If the model remains in training mode, batch normalization statistics may still change.

For small target datasets, this can hurt performance. One option is to keep batch normalization layers in evaluation mode while training the classifier.

def set_batchnorm_eval(module):
    if isinstance(module, nn.BatchNorm2d):
        module.eval()

model.apply(set_batchnorm_eval)

This keeps running statistics fixed. The right choice depends on dataset size and domain shift.

Transfer Learning Training Loop

The training loop is the same as ordinary classification. The main differences are the pretrained initialization, replaced classifier head, and optimizer parameter selection.

def train_one_epoch(model, loader, loss_fn, optimizer, device):
    model.train()

    total_loss = 0.0
    total_correct = 0
    total_count = 0

    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)

        logits = model(images)
        loss = loss_fn(logits, labels)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        preds = logits.argmax(dim=1)

        total_loss += loss.item() * labels.size(0)
        total_correct += (preds == labels).sum().item()
        total_count += labels.size(0)

    return {
        "loss": total_loss / total_count,
        "accuracy": total_correct / total_count,
    }

Validation remains unchanged:

@torch.no_grad()
def evaluate(model, loader, loss_fn, device):
    model.eval()

    total_loss = 0.0
    total_correct = 0
    total_count = 0

    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)

        logits = model(images)
        loss = loss_fn(logits, labels)

        preds = logits.argmax(dim=1)

        total_loss += loss.item() * labels.size(0)
        total_correct += (preds == labels).sum().item()
        total_count += labels.size(0)

    return {
        "loss": total_loss / total_count,
        "accuracy": total_correct / total_count,
    }

Complete Example

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, models

device = "cuda" if torch.cuda.is_available() else "cpu"

weights = models.ResNet18_Weights.DEFAULT
transform = weights.transforms()

train_set = datasets.ImageFolder("dataset/train", transform=transform)
val_set = datasets.ImageFolder("dataset/val", transform=transform)

train_loader = DataLoader(
    train_set,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

val_loader = DataLoader(
    val_set,
    batch_size=64,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)

num_classes = len(train_set.classes)

model = models.resnet18(weights=weights)

for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(model.fc.in_features, num_classes)

model = model.to(device)

loss_fn = nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-3,
    weight_decay=1e-4,
)

best_acc = 0.0

for epoch in range(10):
    train_metrics = train_one_epoch(
        model=model,
        loader=train_loader,
        loss_fn=loss_fn,
        optimizer=optimizer,
        device=device,
    )

    val_metrics = evaluate(
        model=model,
        loader=val_loader,
        loss_fn=loss_fn,
        device=device,
    )

    print(
        f"epoch={epoch + 1} "
        f"train_loss={train_metrics['loss']:.4f} "
        f"train_acc={train_metrics['accuracy']:.4f} "
        f"val_loss={val_metrics['loss']:.4f} "
        f"val_acc={val_metrics['accuracy']:.4f}"
    )

    if val_metrics["accuracy"] > best_acc:
        best_acc = val_metrics["accuracy"]

        torch.save(
            {
                "model": model.state_dict(),
                "classes": train_set.classes,
                "class_to_idx": train_set.class_to_idx,
                "weights": "ResNet18_Weights.DEFAULT",
                "val_acc": best_acc,
            },
            "transfer_classifier.pt",
        )

This version trains only the final classifier. It is a strong baseline for small image datasets.

When to Fine-Tune More Layers

The decision depends on the target data.

SituationRecommended approach
Very small datasetFreeze backbone, train head
Small dataset similar to ImageNetFreeze backbone, then unfreeze final block
Medium datasetFine-tune final blocks with small learning rate
Large datasetFine-tune full model
Strong domain shiftFine-tune more layers
Medical, satellite, or scientific imagesFine-tune more layers, possibly from self-supervised weights

A dataset of dog and cat photos is close to common pretraining data. A dataset of microscope images is farther away. Larger domain shift usually requires deeper adaptation.

Common Mistakes

The most common transfer learning errors are simple.

MistakeConsequence
Wrong normalizationPoor accuracy
Forgetting to replace classifier headWrong output shape
Training frozen parameters accidentallyWasted memory and compute
Freezing all parameters including the new headNo learning
Learning rate too highDestroyed pretrained features
Changing class order at inferenceWrong labels
Random validation transformsNoisy metrics
Ignoring batch normalization behaviorUnstable fine-tuning

Before training, verify which parameters are trainable:

for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

This simple check prevents many silent failures.

Practical Defaults

A good first transfer learning setup is:

ComponentDefault
ModelResNet18 or EfficientNet-B0
WeightsOfficial pretrained weights
Image size224
TransformWeight-specific transform
First phaseFreeze backbone
OptimizerAdamW
Head learning rate10310^{-3}
Fine-tune learning rate10510^{-5} to 10410^{-4}
LossCross-entropy
MetricValidation accuracy
CheckpointBest validation metric

These defaults are not always optimal, but they are usually stable. Once this baseline works, tune augmentation, learning rate, batch size, unfreezing depth, and model size.

Summary

Transfer learning uses pretrained models as reusable representation learners. For image classification, the usual workflow is to load pretrained weights, replace the classifier head, train the head, and optionally fine-tune deeper layers.

Feature extraction is safer and cheaper. Full fine-tuning is more powerful but more sensitive to learning rate, dataset size, and normalization. The best practice is to begin with a frozen-backbone baseline, then unfreeze progressively when the validation set shows that more adaptation is needed.