Training Dynamics and Double Descent
Modern deep learning defies classical machine learning wisdom. Models with billions of parameters that perfectly fit training data still generalize remarkably well. This phenomenon - and the theory explaining it - represents one of the most important shifts in our understanding of how deep learning works.
The Classical View (Incomplete)
Traditional machine learning taught a simple story:
Classical Bias-Variance Tradeoff:
Test Error
|
| ╱╲
| ╱ ╲
|╱ ╲
| ╲___
|___________
Model Size
← Underfit | Optimal | Overfit →Classical wisdom:
- Underfit when model is too simple
- Optimal at “just right” complexity
- Overfit when model is too complex
Prescription: Keep models simple, add regularization, stop before overfitting.
The problem: This doesn’t explain GPT-4, Stable Diffusion, or any modern large-scale model!
The Modern Reality: Double Descent
Test error actually follows a double descent curve with model complexity:
Test Error
|
| ╱╲ Classical
| ╱ ╲ U-curve
|╱ ╲___
| ╱ ╲___
| ╱ ╲___
|____╱ ╲___ Modern regime
| ↑
Interpolation
threshold
|________________________
Model Size
← Under | Critical | Overparameterized →Three Distinct Regimes
| Regime | Behavior | Performance | Example |
|---|---|---|---|
| Underparameterized | Classical bias-variance tradeoff | Improves then degrades | Linear models, small networks |
| Interpolation threshold | Can just barely fit training data | Worst performance | Critical model size |
| Overparameterized | Many ways to fit training data | Improves with size! | Modern deep networks |
| Highly overparameterized | Massive capacity | Best performance | GPT-4, CLIP, Stable Diffusion |
The Interpolation Threshold
The interpolation threshold is where the model has just enough parameters to perfectly fit the training data - and this is often where performance is worst!
Why the threshold is dangerous:
- Below: Model can’t fit all training examples → high bias
- At threshold: Model can barely fit training data → fits noise, overfits badly
- Above threshold: Model can fit training data many different ways → finds good solutions
The worst performance often occurs at the interpolation threshold - when your model can just barely fit the training data. Going larger actually helps generalization!
Implications for Model Design
This completely changes how we approach deep learning:
✅ Modern principles:
- Larger models generalize better (counterintuitive!)
- Don’t stop at interpolation threshold - go bigger
- Perfect training accuracy is okay - doesn’t necessarily mean overfitting
- Early stopping may hurt in overparameterized regime
- Let training continue past perfect train accuracy
❌ Outdated principles:
Stop training when you hit perfect accuracyAlways add more regularizationKeep models as small as possiblePerfect train accuracy means overfitting
GPT-4: 1.8T parameters trained on ~13T tokens
Stable Diffusion: 890M parameters trained on 2B image-text pairs
CLIP: 400M parameters trained on 400M image-text pairs
All of these models are massively overparameterized relative to classical ML intuitions, yet they generalize exceptionally well.
Why Overparameterization Helps
Multiple complementary mechanisms explain why larger models generalize better:
1. Implicit Regularization via SGD
Stochastic Gradient Descent doesn’t just find any solution - it prefers certain solutions that generalize well.
Flat vs Sharp Minima:
Loss Landscape
Flat Minimum (Good): Sharp Minimum (Bad):
___ /\
/ \ / \
/ \ / \
/ \ / \
Robust to perturbations Sensitive to perturbations
Generalizes well OverfitsHow SGD finds flat minima:
- Stochastic noise: Batch sampling randomness explores broadly
- Learning rate: Acts as temperature, prefers wider basins
- Momentum: Smooths optimization path, avoids sharp valleys
- Batch size: Smaller batches → more noise → flatter minima
# SGD implicitly regularizes through:
# 1. Stochastic batch sampling (randomness)
# 2. Learning rate (temperature)
# 3. Momentum (smoothing)
optimizer = torch.optim.SGD(
model.parameters(),
lr=0.01, # Larger LR → stronger implicit regularization
momentum=0.9, # Smooths optimization
weight_decay=0 # Often not needed in overparameterized regime!
)
# Small batches also help (more stochastic noise)
data_loader = DataLoader(dataset, batch_size=32) # vs 10242. More Paths to Good Solutions
With more parameters, there are exponentially more ways to fit the data:
| Regime | Solution Space | Optimization |
|---|---|---|
| Underparameterized | Few solutions, may not include good ones | Hard - limited paths |
| At threshold | Unique solution, often bad | Hardest - no flexibility |
| Overparameterized | Many solutions, including good ones | Easier - many paths |
The larger the model, the more likely SGD will find a solution that:
- Fits training data (necessary)
- Uses simple patterns (implicit regularization)
- Generalizes to test data (desired outcome)
3. The Lottery Ticket Hypothesis
Core claim: A randomly initialized dense network contains sparse “winning ticket” subnetworks that can train to similar accuracy.
Finding winning tickets:
def find_lottery_ticket(model, data, prune_ratio=0.9):
"""
Lottery Ticket Hypothesis implementation
"""
# Step 1: Save initial random weights
initial_weights = copy.deepcopy(model.state_dict())
# Step 2: Train network to convergence
train(model, data, epochs=100)
# Step 3: Identify important weights by magnitude
importance = {
name: param.abs()
for name, param in model.named_parameters()
}
# Step 4: Create mask keeping top (1 - prune_ratio) weights
threshold = torch.quantile(
torch.cat([v.flatten() for v in importance.values()]),
prune_ratio
)
mask = {
name: (importance[name] > threshold).float()
for name in importance
}
# Step 5: Reset to initial weights
model.load_state_dict(initial_weights)
# Step 6: Train sparse network with mask
train_with_mask(model, data, mask, epochs=100)
# Result: Sparse network (10% of weights) achieves
# similar performance to full dense network!
return model, maskImplications:
- Dense networks easier to train: More “lottery tickets” means higher chance of lucky initialization
- Initialization matters: The right sparse network exists, but you need the right initial weights
- Pruning works: Train large → prune → fine-tune is effective
- Overparameterization helps optimization: Not just representation capacity!
4. Grokking: Delayed Generalization
Models can suddenly generalize long after achieving perfect training accuracy!
The grokking phenomenon:
Accuracy
100% | _____ Test
| /
| ___________/
50% | /
|____/ Training
|
|_________________________
0 1k 5k 10k epochsTimeline:
- Epoch 1000: 100% training accuracy, 50% test accuracy (memorization)
- Epoch 10000: 100% training accuracy, 100% test accuracy (generalization!)
- Model “groks” the underlying pattern after memorizing
Why grokking happens:
- Phase 1 (fast): Memorize training data (easy, any solution works)
- Phase 2 (slow): Implicit regularization gradually simplifies solution
- Phase 3 (sudden): Model discovers generalizable pattern
Don’t stop training too early! Generalization can happen long after you’ve achieved perfect training accuracy. The model is still improving through implicit regularization.
Practical example:
# Training a modular arithmetic task (e.g., a + b mod 97)
# Epoch 1000:
# - Training: 100% (memorized all 97*97 = 9409 examples)
# - Test: 50% (no generalization)
# Epoch 5000:
# - Training: 100% (still memorized)
# - Test: 50% (still no generalization)
# Epoch 10000:
# - Training: 100%
# - Test: 100% (suddenly discovered modular structure!)
# The model grokked that:
# (a + b) mod n = ((a mod n) + (b mod n)) mod nPractical Implications
For Model Architecture Design
When designing models:
✅ Do:
- Use large models when computationally feasible
- Don’t fear overparameterization
- Trust implicit regularization (especially with SGD)
- Let training continue past perfect train accuracy
- Consider scaling up before adding explicit regularization
❌ Don’t:
- Stop at interpolation threshold (worst performance!)
- Assume perfect training accuracy means overfitting
- Over-regularize large models (can prevent grokking)
- Give up training too early
- Always default to smallest possible model
For Training Strategy
Effective training in overparameterized regime:
# Good practices for large models
# 1. Use learning rate warmup
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=1000, # Gradual warmup
num_training_steps=10000
)
# 2. Small batch sizes for implicit regularization
batch_size = 32 # vs 1024
# 3. SGD or Adam with reasonable learning rate
optimizer = torch.optim.AdamW(
model.parameters(),
lr=1e-4, # Not too small!
weight_decay=0.01 # Light weight decay if any
)
# 4. Train longer than you think
# - Don't stop at perfect train accuracy
# - Wait for generalization (grokking)
# - Monitor validation loss, not just train loss
epochs = 100 # Or until validation stops improving
# 5. Gradient clipping for stability
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)For Healthcare AI and Domain Applications
Scenario: Training a multimodal model for EHR + symptoms with:
- Model: 50M parameters
- Training data: 100K labeled examples
Analysis: You’re in the overparameterized regime (500 parameters per example).
Strategy:
- ✅ Use the full 50M parameter model (don’t reduce size)
- ✅ Train to perfect training accuracy, then keep going
- ✅ Use light regularization (dropout 0.1, small weight decay)
- ✅ Let training run 50-100 epochs past perfect train accuracy
- ❌ Don’t stop early due to “overfitting” fears
- ❌ Don’t aggressively regularize (can prevent learning)
For Debugging Training Issues
If model isn’t learning:
-
Check if you’re at interpolation threshold
# Can model fit training data? train_acc = evaluate(model, train_loader) if train_acc < 95%: # Underparameterized or at threshold # → Try larger model! model = increase_model_size(model) -
Check optimization stability
# Use gradient clipping torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Use learning rate warmup scheduler = WarmupScheduler(...) # Check for NaN gradients if torch.isnan(loss): # Reduce learning rate or add gradient clipping -
Give training more time
# Don't stop at epoch 10 if still improving # Grokking can take 100+ epochs patience = 50 # Wait 50 epochs for improvement
Theoretical Foundations
Neural Tangent Kernel (NTK) Theory
In the infinite-width limit, neural networks behave like kernel methods:
- Infinite width: Network barely changes during training (lazy regime)
- Finite width: Network adapts, features learn (feature learning regime)
- Practical networks: Between lazy and feature learning
This theory helps explain why overparameterized networks generalize but is complex.
Scaling Laws
Empirical scaling laws for large language models show consistent patterns:
# Chinchilla scaling law (approximate)
optimal_tokens = 20 * num_parameters
# For a 10B parameter model:
optimal_training = 20 * 10B = 200B tokens
# Test loss scales predictably:
# L(N, D) ∝ (N_c / N)^α + (D_c / D)^β
# where N = params, D = data, α ≈ 0.076, β ≈ 0.095See Language Model Scaling Laws for details.
Historical Context
- 2017: Deep learning achieves state-of-the-art despite “overfitting” classical intuitions
- 2018: “Deep Double Descent” paper (Belkin et al.) formalizes the phenomenon
- 2019: “Lottery Ticket Hypothesis” (Frankle & Carbin) explains role of initialization
- 2020: NTK theory provides mathematical framework for infinite-width limits
- 2021: “Grokking” paper shows delayed generalization in algorithmic tasks
- 2022: Scaling laws become central to foundation model development
- 2023-2025: Double descent and overparameterization are standard practice
Common Misconceptions
Myth 1: “Perfect training accuracy means overfitting”
- Reality: In overparameterized regime, perfect train accuracy is fine and expected
Myth 2: “Always use as little capacity as possible”
- Reality: Below interpolation threshold is worse than above it
Myth 3: “More parameters always means overfitting”
- Reality: More parameters + implicit regularization = better generalization
Myth 4: “Stop training once you hit 100% train accuracy”
- Reality: Grokking can happen long after - keep training!
Myth 5: “Deep learning is just memorization”
- Reality: Models compress patterns through implicit regularization
Key Takeaways
- Double descent is real: Test error improves → gets worse (threshold) → improves again
- Interpolation threshold is dangerous: Avoid models that can just barely fit training data
- Bigger is often better: Overparameterized models generalize better due to implicit regularization
- SGD is special: Stochastic optimization implicitly regularizes toward flat, generalizable minima
- Train longer: Grokking can happen long after perfect training accuracy
- Trust the process: Modern deep learning works despite violating classical ML intuitions
Related Concepts
- Bias-Variance Tradeoff - Classical ML perspective
- Regularization Techniques - Explicit regularization methods
- Optimization Algorithms - SGD and implicit regularization
- Language Model Scaling Laws - Empirical scaling behavior
- Training Best Practices - Practical training guidelines
Further Reading
Papers
- “Reconciling modern machine-learning practice and the classical bias-variance trade-off” (Belkin et al., 2019) - Original double descent paper
- “Deep Double Descent: Where Bigger Models and More Data Hurt” (Nakkiran et al., 2019) - Comprehensive analysis
- “The Lottery Ticket Hypothesis” (Frankle & Carbin, 2019) - Sparse subnetworks
- “Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets” (Power et al., 2021)
- “Scaling Laws for Neural Language Models” (Kaplan et al., 2020) - Empirical scaling
Videos
- “What the Books Get Wrong about AI (Double Descent)” - Welch Labs (essential viewing!)
- “The Lottery Ticket Hypothesis” - Yannic Kilcher
- “Deep Double Descent” - Paper explanation by authors
Code & Experiments
- OpenAI Scaling Laws: Empirical notebook
- Lottery Ticket Pruning: PyTorch implementation
- Grokking Reproduction: Algorithmic task examples