Practical Training Techniques
Training large deep learning models requires careful attention to stability and efficiency. Modern techniques enable training models with billions of parameters that would otherwise be impossible.
Core Techniques
This guide covers four essential techniques:
- Learning Rate Warmup - Gradual LR increase for stability
- Gradient Clipping - Prevent exploding gradients
- Mixed Precision Training - 2-3x speedup with FP16
- Layer-wise Learning Rates - Different LRs for different layers
Learning Rate Warmup
Problem: At training start, gradients are uncertain and weights are random. Large learning rates can destabilize training.
Solution: Start with a small learning rate and gradually increase it over the first 1-10% of training.
Why Warmup Works
At the beginning of training:
- Gradients are uncertain - Model hasn’t seen much data
- Weights are random - May be far from good solutions
- Large updates are dangerous - Can destabilize training or cause divergence
Warmup prevents early instabilities by using small, cautious steps initially.
Linear Warmup Implementation
import math
def get_lr_with_warmup(step, warmup_steps, max_lr, total_steps):
"""
Linear warmup followed by cosine decay
Args:
step: Current training step
warmup_steps: Number of warmup steps (typically 1-10% of total)
max_lr: Maximum learning rate after warmup
total_steps: Total training steps
Returns:
Current learning rate
"""
if step < warmup_steps:
# Linear warmup: lr = max_lr * (step / warmup_steps)
return max_lr * (step / warmup_steps)
else:
# Cosine decay after warmup
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return max_lr * 0.5 * (1 + math.cos(math.pi * progress))Key insight: Learning rate follows this curve:
LR
│ ╱────╲
│ ╱ ────╲
│ ╱ ────╲
│ ╱ ────╲
│ ╱ ────
└────────────────────────────────> Steps
│←warmup→│←──── cosine decay ──→│PyTorch Implementation
from torch.optim.lr_scheduler import LambdaLR
# Create optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
# Create warmup scheduler
warmup_steps = 1000
total_steps = len(dataloader) * num_epochs
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
# After warmup, use cosine decay
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return 0.5 * (1 + math.cos(math.pi * progress))
scheduler = LambdaLR(optimizer, lr_lambda)
# Training loop
for epoch in range(num_epochs):
for batch in dataloader:
optimizer.zero_grad()
loss = model(batch)
loss.backward()
optimizer.step()
scheduler.step() # Update LR every step (not per epoch!)Warmup Guidelines by Model Size
| Model Size | Warmup Steps | Typical Max LR | Warmup Duration |
|---|---|---|---|
| Small (< 50M params) | 500-1000 | 1e-3 to 5e-4 | 1-2% of training |
| Medium (50M-500M) | 1000-5000 | 5e-4 to 1e-4 | 2-5% of training |
| Large (> 500M) | 5000-10000 | 1e-4 to 5e-5 | 5-10% of training |
Rule of thumb: Use warmup for 1-10% of total training steps.
Gradient Clipping
Problem: Gradients can occasionally explode to very large values, causing training to diverge.
Solution: Limit gradient magnitude to a maximum threshold.
Two Types of Clipping
1. Clip by Norm (Recommended)
Scales gradients if their total norm exceeds a threshold:
import torch.nn as nn
# Clip gradient norm to maximum value
max_norm = 1.0
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
# Formula: if ||g|| > max_norm, scale g = g * (max_norm / ||g||)How it works:
- Compute total gradient norm:
||g|| = sqrt(sum(g_i^2)) - If
||g|| > max_norm, scale all gradients:g = g * (max_norm / ||g||) - Otherwise, leave gradients unchanged
Advantage: Preserves gradient direction, only scales magnitude.
2. Clip by Value
Clips each gradient element independently:
# Clip each gradient element to range [-clip_value, clip_value]
clip_value = 0.5
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value)Disadvantage: Can distort gradient directions. Less common.
When to Use Gradient Clipping
✅ Always use for:
- RNNs and LSTMs - Prone to exploding gradients
- Transformers - Deep architecture with multiplicative interactions
- Very deep networks (> 50 layers) - Long backpropagation paths
- Long sequence training - More timesteps = more gradient accumulation
⚠️ May not need for:
- Small CNNs with batch normalization
- Well-initialized architectures (ResNets with proper initialization)
- Shallow networks with careful learning rate tuning
Practical Training Loop
for batch in dataloader:
optimizer.zero_grad()
# Forward pass
outputs = model(batch['input'])
loss = criterion(outputs, batch['target'])
# Backward pass
loss.backward()
# Clip gradients BEFORE optimizer step
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Update weights
optimizer.step()Monitoring Gradient Norms
Track gradient norms to choose appropriate clipping threshold:
def get_grad_norm(model):
"""Compute total gradient norm across all parameters"""
total_norm = 0.0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
return total_norm
# During training
grad_norm = get_grad_norm(model)
print(f"Gradient norm: {grad_norm:.2f}")
# If you frequently see gradient norms > 10, add clipping
# Set max_norm slightly above typical gradient norms (e.g., 1.0-5.0)Guideline: If gradient norms frequently exceed 10-100, use clipping with max_norm=1.0 or max_norm=5.0.
Mixed Precision Training
Problem: Training large models in full precision (FP32) is slow and memory-intensive.
Solution: Use FP16 for most operations, FP32 for critical steps. Get 2-3x speedup with minimal accuracy loss.
Benefits
- 2-3x speedup - FP16 operations are faster on modern GPUs (Tensor Cores)
- 50% memory reduction - FP16 uses half the memory of FP32
- Larger batch sizes - Reduced memory enables bigger batches
- Minimal quality loss - Careful gradient scaling prevents underflow
How Mixed Precision Works
from torch.cuda.amp import autocast, GradScaler
# Create gradient scaler (handles FP16 underflow)
scaler = GradScaler()
for batch in dataloader:
optimizer.zero_grad()
# Forward pass in FP16 (automatic casting)
with autocast():
outputs = model(batch['input'])
loss = criterion(outputs, batch['target'])
# Backward pass with scaled gradients (prevents FP16 underflow)
scaler.scale(loss).backward()
# Unscale gradients before clipping (to FP32)
scaler.unscale_(optimizer)
# Gradient clipping (on FP32 gradients)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Optimizer step with scaler
scaler.step(optimizer)
# Update scaler for next iteration
scaler.update()Why Gradient Scaling?
FP16 has limited range (≈ 6e-8 to 65504):
Problem: Small gradients underflow to zero
gradient = 1e-8 # Underflows to 0 in FP16Solution: Scale gradients up during backward pass
scale_factor = 65536 # 2^16
scaled_gradient = 1e-8 * 65536 # = 0.00065 (representable in FP16)
# After backward pass, divide by scale_factor before updating weightsProcess:
- Scale loss before backward pass (multiply by large constant)
- Compute gradients in scaled space (prevents underflow)
- Unscale gradients before optimizer step (divide by constant)
- Update weights in FP32 (master weights)
Mixed Precision Best Practices
# ✅ DO: Use autocast for forward pass
with autocast():
loss = model(input)
# ✅ DO: Scale loss before backward
scaler.scale(loss).backward()
# ✅ DO: Unscale before gradient clipping
scaler.unscale_(optimizer)
clip_grad_norm_(model.parameters(), 1.0)
# ✅ DO: Call scaler.update() every iteration
scaler.update()
# ❌ DON'T: Use autocast for data loading or preprocessing
# ❌ DON'T: Use autocast for loss computation if loss requires high precision
# ❌ DON'T: Forget to call scaler.update()What Gets Computed in FP16 vs FP32?
| Operation | Precision | Reason |
|---|---|---|
| Matrix multiplications | FP16 | Fast on Tensor Cores |
| Convolutions | FP16 | Fast on Tensor Cores |
| Activations | FP16 | Memory savings |
| Batch norm | FP32 | Numerical stability |
| Softmax | FP32 | Numerical stability |
| Loss computation | FP32 | Precision needed |
| Optimizer state | FP32 | Master weights |
Layer-wise Learning Rates
Problem: Different layers may need different learning rates (especially during fine-tuning).
Solution: Assign different learning rates to different parameter groups.
When to Use
Transfer Learning / Fine-tuning:
- Lower LR for pre-trained layers - Preserve learned features
- Higher LR for new layers - Learn task-specific features quickly
Deep Networks:
- Lower LR for early layers - Low-level features (edges, textures)
- Higher LR for later layers - High-level features (task-specific)
Implementation with Parameter Groups
# Example: Fine-tuning a pre-trained model
model = PretrainedModel()
optimizer = torch.optim.AdamW([
{
'params': model.encoder.parameters(),
'lr': 1e-5, # Low LR for pre-trained encoder
'weight_decay': 0.01
},
{
'params': model.head.parameters(),
'lr': 1e-3, # High LR for new classification head
'weight_decay': 0.0 # No decay on head
}
])LR ratio: New layers typically use 10-100x higher LR than pre-trained layers.
Layer-wise LR Decay (for very deep networks)
Progressively decrease learning rate for earlier layers:
def get_layer_wise_lr(model, base_lr=1e-3, decay_rate=0.95):
"""
Assign decreasing learning rates to earlier layers
Args:
base_lr: Learning rate for final layer
decay_rate: Multiplicative decay per layer (e.g., 0.95)
Returns:
List of parameter groups with layer-specific LRs
"""
param_groups = []
num_layers = len(list(model.named_parameters()))
for i, (name, param) in enumerate(model.named_parameters()):
# Earlier layers get lower LR
lr = base_lr * (decay_rate ** (num_layers - i))
param_groups.append({
'params': param,
'lr': lr,
'name': name
})
return param_groups
optimizer = torch.optim.AdamW(
get_layer_wise_lr(model, base_lr=1e-3, decay_rate=0.95)
)BERT-style Discriminative Fine-tuning
For transformer models like BERT:
def get_bert_param_groups(model, lr=2e-5):
"""
BERT-style discriminative learning rates:
- Embeddings: lr
- Each encoder layer: lr * (0.95 ^ (num_layers - layer_idx))
- Classification head: lr * 10
"""
return [
# Embeddings (lowest LR)
{
'params': model.embeddings.parameters(),
'lr': lr,
'weight_decay': 0.01
},
# BERT encoder layers (with layer-wise decay)
*[
{
'params': layer.parameters(),
'lr': lr * (0.95 ** (12 - i)),
'weight_decay': 0.01
}
for i, layer in enumerate(model.encoder.layer)
],
# Classification head (highest LR)
{
'params': model.classifier.parameters(),
'lr': lr * 10,
'weight_decay': 0.01
}
]
optimizer = torch.optim.AdamW(get_bert_param_groups(model, lr=2e-5))Typical ratios:
- Pre-trained encoder:
1x(base LR) - New task head:
10-100x(10-100 times higher)
Putting It All Together
Complete training loop combining all techniques:
import torch
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import LambdaLR
import math
# Model setup with layer-wise LR
model = MyLargeModel()
optimizer = torch.optim.AdamW([
{'params': model.encoder.parameters(), 'lr': 1e-5},
{'params': model.head.parameters(), 'lr': 1e-3}
], weight_decay=0.01)
# Warmup + cosine decay scheduler
warmup_steps = 1000
total_steps = len(dataloader) * num_epochs
def lr_lambda(step):
if step < warmup_steps:
# Linear warmup
return step / warmup_steps
# Cosine decay
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return 0.5 * (1 + math.cos(math.pi * progress))
scheduler = LambdaLR(optimizer, lr_lambda)
# Mixed precision scaler
scaler = GradScaler()
# Training loop
global_step = 0
for epoch in range(num_epochs):
for batch in dataloader:
optimizer.zero_grad()
# Mixed precision forward pass
with autocast():
outputs = model(batch['input'])
loss = criterion(outputs, batch['target'])
# Scaled backward pass
scaler.scale(loss).backward()
# Unscale gradients before clipping
scaler.unscale_(optimizer)
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Optimizer step with scaler
scaler.step(optimizer)
scaler.update()
# Update learning rate
scheduler.step()
global_step += 1
# Logging
if global_step % 100 == 0:
current_lr = scheduler.get_last_lr()[0]
grad_norm = get_grad_norm(model)
print(f"Step {global_step}, LR: {current_lr:.2e}, "
f"Loss: {loss.item():.4f}, Grad Norm: {grad_norm:.2f}")Technique Summary
| Technique | When to Use | Primary Benefit | Complexity |
|---|---|---|---|
| Warmup | Always (transformers, large models) | Training stability | Low |
| Gradient Clipping | RNNs, Transformers, Deep nets | Prevents divergence | Low |
| Mixed Precision | GPU training (V100, A100, RTX) | 2-3x speed + 50% memory | Medium |
| Layer-wise LR | Fine-tuning, transfer learning | Better adaptation | Medium |
Quick Decision Guide
Starting a new project?
- ✅ Start with warmup (1-5% of training steps)
- ✅ Add gradient clipping if training is unstable
- ✅ Use mixed precision if training on GPU
- ⚠️ Add layer-wise LR only if fine-tuning
Debugging training issues?
- Loss is NaN → Add gradient clipping, reduce LR, check data
- Training is slow → Use mixed precision, increase batch size
- Model won’t converge → Increase warmup steps, reduce max LR
- Fine-tuning hurts pre-trained features → Use layer-wise LR
Healthcare AI Applications
For multimodal healthcare models (e.g., EHR + imaging):
# Typical setup for healthcare multimodal model
optimizer = torch.optim.AdamW([
# Pre-trained image encoder (ResNet or ViT)
{'params': model.image_encoder.parameters(), 'lr': 1e-5},
# Pre-trained EHR encoder (ETHOS or BEHRT)
{'params': model.ehr_encoder.parameters(), 'lr': 1e-5},
# New fusion module
{'params': model.fusion.parameters(), 'lr': 5e-4},
# New prediction head
{'params': model.outcome_head.parameters(), 'lr': 1e-3}
])
# Conservative settings for medical data
warmup_steps = len(dataloader) * 2 # 2 epochs warmup
max_norm = 1.0 # Gradient clipping
use_amp = True # Mixed precision for speedMedical AI considerations:
- Conservative warmup - Medical data is noisy, use longer warmup
- Lower learning rates - Pre-trained clinical models are valuable
- Careful monitoring - Track metrics on multiple demographic groups
Related Concepts
- Optimization - Foundation optimization algorithms (SGD, Adam)
- Transformer Training - Transformer-specific training techniques
- Training Dynamics - Understanding double descent and scaling
- Language Model Training - LM-specific training details
- Regularization - Preventing overfitting
Further Reading
Warmup
- “Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour” (Goyal et al., 2017)
- “On the adequacy of untuned warmup for adaptive optimization” (Ma & Yarats, 2019)
Mixed Precision
- PyTorch Mixed Precision Documentation
- “Mixed Precision Training” (Micikevicius et al., 2018)
- NVIDIA Apex
Layer-wise Learning Rates
- “Universal Language Model Fine-tuning for Text Classification” (Howard & Ruder, 2018) - Discriminative fine-tuning for BERT
- “BERT: Pre-training of Deep Bidirectional Transformers” (Devlin et al., 2019)
Tools
- PyTorch Lightning - Handles mixed precision, gradient clipping automatically
- Hugging Face Trainer - Built-in support for all techniques
- DeepSpeed - Advanced optimization for very large models