Skip to Content
LibraryConceptsPractical Training

Practical Training Techniques

Training large deep learning models requires careful attention to stability and efficiency. Modern techniques enable training models with billions of parameters that would otherwise be impossible.

Core Techniques

This guide covers four essential techniques:

  1. Learning Rate Warmup - Gradual LR increase for stability
  2. Gradient Clipping - Prevent exploding gradients
  3. Mixed Precision Training - 2-3x speedup with FP16
  4. Layer-wise Learning Rates - Different LRs for different layers

Learning Rate Warmup

Problem: At training start, gradients are uncertain and weights are random. Large learning rates can destabilize training.

Solution: Start with a small learning rate and gradually increase it over the first 1-10% of training.

Why Warmup Works

At the beginning of training:

  • Gradients are uncertain - Model hasn’t seen much data
  • Weights are random - May be far from good solutions
  • Large updates are dangerous - Can destabilize training or cause divergence

Warmup prevents early instabilities by using small, cautious steps initially.

Linear Warmup Implementation

import math def get_lr_with_warmup(step, warmup_steps, max_lr, total_steps): """ Linear warmup followed by cosine decay Args: step: Current training step warmup_steps: Number of warmup steps (typically 1-10% of total) max_lr: Maximum learning rate after warmup total_steps: Total training steps Returns: Current learning rate """ if step < warmup_steps: # Linear warmup: lr = max_lr * (step / warmup_steps) return max_lr * (step / warmup_steps) else: # Cosine decay after warmup progress = (step - warmup_steps) / (total_steps - warmup_steps) return max_lr * 0.5 * (1 + math.cos(math.pi * progress))

Key insight: Learning rate follows this curve:

LR │ ╱────╲ │ ╱ ────╲ │ ╱ ────╲ │ ╱ ────╲ │ ╱ ──── └────────────────────────────────> Steps │←warmup→│←──── cosine decay ──→│

PyTorch Implementation

from torch.optim.lr_scheduler import LambdaLR # Create optimizer optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) # Create warmup scheduler warmup_steps = 1000 total_steps = len(dataloader) * num_epochs def lr_lambda(step): if step < warmup_steps: return step / warmup_steps # After warmup, use cosine decay progress = (step - warmup_steps) / (total_steps - warmup_steps) return 0.5 * (1 + math.cos(math.pi * progress)) scheduler = LambdaLR(optimizer, lr_lambda) # Training loop for epoch in range(num_epochs): for batch in dataloader: optimizer.zero_grad() loss = model(batch) loss.backward() optimizer.step() scheduler.step() # Update LR every step (not per epoch!)

Warmup Guidelines by Model Size

Model SizeWarmup StepsTypical Max LRWarmup Duration
Small (< 50M params)500-10001e-3 to 5e-41-2% of training
Medium (50M-500M)1000-50005e-4 to 1e-42-5% of training
Large (> 500M)5000-100001e-4 to 5e-55-10% of training

Rule of thumb: Use warmup for 1-10% of total training steps.


Gradient Clipping

Problem: Gradients can occasionally explode to very large values, causing training to diverge.

Solution: Limit gradient magnitude to a maximum threshold.

Two Types of Clipping

Scales gradients if their total norm exceeds a threshold:

import torch.nn as nn # Clip gradient norm to maximum value max_norm = 1.0 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) # Formula: if ||g|| > max_norm, scale g = g * (max_norm / ||g||)

How it works:

  1. Compute total gradient norm: ||g|| = sqrt(sum(g_i^2))
  2. If ||g|| > max_norm, scale all gradients: g = g * (max_norm / ||g||)
  3. Otherwise, leave gradients unchanged

Advantage: Preserves gradient direction, only scales magnitude.

2. Clip by Value

Clips each gradient element independently:

# Clip each gradient element to range [-clip_value, clip_value] clip_value = 0.5 torch.nn.utils.clip_grad_value_(model.parameters(), clip_value)

Disadvantage: Can distort gradient directions. Less common.

When to Use Gradient Clipping

Always use for:

  • RNNs and LSTMs - Prone to exploding gradients
  • Transformers - Deep architecture with multiplicative interactions
  • Very deep networks (> 50 layers) - Long backpropagation paths
  • Long sequence training - More timesteps = more gradient accumulation

⚠️ May not need for:

  • Small CNNs with batch normalization
  • Well-initialized architectures (ResNets with proper initialization)
  • Shallow networks with careful learning rate tuning

Practical Training Loop

for batch in dataloader: optimizer.zero_grad() # Forward pass outputs = model(batch['input']) loss = criterion(outputs, batch['target']) # Backward pass loss.backward() # Clip gradients BEFORE optimizer step torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Update weights optimizer.step()

Monitoring Gradient Norms

Track gradient norms to choose appropriate clipping threshold:

def get_grad_norm(model): """Compute total gradient norm across all parameters""" total_norm = 0.0 for p in model.parameters(): if p.grad is not None: param_norm = p.grad.data.norm(2) total_norm += param_norm.item() ** 2 total_norm = total_norm ** 0.5 return total_norm # During training grad_norm = get_grad_norm(model) print(f"Gradient norm: {grad_norm:.2f}") # If you frequently see gradient norms > 10, add clipping # Set max_norm slightly above typical gradient norms (e.g., 1.0-5.0)

Guideline: If gradient norms frequently exceed 10-100, use clipping with max_norm=1.0 or max_norm=5.0.


Mixed Precision Training

Problem: Training large models in full precision (FP32) is slow and memory-intensive.

Solution: Use FP16 for most operations, FP32 for critical steps. Get 2-3x speedup with minimal accuracy loss.

Benefits

  • 2-3x speedup - FP16 operations are faster on modern GPUs (Tensor Cores)
  • 50% memory reduction - FP16 uses half the memory of FP32
  • Larger batch sizes - Reduced memory enables bigger batches
  • Minimal quality loss - Careful gradient scaling prevents underflow

How Mixed Precision Works

from torch.cuda.amp import autocast, GradScaler # Create gradient scaler (handles FP16 underflow) scaler = GradScaler() for batch in dataloader: optimizer.zero_grad() # Forward pass in FP16 (automatic casting) with autocast(): outputs = model(batch['input']) loss = criterion(outputs, batch['target']) # Backward pass with scaled gradients (prevents FP16 underflow) scaler.scale(loss).backward() # Unscale gradients before clipping (to FP32) scaler.unscale_(optimizer) # Gradient clipping (on FP32 gradients) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Optimizer step with scaler scaler.step(optimizer) # Update scaler for next iteration scaler.update()

Why Gradient Scaling?

FP16 has limited range (≈ 6e-8 to 65504):

Problem: Small gradients underflow to zero

gradient = 1e-8 # Underflows to 0 in FP16

Solution: Scale gradients up during backward pass

scale_factor = 65536 # 2^16 scaled_gradient = 1e-8 * 65536 # = 0.00065 (representable in FP16) # After backward pass, divide by scale_factor before updating weights

Process:

  1. Scale loss before backward pass (multiply by large constant)
  2. Compute gradients in scaled space (prevents underflow)
  3. Unscale gradients before optimizer step (divide by constant)
  4. Update weights in FP32 (master weights)

Mixed Precision Best Practices

# ✅ DO: Use autocast for forward pass with autocast(): loss = model(input) # ✅ DO: Scale loss before backward scaler.scale(loss).backward() # ✅ DO: Unscale before gradient clipping scaler.unscale_(optimizer) clip_grad_norm_(model.parameters(), 1.0) # ✅ DO: Call scaler.update() every iteration scaler.update() # ❌ DON'T: Use autocast for data loading or preprocessing # ❌ DON'T: Use autocast for loss computation if loss requires high precision # ❌ DON'T: Forget to call scaler.update()

What Gets Computed in FP16 vs FP32?

OperationPrecisionReason
Matrix multiplicationsFP16Fast on Tensor Cores
ConvolutionsFP16Fast on Tensor Cores
ActivationsFP16Memory savings
Batch normFP32Numerical stability
SoftmaxFP32Numerical stability
Loss computationFP32Precision needed
Optimizer stateFP32Master weights

Layer-wise Learning Rates

Problem: Different layers may need different learning rates (especially during fine-tuning).

Solution: Assign different learning rates to different parameter groups.

When to Use

Transfer Learning / Fine-tuning:

  • Lower LR for pre-trained layers - Preserve learned features
  • Higher LR for new layers - Learn task-specific features quickly

Deep Networks:

  • Lower LR for early layers - Low-level features (edges, textures)
  • Higher LR for later layers - High-level features (task-specific)

Implementation with Parameter Groups

# Example: Fine-tuning a pre-trained model model = PretrainedModel() optimizer = torch.optim.AdamW([ { 'params': model.encoder.parameters(), 'lr': 1e-5, # Low LR for pre-trained encoder 'weight_decay': 0.01 }, { 'params': model.head.parameters(), 'lr': 1e-3, # High LR for new classification head 'weight_decay': 0.0 # No decay on head } ])

LR ratio: New layers typically use 10-100x higher LR than pre-trained layers.

Layer-wise LR Decay (for very deep networks)

Progressively decrease learning rate for earlier layers:

def get_layer_wise_lr(model, base_lr=1e-3, decay_rate=0.95): """ Assign decreasing learning rates to earlier layers Args: base_lr: Learning rate for final layer decay_rate: Multiplicative decay per layer (e.g., 0.95) Returns: List of parameter groups with layer-specific LRs """ param_groups = [] num_layers = len(list(model.named_parameters())) for i, (name, param) in enumerate(model.named_parameters()): # Earlier layers get lower LR lr = base_lr * (decay_rate ** (num_layers - i)) param_groups.append({ 'params': param, 'lr': lr, 'name': name }) return param_groups optimizer = torch.optim.AdamW( get_layer_wise_lr(model, base_lr=1e-3, decay_rate=0.95) )

BERT-style Discriminative Fine-tuning

For transformer models like BERT:

def get_bert_param_groups(model, lr=2e-5): """ BERT-style discriminative learning rates: - Embeddings: lr - Each encoder layer: lr * (0.95 ^ (num_layers - layer_idx)) - Classification head: lr * 10 """ return [ # Embeddings (lowest LR) { 'params': model.embeddings.parameters(), 'lr': lr, 'weight_decay': 0.01 }, # BERT encoder layers (with layer-wise decay) *[ { 'params': layer.parameters(), 'lr': lr * (0.95 ** (12 - i)), 'weight_decay': 0.01 } for i, layer in enumerate(model.encoder.layer) ], # Classification head (highest LR) { 'params': model.classifier.parameters(), 'lr': lr * 10, 'weight_decay': 0.01 } ] optimizer = torch.optim.AdamW(get_bert_param_groups(model, lr=2e-5))

Typical ratios:

  • Pre-trained encoder: 1x (base LR)
  • New task head: 10-100x (10-100 times higher)

Putting It All Together

Complete training loop combining all techniques:

import torch from torch.cuda.amp import autocast, GradScaler from torch.optim.lr_scheduler import LambdaLR import math # Model setup with layer-wise LR model = MyLargeModel() optimizer = torch.optim.AdamW([ {'params': model.encoder.parameters(), 'lr': 1e-5}, {'params': model.head.parameters(), 'lr': 1e-3} ], weight_decay=0.01) # Warmup + cosine decay scheduler warmup_steps = 1000 total_steps = len(dataloader) * num_epochs def lr_lambda(step): if step < warmup_steps: # Linear warmup return step / warmup_steps # Cosine decay progress = (step - warmup_steps) / (total_steps - warmup_steps) return 0.5 * (1 + math.cos(math.pi * progress)) scheduler = LambdaLR(optimizer, lr_lambda) # Mixed precision scaler scaler = GradScaler() # Training loop global_step = 0 for epoch in range(num_epochs): for batch in dataloader: optimizer.zero_grad() # Mixed precision forward pass with autocast(): outputs = model(batch['input']) loss = criterion(outputs, batch['target']) # Scaled backward pass scaler.scale(loss).backward() # Unscale gradients before clipping scaler.unscale_(optimizer) # Gradient clipping torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Optimizer step with scaler scaler.step(optimizer) scaler.update() # Update learning rate scheduler.step() global_step += 1 # Logging if global_step % 100 == 0: current_lr = scheduler.get_last_lr()[0] grad_norm = get_grad_norm(model) print(f"Step {global_step}, LR: {current_lr:.2e}, " f"Loss: {loss.item():.4f}, Grad Norm: {grad_norm:.2f}")

Technique Summary

TechniqueWhen to UsePrimary BenefitComplexity
WarmupAlways (transformers, large models)Training stabilityLow
Gradient ClippingRNNs, Transformers, Deep netsPrevents divergenceLow
Mixed PrecisionGPU training (V100, A100, RTX)2-3x speed + 50% memoryMedium
Layer-wise LRFine-tuning, transfer learningBetter adaptationMedium

Quick Decision Guide

Starting a new project?

  1. ✅ Start with warmup (1-5% of training steps)
  2. ✅ Add gradient clipping if training is unstable
  3. ✅ Use mixed precision if training on GPU
  4. ⚠️ Add layer-wise LR only if fine-tuning

Debugging training issues?

  • Loss is NaN → Add gradient clipping, reduce LR, check data
  • Training is slow → Use mixed precision, increase batch size
  • Model won’t converge → Increase warmup steps, reduce max LR
  • Fine-tuning hurts pre-trained features → Use layer-wise LR

Healthcare AI Applications

For multimodal healthcare models (e.g., EHR + imaging):

# Typical setup for healthcare multimodal model optimizer = torch.optim.AdamW([ # Pre-trained image encoder (ResNet or ViT) {'params': model.image_encoder.parameters(), 'lr': 1e-5}, # Pre-trained EHR encoder (ETHOS or BEHRT) {'params': model.ehr_encoder.parameters(), 'lr': 1e-5}, # New fusion module {'params': model.fusion.parameters(), 'lr': 5e-4}, # New prediction head {'params': model.outcome_head.parameters(), 'lr': 1e-3} ]) # Conservative settings for medical data warmup_steps = len(dataloader) * 2 # 2 epochs warmup max_norm = 1.0 # Gradient clipping use_amp = True # Mixed precision for speed

Medical AI considerations:

  • Conservative warmup - Medical data is noisy, use longer warmup
  • Lower learning rates - Pre-trained clinical models are valuable
  • Careful monitoring - Track metrics on multiple demographic groups

Further Reading

Warmup

  • “Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour” (Goyal et al., 2017)
  • “On the adequacy of untuned warmup for adaptive optimization” (Ma & Yarats, 2019)

Mixed Precision

Layer-wise Learning Rates

  • “Universal Language Model Fine-tuning for Text Classification” (Howard & Ruder, 2018) - Discriminative fine-tuning for BERT
  • “BERT: Pre-training of Deep Bidirectional Transformers” (Devlin et al., 2019)

Tools