Deep Learning for Medical Imaging
Deep learning, particularly convolutional neural networks (CNNs), has revolutionized medical imaging analysis. This page covers core techniques for applying modern computer vision to clinical applications.
Why CNNs for Medical Imaging?
Medical images have inherent properties that make CNNs ideal:
Spatial Hierarchies: CNNs learn progressively abstract features
- Early layers: Edges, textures, basic patterns
- Middle layers: Anatomical structures, tissue types
- Later layers: Pathological patterns, disease signatures
Translation Invariance: Pathology can appear anywhere in an image
- Tumors don’t always appear in the same location
- CNNs detect patterns regardless of position
- Pooling layers provide spatial invariance
Parameter Efficiency: Compared to fully-connected networks
- Shared weights across spatial locations
- Dramatically fewer parameters
- Faster training on large medical datasets
Transfer Learning: The Essential Strategy
Medical datasets are typically small compared to natural image datasets:
- ImageNet: 1.2M images
- Typical medical dataset: 1,000-50,000 images
- Rare conditions: Often <100 cases
Solution: Transfer learning from pre-trained models
Three-Stage Approach
import torch
import torch.nn as nn
from torchvision.models import resnet50, efficientnet_b0
# Stage 1: Load pre-trained model
model = resnet50(pretrained=True) # Trained on ImageNet
# Stage 2: Replace classification head
num_classes = 5 # e.g., 5 disease categories
model.fc = nn.Linear(model.fc.in_features, num_classes)
# Stage 3: Fine-tune with layer-specific learning rates
optimizer = torch.optim.Adam([
{'params': model.layer1.parameters(), 'lr': 1e-5}, # Freeze early layers
{'params': model.layer2.parameters(), 'lr': 5e-5},
{'params': model.layer3.parameters(), 'lr': 1e-4},
{'params': model.layer4.parameters(), 'lr': 5e-4},
{'params': model.fc.parameters(), 'lr': 1e-3} # Fine-tune head
])Why ImageNet Pre-training Works for Medical Images
Despite visual differences between natural and medical images:
- Low-level features transfer: Edges, textures, gradients are universal
- Mid-level features adapt: Anatomical structures learned during fine-tuning
- High-level features specialize: Disease-specific patterns emerge
Research shows 30-40% performance improvement with transfer learning vs. training from scratch.
Transfer learning works across domains: Similar techniques apply to agricultural disease detection (e.g., plant pathology ), industrial inspection, and satellite imagery analysis—demonstrating the universality of the approach for visual diagnosis tasks.
Common Medical Imaging Modalities
X-Ray and Radiography
- Input: 2D grayscale images (1 channel)
- Resolution: Typically 512×512 to 2048×2048
- Applications: Chest X-ray analysis, fracture detection
- Architectures: ResNet-50, DenseNet-121
CT and MRI Scans
- Input: 3D volumetric data (multiple slices)
- Resolution: 256×256×N slices
- Applications: Tumor segmentation, organ analysis
- Architectures: 3D CNNs, U-Net for segmentation
Pathology Images
- Input: Gigapixel whole-slide images
- Resolution: 10,000×10,000+ pixels
- Applications: Cancer detection, tissue classification
- Architectures: Patch-based CNNs, attention-based aggregation
Ultrasound
- Input: Real-time video or static frames
- Resolution: Variable, often lower quality
- Applications: Cardiac function, fetal monitoring
- Architectures: ResNet, MobileNet (efficiency important)
Architecture Selection
| Use Case | Recommended Architecture | Why |
|---|---|---|
| Limited data (<1000 samples) | ResNet-18, EfficientNet-B0 | Fewer parameters, less overfitting |
| Moderate data (1000-10000) | ResNet-50, DenseNet-121 | Good accuracy-efficiency trade-off |
| Large data (>10000) | ResNet-101, EfficientNet-B4 | Can leverage capacity |
| 3D volumes | 3D ResNet, Med3D | Spatial-temporal modeling |
| Segmentation | U-Net, U-Net++ | Encoder-decoder with skip connections |
| Limited compute | MobileNet, EfficientNet | Optimized for efficiency |
Data Augmentation for Medical Images
Augmentation is critical with limited data, but clinical validity is essential:
Safe Augmentations
import torchvision.transforms as T
safe_augmentations = T.Compose([
T.RandomRotation(degrees=15), # Small rotations OK
T.RandomAffine(degrees=0, # Small shifts OK
translate=(0.1, 0.1)),
T.ColorJitter(brightness=0.2, # Lighting variation
contrast=0.2),
T.GaussianBlur(kernel_size=3), # Simulate noise
T.Normalize(mean=[0.485], std=[0.229]) # Standardize
])Dangerous Augmentations ⚠️
Avoid transformations that change medical meaning:
❌ Horizontal flips: May not preserve anatomy (left vs right side)
- OK for chest X-rays (symmetric)
- NOT OK for liver CT (organ position matters)
❌ Aggressive color changes: Can alter pathology appearance
- Redness may indicate inflammation
- Color changes can obscure disease
❌ Heavy distortions: Can create unrealistic anatomy
Always validate augmentations with clinical experts!
Feature Extraction for Multimodal Fusion
Medical diagnosis often combines multiple data sources. CNNs extract visual features for fusion:
class MedicalMultimodalModel(nn.Module):
def __init__(self, num_outcomes=5):
super().__init__()
# Vision encoder: Extract features from medical images
self.vision_encoder = resnet50(pretrained=True)
self.vision_encoder.fc = nn.Identity() # Remove classifier
vision_dim = 2048
# Text encoder: Clinical notes (from transformers)
from transformers import AutoModel
self.text_encoder = AutoModel.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')
text_dim = 768
# Project to shared dimension
shared_dim = 512
self.vision_projection = nn.Linear(vision_dim, shared_dim)
self.text_projection = nn.Linear(text_dim, shared_dim)
# Fusion and prediction
self.fusion = nn.MultiheadAttention(embed_dim=shared_dim, num_heads=8)
self.classifier = nn.Sequential(
nn.Linear(shared_dim, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_outcomes)
)
def forward(self, image, clinical_notes):
# Extract visual features
image_features = self.vision_encoder(image)
image_emb = self.vision_projection(image_features)
# Extract text features
text_output = self.text_encoder(**clinical_notes)
text_emb = self.text_projection(text_output.pooler_output)
# Fuse modalities via cross-attention
modalities = torch.stack([image_emb, text_emb], dim=1)
fused, _ = self.fusion(modalities, modalities, modalities)
# Predict outcome
pooled = fused.mean(dim=1)
return self.classifier(pooled)Interpretability for Clinical Use
Clinical adoption requires explainability. Doctors must understand model decisions.
1. Saliency Maps (Grad-CAM)
Highlight which image regions influenced predictions:
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
# Create Grad-CAM explainer
cam = GradCAM(model=model, target_layers=[model.layer4])
# Generate heatmap for specific class
grayscale_cam = cam(input_tensor=image, targets=[target_class])
# Overlay on original image
visualization = show_cam_on_image(image, grayscale_cam)Clinical value:
- “Model focused on upper right lung lobe” (appropriate for pneumonia)
- “Model looked at heart shadow” (appropriate for cardiac diagnosis)
- Identifies when model focuses on irrelevant artifacts
2. Feature Visualization
Understand what patterns activate neurons:
def visualize_layer_features(model, image, layer_name):
"""
Visualize what features a layer learns
"""
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hook
# Register hook
layer = dict(model.named_modules())[layer_name]
layer.register_forward_hook(get_activation(layer_name))
# Forward pass
_ = model(image)
# Visualize activations
features = activation[layer_name]
return features # (batch, channels, H, W)3. Attention Visualization (for ViT)
Vision Transformers provide built-in attention maps:
from transformers import ViTForImageClassification
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
# Get attention weights
outputs = model(image, output_attentions=True)
attentions = outputs.attentions # List of attention maps per layer
# Visualize which patches the model attends to
# Shows interpretable spatial reasoningHandling Class Imbalance
Medical datasets often have severe imbalance (rare diseases):
Strategies
1. Weighted Loss Functions:
from sklearn.utils.class_weight import compute_class_weight
class_weights = compute_class_weight('balanced',
classes=np.unique(labels),
y=labels)
criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weights))2. Oversampling Rare Classes:
from imblearn.over_sampling import RandomOverSampler
oversampler = RandomOverSampler(sampling_strategy='minority')
X_resampled, y_resampled = oversampler.fit_resample(X, y)3. Focal Loss (from RetinaNet paper):
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
return focal_loss.mean()Validation Strategies
Standard ML metrics + clinical metrics:
Model Performance Metrics
- AUROC: Area under ROC curve (overall discrimination)
- AUPRC: Area under precision-recall curve (for imbalanced data)
- Sensitivity/Specificity: At clinically relevant thresholds
Clinical Utility Metrics
- Net Benefit: Decision curve analysis
- Positive/Negative Predictive Value: Real-world utility
- Calibration: Do predicted probabilities match true frequencies?
Fairness Metrics
- Subgroup Analysis: Performance across demographics (age, sex, ethnicity)
- Equalized Odds: Fair error rates across groups
- Demographic Parity: Equal positive prediction rates
from sklearn.metrics import roc_auc_score, average_precision_score
# Standard metrics
auroc = roc_auc_score(y_true, y_pred)
auprc = average_precision_score(y_true, y_pred)
# Subgroup analysis
for group in demographic_groups:
group_mask = demographics == group
group_auroc = roc_auc_score(y_true[group_mask], y_pred[group_mask])
print(f"{group}: AUROC = {group_auroc:.3f}")Best Practices Summary
✅ Do:
- Use transfer learning from ImageNet (or medical-specific pre-training)
- Fine-tune with layer-specific learning rates
- Validate augmentations with clinical experts
- Provide interpretability via saliency maps
- Evaluate on clinical utility metrics, not just accuracy
- Perform subgroup analysis for fairness
- Collaborate closely with clinicians
❌ Don’t:
- Train from scratch with limited data
- Use augmentations that change medical meaning
- Ignore class imbalance
- Deploy without interpretability
- Forget regulatory requirements (FDA, CE marking)
- Assume model will work across different hospitals/scanners
Regulatory Considerations
Medical AI must meet regulatory standards:
- FDA (US): Software as Medical Device (SaMD) classification
- CE Mark (Europe): Medical Device Regulation (MDR)
- ISO 13485: Quality management for medical devices
Key requirements:
- Clinical validation studies (retrospective + prospective)
- Risk management (ISO 14971)
- Documentation of training data and model performance
- Post-market surveillance
Related Concepts
Convolution Operations - Core CNN building block
Pooling Layers - Spatial downsampling
Transfer Learning - Pre-training and fine-tuning
ResNet Architecture - Most common medical imaging backbone
Vision Transformers - Alternative to CNNs for imaging
Clinical Interpretability - Explaining model decisions
Learning Resources
Papers
- CheXNet (Rajpurkar et al., 2017): Radiologist-level pneumonia detection
- Deep Learning for Dermatology (Esteva et al., 2017): Skin cancer classification
- Brain Tumor Segmentation (Menze et al., 2015): BraTS challenge overview
Datasets
- ChestX-ray14: 100,000+ chest X-rays with 14 disease labels
- MIMIC-CXR: 377,000 chest X-rays with radiology reports
- BraTS: Brain tumor MRI segmentation challenge
- PathMNIST: Pathology image classification
Libraries
- MONAI: Medical Open Network for AI (PyTorch-based)
- TorchIO: Medical image preprocessing and augmentation
- PyTorch Grad-CAM: Interpretability visualizations