Language Model Training
Training autoregressive language models like GPT involves predicting the next token in a sequence using cross-entropy loss. Modern training techniques include gradient accumulation, learning rate schedules, and careful optimization to achieve stable, efficient training at scale.
Loss Computation
Language models are trained with next-token prediction: predict each token given all previous tokens.
Cross-Entropy Loss
For each position in the sequence, compute cross-entropy between predicted distribution and actual next token:
where:
- is sequence length
- is the token at position
- represents all tokens before position
Data Preparation
def get_batch(data, batch_size, block_size):
"""Sample random chunks from data for training."""
# Sample random starting positions
ix = torch.randint(len(data) - block_size, (batch_size,))
# Extract sequences
x = torch.stack([data[i:i+block_size] for i in ix])
y = torch.stack([data[i+1:i+1+block_size] for i in ix])
return x, y
# Example:
# x = [[1, 2, 3, 4], [5, 6, 7, 8]] # Input sequences
# y = [[2, 3, 4, 5], [6, 7, 8, 9]] # Target sequences (shifted by 1)Training Step
# Get batch of data
x, y = get_batch(train_data, batch_size=32, block_size=256)
# Forward pass
logits = model(x) # (B, T, vocab_size)
# Reshape for cross-entropy: (B*T, vocab_size) and (B*T,)
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = y.view(B*T)
# Compute loss
loss = F.cross_entropy(logits, targets)
# Backward pass
loss.backward()
# Optimizer step
optimizer.step()
optimizer.zero_grad()Why This Works
At each position , the model:
- Input: Receives tokens
- Output: Predicts distribution over vocabulary for
- Target: Actual token
- Loss: Cross-entropy between predicted and actual
All positions contribute to loss in parallel during training (despite sequential dependencies enforced by causal masking).
Gradient Accumulation
When GPU memory is limited, gradient accumulation simulates larger batch sizes by accumulating gradients over multiple mini-batches before updating weights.
Implementation
accumulation_steps = 4 # Effective batch size = batch_size × accumulation_steps
optimizer.zero_grad()
for i, (x, y) in enumerate(dataloader):
# Forward pass
logits = model(x)
loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
# Normalize loss by accumulation steps
loss = loss / accumulation_steps
# Backward pass (gradients accumulate)
loss.backward()
# Update weights every accumulation_steps batches
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()Why It Works
Gradient accumulation exploits the linearity of gradients:
Benefits:
- Effective batch size =
batch_size × accumulation_steps - Same weight updates as training with larger batch
- Uses less GPU memory than large batch
- Critical for training large models on limited hardware
Example: With batch size 16 and accumulation steps 32, you get effective batch size of 512.
Learning Rate Schedule
Modern language model training uses warmup followed by cosine decay.
Implementation
def get_lr(it, warmup_iters, max_iters, max_lr, min_lr):
"""Learning rate schedule with linear warmup and cosine decay."""
# 1. Linear warmup
if it < warmup_iters:
return max_lr * (it / warmup_iters)
# 2. Constant at max (optional)
if it > max_iters:
return min_lr
# 3. Cosine decay
decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (max_lr - min_lr)Schedule Visualization
Learning Rate
↑
max_lr ┐ ╱╲
│ ╱ ╲
│ ╱ ╲___
│ ╱ ╲___
│ ╱ ╲___
min_lr └──┴──────────────────────→ Iterations
│←warmup→│←cosine decay→│Phase Breakdown
1. Linear Warmup (first 2-10% of training):
- Gradually increase LR from 0 to maximum
- Why: Prevents instability from large updates early in training
- Typical duration: 2,000-10,000 steps
2. Cosine Decay (remaining training):
- Smoothly decrease LR following cosine curve
- Why: Allows fine-tuning in later stages
- End value: Usually 10% of max LR (not 0)
Typical Hyperparameters
# GPT-2 style training
max_lr = 6e-4 # Peak learning rate
min_lr = 6e-5 # Final learning rate (10% of max)
warmup_iters = 2000 # Warmup steps
max_iters = 600000 # Total training stepsOptimizer: AdamW
Language models typically use AdamW (Adam with decoupled weight decay).
Configuration
optimizer = torch.optim.AdamW(
model.parameters(),
lr=6e-4, # Peak learning rate
betas=(0.9, 0.95), # Momentum parameters (β1, β2)
eps=1e-8, # Numerical stability epsilon
weight_decay=0.1 # L2 regularization strength
)Why AdamW?
Adaptive learning rates:
- Different learning rate for each parameter
- Fast convergence for sparse gradients
- Reduces need for manual per-layer LR tuning
Momentum:
beta1(0.9): First moment (mean) for smoothing gradientsbeta2(0.95): Second moment (variance) for adaptive scaling- Helps escape saddle points and navigate ravines
Decoupled weight decay:
- Fixes weight decay implementation bug in Adam
- Proper L2 regularization independent of gradient scaling
- Standard for transformer training
Adam Update Rule
where the last term is weight decay.
Complete Training Loop
# Configuration
max_iters = 5000
eval_interval = 500
batch_size = 64
block_size = 256
grad_clip = 1.0
# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4, weight_decay=0.1)
# Training loop
for iter in range(max_iters):
# Periodic evaluation
if iter % eval_interval == 0:
model.eval()
with torch.no_grad():
val_loss = estimate_loss(model, val_data)
model.train()
print(f"step {iter}: train loss {train_loss:.4f}, val loss {val_loss:.4f}")
# Sample batch
xb, yb = get_batch(train_data, batch_size, block_size)
# Forward pass
logits = model(xb)
loss = F.cross_entropy(logits.view(-1, vocab_size), yb.view(-1))
# Backward pass
optimizer.zero_grad()
loss.backward()
# Gradient clipping (prevent exploding gradients)
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
# Update learning rate
lr = get_lr(iter, warmup_iters=100, max_iters=max_iters,
max_lr=6e-4, min_lr=6e-5)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Optimizer step
optimizer.step()Gradient Clipping
Prevents exploding gradients by capping gradient norm:
# Clip gradients to max norm of 1.0
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)How it works:
- Compute total gradient norm:
- If g_{\text{norm}} > \text{max_norm}: scale all gradients by \frac{\text{max_norm}}{g_{\text{norm}}}
- Maintains gradient direction, just reduces magnitude
When to use:
- Essential for RNNs (exploding gradients common)
- Helpful for transformers (occasional spikes in gradients)
- Standard practice for LLM training
Checkpointing
Save model state periodically to resume training or deploy:
# Save checkpoint
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'iter': iter,
'best_val_loss': best_val_loss,
'config': model_config,
}
torch.save(checkpoint, f'ckpt_iter_{iter}.pt')
# Load checkpoint
checkpoint = torch.load('ckpt_iter_5000.pt')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_iter = checkpoint['iter'] + 1Best practices:
- Save every N steps (e.g., every 1000 steps)
- Keep last K checkpoints (avoid filling disk)
- Save best model based on validation loss
- Include optimizer state for resuming training
Training Dynamics
Typical Loss Curves
Healthy training:
- Rapid initial decrease (first 10-20% of training)
- Slower, steady improvement
- Train and val loss track each other
- Eventually plateau (convergence)
Warning signs:
- Large train/val gap: Overfitting (need regularization)
- Val loss increasing: Severe overfitting (stop training)
- Loss not decreasing: Learning rate too low or data issues
- Loss exploding: Learning rate too high or gradient issues
Common Issues and Solutions
Exploding Gradients:
- Symptoms: Loss becomes NaN, gradients > 1000
- Solutions:
- Enable gradient clipping (
clip_grad_norm_) - Reduce learning rate
- Check for bugs (e.g., missing normalization)
- Enable gradient clipping (
Vanishing Gradients:
- Symptoms: Very slow learning, gradients near 0
- Solutions:
- Rare in transformers (residuals help)
- Check layer normalization is working
- Verify proper weight initialization
Slow Convergence:
- Solutions:
- Increase learning rate (if stable)
- Increase batch size / gradient accumulation
- Verify warmup is long enough
- Check data quality and preprocessing
Overfitting:
- Solutions:
- Increase dropout (try 0.1 → 0.2)
- Increase weight decay (try 0.1 → 0.2)
- Get more training data
- Reduce model size
Advanced Techniques
Mixed Precision Training
Use FP16 instead of FP32 to reduce memory and increase speed:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for x, y in dataloader:
optimizer.zero_grad()
# Forward in FP16
with autocast():
logits = model(x)
loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
# Backward with gradient scaling
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()Benefits:
- 2× memory reduction
- 2-3× training speedup
- Minimal quality impact
- Standard for large model training
Gradient Checkpointing
Trade compute for memory by recomputing activations during backward pass:
from torch.utils.checkpoint import checkpoint
# In transformer block forward:
def forward(self, x):
x = checkpoint(self.attention_block, x) # Recompute during backward
x = checkpoint(self.mlp_block, x)
return xBenefits:
- Reduces memory by ~40-60%
- Allows training larger models
- Cost: ~30% slower training
Key Insights
- Batch size matters: Larger batches stabilize training but may hurt generalization (use gradient accumulation)
- Learning rate is critical: Most important hyperparameter; warmup prevents early instability
- Gradient clipping is essential: Prevents occasional gradient spikes from derailing training
- Monitor validation loss: Use for early stopping and checkpoint selection
- Mixed precision is free performance: 2× speedup with minimal effort
Related Concepts
- GPT Architecture - The model being trained
- Optimization Algorithms - SGD, Adam, and optimization fundamentals
- Regularization - Techniques to prevent overfitting
- Backpropagation - How gradients are computed
- Text Generation - Using the trained model for inference
Learning Resources
Papers
- Attention Is All You Need (Vaswani et al., 2017) - Original transformer training details
- Language Models are Few-Shot Learners (Brown et al., 2020) - GPT-3 training at scale
- Decoupled Weight Decay Regularization (Loshchilov & Hutter, 2019) - AdamW paper
Implementation Guides
- Andrej Karpathy’s nanoGPT - Clean, minimal GPT training loop
- PyTorch Optimization Tutorial
- Hugging Face Training Guide
Video Tutorials
- Let’s Build GPT from Scratch - Karpathy builds and trains GPT
- Stanford CS224N: Training Neural Nets - Training best practices
Articles
- The Illustrated Transformer - Visual guide to transformer training
- Training Tips for Transformers - Hugging Face optimization guide