Skip to Content
LibraryConceptsTraining Dynamics

Training Dynamics and Double Descent

Modern deep learning defies classical machine learning wisdom. Models with billions of parameters that perfectly fit training data still generalize remarkably well. This phenomenon - and the theory explaining it - represents one of the most important shifts in our understanding of how deep learning works.

The Classical View (Incomplete)

Traditional machine learning taught a simple story:

Classical Bias-Variance Tradeoff:

Test Error | | ╱╲ | ╱ ╲ |╱ ╲ | ╲___ |___________ Model Size ← Underfit | Optimal | Overfit →

Classical wisdom:

  1. Underfit when model is too simple
  2. Optimal at “just right” complexity
  3. Overfit when model is too complex

Prescription: Keep models simple, add regularization, stop before overfitting.

The problem: This doesn’t explain GPT-4, Stable Diffusion, or any modern large-scale model!

The Modern Reality: Double Descent

Test error actually follows a double descent curve with model complexity:

Test Error | | ╱╲ Classical | ╱ ╲ U-curve |╱ ╲___ | ╱ ╲___ | ╱ ╲___ |____╱ ╲___ Modern regime | ↑ Interpolation threshold |________________________ Model Size ← Under | Critical | Overparameterized →

Three Distinct Regimes

RegimeBehaviorPerformanceExample
UnderparameterizedClassical bias-variance tradeoffImproves then degradesLinear models, small networks
Interpolation thresholdCan just barely fit training dataWorst performanceCritical model size
OverparameterizedMany ways to fit training dataImproves with size!Modern deep networks
Highly overparameterizedMassive capacityBest performanceGPT-4, CLIP, Stable Diffusion

The Interpolation Threshold

The interpolation threshold is where the model has just enough parameters to perfectly fit the training data - and this is often where performance is worst!

Why the threshold is dangerous:

  • Below: Model can’t fit all training examples → high bias
  • At threshold: Model can barely fit training data → fits noise, overfits badly
  • Above threshold: Model can fit training data many different ways → finds good solutions
Critical Insight

The worst performance often occurs at the interpolation threshold - when your model can just barely fit the training data. Going larger actually helps generalization!

Implications for Model Design

This completely changes how we approach deep learning:

Modern principles:

  1. Larger models generalize better (counterintuitive!)
  2. Don’t stop at interpolation threshold - go bigger
  3. Perfect training accuracy is okay - doesn’t necessarily mean overfitting
  4. Early stopping may hurt in overparameterized regime
  5. Let training continue past perfect train accuracy

Outdated principles:

  1. Stop training when you hit perfect accuracy
  2. Always add more regularization
  3. Keep models as small as possible
  4. Perfect train accuracy means overfitting
Modern Deep Learning Scale

GPT-4: 1.8T parameters trained on ~13T tokens

Stable Diffusion: 890M parameters trained on 2B image-text pairs

CLIP: 400M parameters trained on 400M image-text pairs

All of these models are massively overparameterized relative to classical ML intuitions, yet they generalize exceptionally well.

Why Overparameterization Helps

Multiple complementary mechanisms explain why larger models generalize better:

1. Implicit Regularization via SGD

Stochastic Gradient Descent doesn’t just find any solution - it prefers certain solutions that generalize well.

Flat vs Sharp Minima:

Loss Landscape Flat Minimum (Good): Sharp Minimum (Bad): ___ /\ / \ / \ / \ / \ / \ / \ Robust to perturbations Sensitive to perturbations Generalizes well Overfits

How SGD finds flat minima:

  • Stochastic noise: Batch sampling randomness explores broadly
  • Learning rate: Acts as temperature, prefers wider basins
  • Momentum: Smooths optimization path, avoids sharp valleys
  • Batch size: Smaller batches → more noise → flatter minima
# SGD implicitly regularizes through: # 1. Stochastic batch sampling (randomness) # 2. Learning rate (temperature) # 3. Momentum (smoothing) optimizer = torch.optim.SGD( model.parameters(), lr=0.01, # Larger LR → stronger implicit regularization momentum=0.9, # Smooths optimization weight_decay=0 # Often not needed in overparameterized regime! ) # Small batches also help (more stochastic noise) data_loader = DataLoader(dataset, batch_size=32) # vs 1024

2. More Paths to Good Solutions

With more parameters, there are exponentially more ways to fit the data:

RegimeSolution SpaceOptimization
UnderparameterizedFew solutions, may not include good onesHard - limited paths
At thresholdUnique solution, often badHardest - no flexibility
OverparameterizedMany solutions, including good onesEasier - many paths

The larger the model, the more likely SGD will find a solution that:

  • Fits training data (necessary)
  • Uses simple patterns (implicit regularization)
  • Generalizes to test data (desired outcome)

3. The Lottery Ticket Hypothesis

Core claim: A randomly initialized dense network contains sparse “winning ticket” subnetworks that can train to similar accuracy.

Finding winning tickets:

def find_lottery_ticket(model, data, prune_ratio=0.9): """ Lottery Ticket Hypothesis implementation """ # Step 1: Save initial random weights initial_weights = copy.deepcopy(model.state_dict()) # Step 2: Train network to convergence train(model, data, epochs=100) # Step 3: Identify important weights by magnitude importance = { name: param.abs() for name, param in model.named_parameters() } # Step 4: Create mask keeping top (1 - prune_ratio) weights threshold = torch.quantile( torch.cat([v.flatten() for v in importance.values()]), prune_ratio ) mask = { name: (importance[name] > threshold).float() for name in importance } # Step 5: Reset to initial weights model.load_state_dict(initial_weights) # Step 6: Train sparse network with mask train_with_mask(model, data, mask, epochs=100) # Result: Sparse network (10% of weights) achieves # similar performance to full dense network! return model, mask

Implications:

  • Dense networks easier to train: More “lottery tickets” means higher chance of lucky initialization
  • Initialization matters: The right sparse network exists, but you need the right initial weights
  • Pruning works: Train large → prune → fine-tune is effective
  • Overparameterization helps optimization: Not just representation capacity!

4. Grokking: Delayed Generalization

Models can suddenly generalize long after achieving perfect training accuracy!

The grokking phenomenon:

Accuracy 100% | _____ Test | / | ___________/ 50% | / |____/ Training | |_________________________ 0 1k 5k 10k epochs

Timeline:

  • Epoch 1000: 100% training accuracy, 50% test accuracy (memorization)
  • Epoch 10000: 100% training accuracy, 100% test accuracy (generalization!)
  • Model “groks” the underlying pattern after memorizing

Why grokking happens:

  1. Phase 1 (fast): Memorize training data (easy, any solution works)
  2. Phase 2 (slow): Implicit regularization gradually simplifies solution
  3. Phase 3 (sudden): Model discovers generalizable pattern
Training Lesson

Don’t stop training too early! Generalization can happen long after you’ve achieved perfect training accuracy. The model is still improving through implicit regularization.

Practical example:

# Training a modular arithmetic task (e.g., a + b mod 97) # Epoch 1000: # - Training: 100% (memorized all 97*97 = 9409 examples) # - Test: 50% (no generalization) # Epoch 5000: # - Training: 100% (still memorized) # - Test: 50% (still no generalization) # Epoch 10000: # - Training: 100% # - Test: 100% (suddenly discovered modular structure!) # The model grokked that: # (a + b) mod n = ((a mod n) + (b mod n)) mod n

Practical Implications

For Model Architecture Design

When designing models:

Do:

  • Use large models when computationally feasible
  • Don’t fear overparameterization
  • Trust implicit regularization (especially with SGD)
  • Let training continue past perfect train accuracy
  • Consider scaling up before adding explicit regularization

Don’t:

  • Stop at interpolation threshold (worst performance!)
  • Assume perfect training accuracy means overfitting
  • Over-regularize large models (can prevent grokking)
  • Give up training too early
  • Always default to smallest possible model

For Training Strategy

Effective training in overparameterized regime:

# Good practices for large models # 1. Use learning rate warmup scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=1000, # Gradual warmup num_training_steps=10000 ) # 2. Small batch sizes for implicit regularization batch_size = 32 # vs 1024 # 3. SGD or Adam with reasonable learning rate optimizer = torch.optim.AdamW( model.parameters(), lr=1e-4, # Not too small! weight_decay=0.01 # Light weight decay if any ) # 4. Train longer than you think # - Don't stop at perfect train accuracy # - Wait for generalization (grokking) # - Monitor validation loss, not just train loss epochs = 100 # Or until validation stops improving # 5. Gradient clipping for stability torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

For Healthcare AI and Domain Applications

Domain-Specific Considerations

Scenario: Training a multimodal model for EHR + symptoms with:

  • Model: 50M parameters
  • Training data: 100K labeled examples

Analysis: You’re in the overparameterized regime (500 parameters per example).

Strategy:

  • ✅ Use the full 50M parameter model (don’t reduce size)
  • ✅ Train to perfect training accuracy, then keep going
  • ✅ Use light regularization (dropout 0.1, small weight decay)
  • ✅ Let training run 50-100 epochs past perfect train accuracy
  • ❌ Don’t stop early due to “overfitting” fears
  • ❌ Don’t aggressively regularize (can prevent learning)

For Debugging Training Issues

If model isn’t learning:

  1. Check if you’re at interpolation threshold

    # Can model fit training data? train_acc = evaluate(model, train_loader) if train_acc < 95%: # Underparameterized or at threshold # → Try larger model! model = increase_model_size(model)
  2. Check optimization stability

    # Use gradient clipping torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Use learning rate warmup scheduler = WarmupScheduler(...) # Check for NaN gradients if torch.isnan(loss): # Reduce learning rate or add gradient clipping
  3. Give training more time

    # Don't stop at epoch 10 if still improving # Grokking can take 100+ epochs patience = 50 # Wait 50 epochs for improvement

Theoretical Foundations

Neural Tangent Kernel (NTK) Theory

In the infinite-width limit, neural networks behave like kernel methods:

  • Infinite width: Network barely changes during training (lazy regime)
  • Finite width: Network adapts, features learn (feature learning regime)
  • Practical networks: Between lazy and feature learning

This theory helps explain why overparameterized networks generalize but is complex.

Scaling Laws

Empirical scaling laws for large language models show consistent patterns:

# Chinchilla scaling law (approximate) optimal_tokens = 20 * num_parameters # For a 10B parameter model: optimal_training = 20 * 10B = 200B tokens # Test loss scales predictably: # L(N, D) ∝ (N_c / N)^α + (D_c / D)^β # where N = params, D = data, α ≈ 0.076, β ≈ 0.095

See Language Model Scaling Laws for details.

Historical Context

  • 2017: Deep learning achieves state-of-the-art despite “overfitting” classical intuitions
  • 2018: “Deep Double Descent” paper (Belkin et al.) formalizes the phenomenon
  • 2019: “Lottery Ticket Hypothesis” (Frankle & Carbin) explains role of initialization
  • 2020: NTK theory provides mathematical framework for infinite-width limits
  • 2021: “Grokking” paper shows delayed generalization in algorithmic tasks
  • 2022: Scaling laws become central to foundation model development
  • 2023-2025: Double descent and overparameterization are standard practice

Common Misconceptions

Myth 1: “Perfect training accuracy means overfitting”

  • Reality: In overparameterized regime, perfect train accuracy is fine and expected

Myth 2: “Always use as little capacity as possible”

  • Reality: Below interpolation threshold is worse than above it

Myth 3: “More parameters always means overfitting”

  • Reality: More parameters + implicit regularization = better generalization

Myth 4: “Stop training once you hit 100% train accuracy”

  • Reality: Grokking can happen long after - keep training!

Myth 5: “Deep learning is just memorization”

  • Reality: Models compress patterns through implicit regularization

Key Takeaways

  1. Double descent is real: Test error improves → gets worse (threshold) → improves again
  2. Interpolation threshold is dangerous: Avoid models that can just barely fit training data
  3. Bigger is often better: Overparameterized models generalize better due to implicit regularization
  4. SGD is special: Stochastic optimization implicitly regularizes toward flat, generalizable minima
  5. Train longer: Grokking can happen long after perfect training accuracy
  6. Trust the process: Modern deep learning works despite violating classical ML intuitions

Further Reading

Papers

  • “Reconciling modern machine-learning practice and the classical bias-variance trade-off” (Belkin et al., 2019) - Original double descent paper
  • “Deep Double Descent: Where Bigger Models and More Data Hurt” (Nakkiran et al., 2019) - Comprehensive analysis
  • “The Lottery Ticket Hypothesis” (Frankle & Carbin, 2019) - Sparse subnetworks
  • “Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets” (Power et al., 2021)
  • “Scaling Laws for Neural Language Models” (Kaplan et al., 2020) - Empirical scaling

Videos

  • “What the Books Get Wrong about AI (Double Descent)” - Welch Labs (essential viewing!)
  • “The Lottery Ticket Hypothesis” - Yannic Kilcher
  • “Deep Double Descent” - Paper explanation by authors

Code & Experiments

  • OpenAI Scaling Laws: Empirical notebook
  • Lottery Ticket Pruning: PyTorch implementation
  • Grokking Reproduction: Algorithmic task examples