Skip to Content
LibraryConceptsTransfer Learning

Transfer Learning

Transfer learning adapts models pre-trained on large datasets to new tasks with limited data. This is particularly critical in domains like medical imaging where labeled data is scarce and expensive.

Core Paradigm

Instead of training from scratch, transfer learning follows a two-stage process:

  1. Pre-training: Train on a large, general dataset (e.g., ImageNet with 1.2M images)
  2. Fine-tuning: Adapt to a specific, often smaller target dataset

This works because deep networks learn hierarchical features transferable across tasks.

Why Transfer Learning Works

CNNs learn hierarchical representations:

  • Early layers: General features (edges, colors, textures) - universal across visual tasks
  • Middle layers: Intermediate patterns (shapes, parts) - somewhat task-specific
  • Late layers: High-level, task-specific concepts (object classes, semantic features)

Early layer features transfer well across domains, even when source and target tasks differ significantly.

Implementation Patterns

Basic Fine-Tuning

import torch.nn as nn import torchvision.models as models # 1. Load pre-trained model model = models.resnet50(pretrained=True) # 2. Replace final layer for new task num_classes = 5 # Your task's number of classes num_features = model.fc.in_features model.fc = nn.Linear(num_features, num_classes) # 3. Define which parameters to train optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

Freezing Early Layers

# Freeze all parameters initially for param in model.parameters(): param.requires_grad = False # Only train the new final layer model.fc = nn.Linear(num_features, num_classes) optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)

Differential Learning Rates

Train different layers with different learning rates:

optimizer = torch.optim.Adam([ {'params': model.layer1.parameters(), 'lr': 1e-5}, {'params': model.layer2.parameters(), 'lr': 1e-5}, {'params': model.layer3.parameters(), 'lr': 1e-4}, {'params': model.layer4.parameters(), 'lr': 1e-4}, {'params': model.fc.parameters(), 'lr': 1e-3} ])

Lower learning rates for early layers preserve learned general features; higher rates for late layers allow task-specific adaptation.

Strategy Selection by Dataset Size

Small Dataset, Similar Domain

Example: 1,000 skin lesion images (similar to natural images)

Strategy: Feature extraction only

  • Freeze all pre-trained layers
  • Train only the final classification layer
  • Fast and prevents overfitting
for param in model.parameters(): param.requires_grad = False model.fc = nn.Linear(num_features, num_classes) optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)

Medium Dataset, Similar Domain

Example: 10,000 chest X-rays

Strategy: Partial fine-tuning

  • Freeze early layers (layer1, layer2)
  • Fine-tune later layers with low learning rate
  • Balance between transfer and adaptation
for name, param in model.named_parameters(): if 'layer1' in name or 'layer2' in name: param.requires_grad = False optimizer = torch.optim.Adam([ {'params': model.layer3.parameters(), 'lr': 1e-4}, {'params': model.layer4.parameters(), 'lr': 1e-4}, {'params': model.fc.parameters(), 'lr': 1e-3} ])

Large Dataset, Different Domain

Example: 100,000 medical CT scans (very different from ImageNet)

Strategy: Full fine-tuning with differential rates

  • Train all layers but with careful learning rates
  • Early layers change slowly, late layers adapt more
optimizer = torch.optim.Adam([ {'params': model.layer1.parameters(), 'lr': 1e-5}, {'params': model.layer2.parameters(), 'lr': 1e-5}, {'params': model.layer3.parameters(), 'lr': 1e-4}, {'params': model.layer4.parameters(), 'lr': 1e-4}, {'params': model.fc.parameters(), 'lr': 1e-3} ])

Feature Extraction vs Fine-Tuning

Feature Extraction

Use the pre-trained model as a fixed feature extractor:

model = models.resnet50(pretrained=True) # Remove final classification layer model = nn.Sequential(*list(model.children())[:-1]) model.eval() # Extract features (no gradients needed) with torch.no_grad(): features = model(images) # Shape: (batch, 2048, 1, 1) features = features.squeeze() # Shape: (batch, 2048) # Train a simple classifier on extracted features classifier = nn.Linear(2048, num_classes) optimizer = torch.optim.Adam(classifier.parameters())

When to use:

  • Very small datasets (< 1,000 samples)
  • Quick prototyping and experimentation
  • Fast inference required
  • Computational resources limited

Fine-Tuning

Adjust pre-trained weights through backpropagation:

model = models.resnet50(pretrained=True) model.fc = nn.Linear(model.fc.in_features, num_classes) # Fine-tune with careful learning rate optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # Training loop with backpropagation for images, labels in train_loader: optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step()

When to use:

  • Medium to large datasets (> 5,000 samples)
  • Domain differs significantly from pre-training data
  • Need maximum performance
  • Sufficient computational resources

Medical Imaging Considerations

Domain Shift

Pre-training on ImageNet (natural images) generalizes surprisingly well to medical images despite visual differences. However, domain-specific pre-training can provide additional benefits when available.

Medical-Specific Pre-Trained Models

Several models are pre-trained on medical data:

  • CheXNet: Pre-trained on chest X-rays (ChestX-ray14 dataset)
  • RadImageNet: Pre-trained on diverse radiology images
  • Models from Medical Segmentation Decathlon: Pre-trained on medical imaging tasks

Search for domain-specific weights for your specific medical imaging task (e.g., chest X-rays → use CheXNet; general radiology → use RadImageNet).

Data Considerations

Medical imaging often has:

  • Limited labeled data: Transfer learning essential
  • Class imbalance: Rare diseases underrepresented
  • High variance: Different imaging protocols, machines, institutions
  • Regulatory requirements: Model validation and explainability critical

Use transfer learning with data augmentation and careful validation strategies.

Available Pre-Trained Models

from torchvision import models # Classic CNNs resnet50 = models.resnet50(pretrained=True) # Standard choice resnet101 = models.resnet101(pretrained=True) # Deeper variant densenet121 = models.densenet121(pretrained=True) # Dense connections # Efficient architectures efficientnet_b0 = models.efficientnet_b0(pretrained=True) # Mobile/edge mobilenet_v3 = models.mobilenet_v3_small(pretrained=True) # Very efficient # Modern architectures vit_b_16 = models.vit_b_16(pretrained=True) # Vision Transformer

Best Practices

1. Always Start with Pre-Trained Weights

Unless you have > 100,000 labeled examples, pre-trained weights will outperform random initialization.

2. Use Appropriate Learning Rates

  • Early layers: Very low LR (1e-5 to 1e-6) or frozen
  • Middle layers: Low LR (1e-4 to 1e-5)
  • New layers: Higher LR (1e-3 to 1e-4)

3. Apply Data Augmentation

Critical for small datasets to prevent overfitting:

from torchvision import transforms train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.2, contrast=0.2), transforms.RandomRotation(10), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])

4. Monitor for Overfitting

Small medical datasets are prone to overfitting:

  • Use validation set for early stopping
  • Track train vs validation metrics closely
  • Consider ensemble methods

5. Validate on Held-Out Test Set

Never tune hyperparameters on your test set. Use:

  • Training set: Model training
  • Validation set: Hyperparameter tuning, early stopping
  • Test set: Final evaluation only

6. Consider Ensemble Methods

Average predictions from multiple fine-tuned models for improved robustness:

# Train multiple models with different initializations models = [train_model(seed=i) for i in range(5)] # Ensemble prediction predictions = [model(x) for model in models] ensemble_pred = torch.mean(torch.stack(predictions), dim=0)

Complete Example

import torch import torch.nn as nn import torch.optim as optim from torchvision import models, transforms, datasets # 1. Load pre-trained model model = models.resnet50(pretrained=True) # 2. Modify for target task num_classes = 5 model.fc = nn.Linear(model.fc.in_features, num_classes) # 3. Freeze early layers for name, param in model.named_parameters(): if 'layer1' in name or 'layer2' in name: param.requires_grad = False # 4. Differential learning rates optimizer = optim.Adam([ {'params': model.layer3.parameters(), 'lr': 1e-4}, {'params': model.layer4.parameters(), 'lr': 1e-4}, {'params': model.fc.parameters(), 'lr': 1e-3} ]) # 5. Data augmentation train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.2, contrast=0.2), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # 6. Training loop criterion = nn.CrossEntropyLoss() for epoch in range(num_epochs): for images, labels in train_loader: optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step()

Practical Exercises

Example 1: Multi-Label Satellite Imagery Classification

Dataset: Planet Amazon Rainforest 

Transfer learning on satellite imagery demonstrates domain adaptation from natural images (ImageNet) to remote sensing data. This multi-label classification task requires detecting multiple conditions per image (e.g., clear sky + primary rainforest + road).

Key Challenges:

  • Multi-label classification (multiple classes per image)
  • Domain shift from ground-level to satellite perspective
  • Class imbalance across atmospheric and land cover categories

Transfer Learning Strategy: Fine-tune ResNet or EfficientNet with sigmoid activation for multi-label output, using BCE loss instead of cross-entropy.

Example 2: Industrial Defect Detection

Dataset: Severstal Steel Defect Detection 

Industrial quality control often has very limited defect examples. Transfer learning enables robust defect classification despite severe class imbalance (most samples are defect-free).

Key Challenges:

  • Extreme class imbalance (rare defect classes)
  • High-resolution images requiring careful resizing
  • Critical precision requirements for industrial deployment

Transfer Learning Strategy: Feature extraction with frozen early layers, then train classifier with class-weighted loss or focal loss to handle imbalance.

Example 3: Agricultural Disease Detection

Dataset: Plant Disease Recognition 

Agricultural pathology detection parallels medical diagnosis—both require identifying disease patterns in images. Transfer learning works remarkably well despite the domain shift from ImageNet’s natural images.

Key Challenges:

  • Visual similarity between disease classes
  • Variation in lighting, background, and image quality
  • Need for interpretability (which features indicate disease?)

Transfer Learning Strategy: Partial fine-tuning with data augmentation (rotation, color jitter) to handle field condition variability. Similar approach to medical imaging applications.

  • Fine-Tuning Strategies - Detailed fine-tuning techniques
  • Feature Extraction - Using models as fixed feature extractors
  • Domain Adaptation - Adapting to distribution shifts

Key Papers

  • How transferable are features in deep neural networks? (Yosinski et al., 2014) - Systematic study of feature transferability
  • Deep Residual Learning for Image Recognition (He et al., 2015) - ResNet, most commonly used for transfer learning
  • ImageNet Classification with Deep Convolutional Neural Networks (Krizhevsky et al., 2012) - AlexNet, started the transfer learning era

Learning Resources

Articles

Papers