Transfer learning reuses a model trained on one task as the starting point for another task.
Transfer learning reuses a model trained on one task as the starting point for another task. In image classification, this usually means taking a convolutional network or vision transformer trained on a large image dataset, replacing its final classifier, and fine-tuning it on a smaller target dataset.
The central idea is simple: early and middle layers learn reusable visual features. They detect edges, colors, textures, shapes, object parts, and higher-level patterns. A new task often needs different class labels, but it can still benefit from these learned representations.
Why Transfer Learning Works
A randomly initialized model starts with no useful visual features. It must learn low-level patterns and high-level decision boundaries from the target dataset. This requires more data, more compute, and more tuning.
A pretrained model already contains useful representations. Fine-tuning adapts those representations to the new task.
Let a pretrained model be written as
Here is the feature extractor, and is the classifier head. In transfer learning, we usually keep , replace , and train the new model on the target classes.
For a target dataset with classes, the new classifier produces
The feature extractor may be frozen, partially unfrozen, or fully fine-tuned.
Feature Extraction Versus Fine-Tuning
There are two common transfer learning modes.
| Mode | What changes | Best when |
|---|---|---|
| Feature extraction | Freeze pretrained backbone, train only new head | Dataset is small, classes are similar to pretraining data |
| Fine-tuning | Train some or all pretrained layers | Dataset is larger, task differs from pretraining data |
Feature extraction is cheaper and less likely to overfit. Fine-tuning is more flexible and usually gives better final accuracy when enough data is available.
Loading a Pretrained Model
PyTorch provides pretrained models through torchvision.models.
import torch
import torch.nn as nn
from torchvision import models
model = models.resnet18(
weights=models.ResNet18_Weights.DEFAULT
)The model was trained with a specific preprocessing convention. The weights object provides the matching transform:
weights = models.ResNet18_Weights.DEFAULT
transform = weights.transforms()Using the correct transform matters. If the model was pretrained on normalized images, inference and fine-tuning should use the same normalization.
Replacing the Classifier Head
A ResNet classifier ends with a fully connected layer called fc.
num_classes = 5
model.fc = nn.Linear(
in_features=model.fc.in_features,
out_features=num_classes,
)Now the model outputs logits with shape:
[B, 5]Only the final layer shape changed. The convolutional backbone remains the same.
For other architectures, the classifier location differs. For example, many torchvision models use classifier instead of fc.
model = models.efficientnet_b0(
weights=models.EfficientNet_B0_Weights.DEFAULT
)
num_classes = 5
in_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(in_features, num_classes)Always inspect the model before replacing the head:
print(model)Freezing the Backbone
To use the pretrained model as a fixed feature extractor, disable gradients for the backbone.
for param in model.parameters():
param.requires_grad = False
model.fc = nn.Linear(model.fc.in_features, num_classes)The new head has requires_grad=True by default. The optimizer should receive only trainable parameters:
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=1e-3,
weight_decay=1e-4,
)This avoids updating frozen parameters and reduces optimizer state memory.
Fine-Tuning the Whole Model
For full fine-tuning, leave all parameters trainable:
for param in model.parameters():
param.requires_grad = TrueThen use a smaller learning rate than ordinary training from scratch:
optimizer = torch.optim.AdamW(
model.parameters(),
lr=3e-5,
weight_decay=1e-4,
)Pretrained weights already encode useful structure. Large learning rates can destroy that structure quickly. This is sometimes called catastrophic forgetting.
Differential Learning Rates
Often, the classifier head should learn faster than the pretrained backbone. PyTorch allows parameter groups with different learning rates.
backbone_params = []
head_params = []
for name, param in model.named_parameters():
if name.startswith("fc."):
head_params.append(param)
else:
backbone_params.append(param)
optimizer = torch.optim.AdamW(
[
{"params": backbone_params, "lr": 3e-5},
{"params": head_params, "lr": 3e-4},
],
weight_decay=1e-4,
)This trains the new classifier more aggressively while making smaller updates to the pretrained representation.
Progressive Unfreezing
Progressive unfreezing starts by training only the classifier head. Then it gradually unfreezes deeper parts of the backbone.
A common schedule is:
| Phase | Trainable parameters |
|---|---|
| Phase 1 | Classifier head only |
| Phase 2 | Last block plus classifier head |
| Phase 3 | Full model |
This approach is useful when the target dataset is small. It reduces the risk of damaging useful pretrained features early in training.
For ResNet, the final residual block is usually layer4.
for param in model.parameters():
param.requires_grad = False
for param in model.fc.parameters():
param.requires_grad = True
# later
for param in model.layer4.parameters():
param.requires_grad = TrueAfter changing which parameters are trainable, recreate the optimizer so it tracks the correct parameter set.
Batch Normalization During Transfer Learning
Batch normalization needs special care. A batch normalization layer has two kinds of state:
| State | Example | Updated by |
|---|---|---|
| Trainable parameters | scale and bias | gradients |
| Running statistics | running mean and variance | forward passes in training mode |
Freezing parameters does not automatically freeze running statistics. If the model remains in training mode, batch normalization statistics may still change.
For small target datasets, this can hurt performance. One option is to keep batch normalization layers in evaluation mode while training the classifier.
def set_batchnorm_eval(module):
if isinstance(module, nn.BatchNorm2d):
module.eval()
model.apply(set_batchnorm_eval)This keeps running statistics fixed. The right choice depends on dataset size and domain shift.
Transfer Learning Training Loop
The training loop is the same as ordinary classification. The main differences are the pretrained initialization, replaced classifier head, and optimizer parameter selection.
def train_one_epoch(model, loader, loss_fn, optimizer, device):
model.train()
total_loss = 0.0
total_correct = 0
total_count = 0
for images, labels in loader:
images = images.to(device)
labels = labels.to(device)
logits = model(images)
loss = loss_fn(logits, labels)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
preds = logits.argmax(dim=1)
total_loss += loss.item() * labels.size(0)
total_correct += (preds == labels).sum().item()
total_count += labels.size(0)
return {
"loss": total_loss / total_count,
"accuracy": total_correct / total_count,
}Validation remains unchanged:
@torch.no_grad()
def evaluate(model, loader, loss_fn, device):
model.eval()
total_loss = 0.0
total_correct = 0
total_count = 0
for images, labels in loader:
images = images.to(device)
labels = labels.to(device)
logits = model(images)
loss = loss_fn(logits, labels)
preds = logits.argmax(dim=1)
total_loss += loss.item() * labels.size(0)
total_correct += (preds == labels).sum().item()
total_count += labels.size(0)
return {
"loss": total_loss / total_count,
"accuracy": total_correct / total_count,
}Complete Example
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, models
device = "cuda" if torch.cuda.is_available() else "cpu"
weights = models.ResNet18_Weights.DEFAULT
transform = weights.transforms()
train_set = datasets.ImageFolder("dataset/train", transform=transform)
val_set = datasets.ImageFolder("dataset/val", transform=transform)
train_loader = DataLoader(
train_set,
batch_size=64,
shuffle=True,
num_workers=4,
pin_memory=True,
)
val_loader = DataLoader(
val_set,
batch_size=64,
shuffle=False,
num_workers=4,
pin_memory=True,
)
num_classes = len(train_set.classes)
model = models.resnet18(weights=weights)
for param in model.parameters():
param.requires_grad = False
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=1e-3,
weight_decay=1e-4,
)
best_acc = 0.0
for epoch in range(10):
train_metrics = train_one_epoch(
model=model,
loader=train_loader,
loss_fn=loss_fn,
optimizer=optimizer,
device=device,
)
val_metrics = evaluate(
model=model,
loader=val_loader,
loss_fn=loss_fn,
device=device,
)
print(
f"epoch={epoch + 1} "
f"train_loss={train_metrics['loss']:.4f} "
f"train_acc={train_metrics['accuracy']:.4f} "
f"val_loss={val_metrics['loss']:.4f} "
f"val_acc={val_metrics['accuracy']:.4f}"
)
if val_metrics["accuracy"] > best_acc:
best_acc = val_metrics["accuracy"]
torch.save(
{
"model": model.state_dict(),
"classes": train_set.classes,
"class_to_idx": train_set.class_to_idx,
"weights": "ResNet18_Weights.DEFAULT",
"val_acc": best_acc,
},
"transfer_classifier.pt",
)This version trains only the final classifier. It is a strong baseline for small image datasets.
When to Fine-Tune More Layers
The decision depends on the target data.
| Situation | Recommended approach |
|---|---|
| Very small dataset | Freeze backbone, train head |
| Small dataset similar to ImageNet | Freeze backbone, then unfreeze final block |
| Medium dataset | Fine-tune final blocks with small learning rate |
| Large dataset | Fine-tune full model |
| Strong domain shift | Fine-tune more layers |
| Medical, satellite, or scientific images | Fine-tune more layers, possibly from self-supervised weights |
A dataset of dog and cat photos is close to common pretraining data. A dataset of microscope images is farther away. Larger domain shift usually requires deeper adaptation.
Common Mistakes
The most common transfer learning errors are simple.
| Mistake | Consequence |
|---|---|
| Wrong normalization | Poor accuracy |
| Forgetting to replace classifier head | Wrong output shape |
| Training frozen parameters accidentally | Wasted memory and compute |
| Freezing all parameters including the new head | No learning |
| Learning rate too high | Destroyed pretrained features |
| Changing class order at inference | Wrong labels |
| Random validation transforms | Noisy metrics |
| Ignoring batch normalization behavior | Unstable fine-tuning |
Before training, verify which parameters are trainable:
for name, param in model.named_parameters():
if param.requires_grad:
print(name)This simple check prevents many silent failures.
Practical Defaults
A good first transfer learning setup is:
| Component | Default |
|---|---|
| Model | ResNet18 or EfficientNet-B0 |
| Weights | Official pretrained weights |
| Image size | 224 |
| Transform | Weight-specific transform |
| First phase | Freeze backbone |
| Optimizer | AdamW |
| Head learning rate | |
| Fine-tune learning rate | to |
| Loss | Cross-entropy |
| Metric | Validation accuracy |
| Checkpoint | Best validation metric |
These defaults are not always optimal, but they are usually stable. Once this baseline works, tune augmentation, learning rate, batch size, unfreezing depth, and model size.
Summary
Transfer learning uses pretrained models as reusable representation learners. For image classification, the usual workflow is to load pretrained weights, replace the classifier head, train the head, and optionally fine-tune deeper layers.
Feature extraction is safer and cheaper. Full fine-tuning is more powerful but more sensitive to learning rate, dataset size, and normalization. The best practice is to begin with a frozen-backbone baseline, then unfreeze progressively when the validation set shows that more adaptation is needed.