Fine-tuning adapts a pretrained model to a target dataset by continuing training from learned weights instead of starting from random initialization.
Fine-tuning adapts a pretrained model to a target dataset by continuing training from learned weights instead of starting from random initialization. In transfer learning, we may train only a new classifier head. In fine-tuning, we update part or all of the pretrained backbone as well.
Fine-tuning is useful when the target task is close enough to the pretraining task that learned features remain useful, but different enough that the model must adapt. In image classification, this often means adapting an ImageNet-pretrained model to a smaller, domain-specific dataset.
What Fine-Tuning Changes
A pretrained classifier can be written as
Here is the backbone and is the classifier head. Fine-tuning changes , , or both.
There are three common settings:
| Setting | Trainable parameters |
|---|---|
| Head-only tuning | New classifier head |
| Partial fine-tuning | Final blocks plus classifier head |
| Full fine-tuning | Entire model |
Head-only tuning is often the first baseline. Partial fine-tuning is the usual next step. Full fine-tuning gives the model the most freedom, but it also increases compute cost and overfitting risk.
Replacing the Head First
A pretrained model usually has an output layer for its original pretraining classes. For ImageNet models, this is commonly 1000 classes. A target dataset may have a different number of classes.
import torch
import torch.nn as nn
from torchvision import models
num_classes = 12
weights = models.ResNet50_Weights.DEFAULT
model = models.resnet50(weights=weights)
model.fc = nn.Linear(model.fc.in_features, num_classes)The new head starts from random initialization. The backbone starts from pretrained weights.
For EfficientNet:
weights = models.EfficientNet_B0_Weights.DEFAULT
model = models.efficientnet_b0(weights=weights)
in_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(in_features, num_classes)The exact replacement depends on the architecture. Always inspect the model structure before editing the classifier.
Training Only the Head
The safest first phase freezes the pretrained backbone and trains only the new head.
for param in model.parameters():
param.requires_grad = False
for param in model.fc.parameters():
param.requires_grad = TrueThen create the optimizer after freezing:
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=1e-3,
weight_decay=1e-4,
)Creating the optimizer after selecting trainable parameters avoids storing optimizer state for frozen weights.
Unfreezing the Final Block
After the head learns a reasonable classifier, the final backbone block can be unfrozen.
For ResNet, the last block is usually layer4.
for param in model.layer4.parameters():
param.requires_grad = TrueThen recreate the optimizer with parameter groups:
optimizer = torch.optim.AdamW(
[
{"params": model.layer4.parameters(), "lr": 1e-5},
{"params": model.fc.parameters(), "lr": 1e-4},
],
weight_decay=1e-4,
)The classifier head receives a larger learning rate because it is newly initialized. The pretrained block receives a smaller learning rate because its features are already useful.
Full Fine-Tuning
Full fine-tuning makes all parameters trainable.
for param in model.parameters():
param.requires_grad = TrueUse a small learning rate:
optimizer = torch.optim.AdamW(
model.parameters(),
lr=3e-5,
weight_decay=1e-4,
)Full fine-tuning is usually appropriate when the target dataset is large enough, the validation set is reliable, and the target domain differs from the pretraining domain.
Discriminative Learning Rates
Different parts of the model should often move at different speeds. Early layers learn generic visual features. Later layers learn more task-specific features. The classifier head is entirely task-specific.
A common fine-tuning setup uses lower learning rates for earlier layers and higher learning rates for later layers.
optimizer = torch.optim.AdamW(
[
{"params": model.conv1.parameters(), "lr": 1e-6},
{"params": model.bn1.parameters(), "lr": 1e-6},
{"params": model.layer1.parameters(), "lr": 1e-6},
{"params": model.layer2.parameters(), "lr": 3e-6},
{"params": model.layer3.parameters(), "lr": 1e-5},
{"params": model.layer4.parameters(), "lr": 3e-5},
{"params": model.fc.parameters(), "lr": 3e-4},
],
weight_decay=1e-4,
)This pattern is called discriminative learning rates. It preserves low-level pretrained features while allowing task-specific layers to adapt.
Warmup and Learning Rate Decay
Fine-tuning can be sensitive during the first few epochs. Warmup gradually increases the learning rate from a small value to the target value.
A simple warmup plus cosine schedule can be implemented with SequentialLR.
warmup = torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=0.1,
total_iters=3,
)
cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=27,
)
scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
schedulers=[warmup, cosine],
milestones=[3],
)Then call:
scheduler.step()once per epoch.
Warmup reduces abrupt updates to pretrained weights. Cosine decay gradually lowers the learning rate so the model can settle into a better solution.
Batch Normalization Policy
Batch normalization is one of the main fine-tuning hazards.
Batch normalization layers contain trainable affine parameters and non-trainable running statistics. The running statistics are updated during training mode. Small target datasets may produce noisy running means and variances.
There are three common policies:
| Policy | Effect |
|---|---|
| Keep BatchNorm in training mode | Update running statistics |
| Freeze BatchNorm parameters | Do not train scale and bias |
| Keep BatchNorm in eval mode | Preserve pretrained running statistics |
For small datasets, preserving pretrained running statistics often works better.
def freeze_batchnorm(module):
if isinstance(module, nn.BatchNorm2d):
module.eval()
for param in module.parameters():
param.requires_grad = False
model.apply(freeze_batchnorm)If the target domain differs strongly from pretraining data and the batch size is large, updating batch normalization statistics may help.
Weight Decay Policy
Weight decay should usually apply to weight matrices and convolution kernels. It should usually not apply to bias terms or normalization parameters.
A cleaner optimizer setup separates decay and no-decay parameters:
def make_param_groups(model, weight_decay):
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if name.endswith(".bias") or "bn" in name.lower() or "norm" in name.lower():
no_decay.append(param)
else:
decay.append(param)
return [
{"params": decay, "weight_decay": weight_decay},
{"params": no_decay, "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(
make_param_groups(model, weight_decay=1e-4),
lr=3e-5,
)This prevents regularization from distorting normalization layers and bias terms.
Gradient Clipping
Fine-tuning may produce unstable gradients, especially when the classifier head is new or the dataset has noisy labels.
Gradient clipping limits the size of the gradient update.
loss.backward()
torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=1.0,
)
optimizer.step()Clipping is not always necessary for CNN fine-tuning, but it is useful for transformers, multimodal models, and unstable datasets.
Fine-Tuning Loop
A fine-tuning loop is the same basic training loop, with stricter control over modes, learning rates, and checkpoints.
def train_one_epoch(model, loader, loss_fn, optimizer, device):
model.train()
total_loss = 0.0
total_correct = 0
total_count = 0
for images, labels in loader:
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
logits = model(images)
loss = loss_fn(logits, labels)
optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
preds = logits.argmax(dim=1)
total_loss += loss.item() * labels.size(0)
total_correct += (preds == labels).sum().item()
total_count += labels.size(0)
return {
"loss": total_loss / total_count,
"accuracy": total_correct / total_count,
}If BatchNorm layers should remain frozen, apply that policy after model.train():
model.train()
model.apply(freeze_batchnorm)This is necessary because model.train() recursively puts modules into training mode.
Monitoring Fine-Tuning
Fine-tuning should be monitored with validation metrics. Training loss alone gives an incomplete picture.
Important signals include:
| Signal | Interpretation |
|---|---|
| Training loss decreases, validation improves | Fine-tuning is working |
| Training improves, validation worsens | Overfitting |
| Both training and validation fail to improve | Learning rate too low, frozen wrong layers, bad labels |
| Loss becomes NaN | Learning rate too high, numerical instability |
| Validation jumps erratically | Small validation set or unstable preprocessing |
The safest checkpoint policy saves the model with the best validation metric.
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(
{
"model": model.state_dict(),
"class_to_idx": train_set.class_to_idx,
"epoch": epoch,
"val_acc": val_acc,
},
"best_finetuned_model.pt",
)Fine-Tuning Vision Transformers
Vision transformers can also be fine-tuned. The pattern is similar: load pretrained weights, replace the head, and train with small learning rates.
weights = models.ViT_B_16_Weights.DEFAULT
model = models.vit_b_16(weights=weights)
in_features = model.heads.head.in_features
model.heads.head = nn.Linear(in_features, num_classes)Vision transformers are often more sensitive to optimization choices than smaller CNNs. Useful defaults include:
| Setting | Typical value |
|---|---|
| Optimizer | AdamW |
| Learning rate | to |
| Weight decay | to |
| Warmup | 5 percent to 10 percent of training |
| Augmentation | Moderate to strong |
| Gradient clipping | Often useful |
Transformers usually benefit from larger datasets. On small datasets, a pretrained CNN may outperform a vision transformer unless the transformer pretraining is strong.
Avoiding Data Leakage
Fine-tuning can look better than it really is if data leaks across splits.
Common leakage cases include:
| Leakage source | Example |
|---|---|
| Duplicate images | Same image appears in train and validation |
| Near duplicates | Augmented or resized copies cross splits |
| Subject leakage | Same patient, product, or scene appears in both splits |
| Time leakage | Future data appears in training |
| Label leakage | Filename or folder reveals the answer in preprocessing |
Fine-tuning a strong pretrained model can exploit leakage quickly. Always split data by the real unit of independence. For medical images, split by patient. For product photos, split by product. For video frames, split by video or scene, not by individual frame.
Practical Fine-Tuning Recipe
A reliable fine-tuning sequence is:
- Load pretrained weights and their matching transforms.
- Replace the classifier head.
- Freeze the backbone.
- Train the head for a few epochs.
- Unfreeze the final block.
- Train with a lower backbone learning rate.
- Optionally unfreeze the whole model.
- Save the checkpoint with the best validation metric.
- Test once on a held-out test set.
For a ResNet-style model, reasonable defaults are:
| Component | Default |
|---|---|
| Head-only learning rate | |
| Final block learning rate | to |
| Head learning rate during fine-tuning | to |
| Optimizer | AdamW or SGD with momentum |
| Batch size | Largest stable batch size |
| Scheduler | Cosine decay with warmup |
| BatchNorm | Freeze for small datasets |
| Checkpoint metric | Validation accuracy or macro F1 |
For imbalanced datasets, macro F1 is often better than accuracy.
Summary
Fine-tuning updates pretrained representations for a new task. It begins with replacing the classifier head and then deciding how much of the backbone should adapt.
The safest method is staged fine-tuning: train the head first, unfreeze later layers, then fine-tune more of the model only when validation metrics justify it. Use small learning rates for pretrained layers, larger learning rates for new layers, careful BatchNorm policy, and validation-based checkpointing.