Skip to Content
LibraryConceptsLM Training

Language Model Training

Training autoregressive language models like GPT involves predicting the next token in a sequence using cross-entropy loss. Modern training techniques include gradient accumulation, learning rate schedules, and careful optimization to achieve stable, efficient training at scale.

Loss Computation

Language models are trained with next-token prediction: predict each token given all previous tokens.

Cross-Entropy Loss

For each position in the sequence, compute cross-entropy between predicted distribution and actual next token:

L=1Tt=1TlogP(xtx<t)\mathcal{L} = -\frac{1}{T} \sum_{t=1}^{T} \log P(x_t | x_{<t})

where:

  • TT is sequence length
  • xtx_t is the token at position tt
  • x<tx_{<t} represents all tokens before position tt

Data Preparation

def get_batch(data, batch_size, block_size): """Sample random chunks from data for training.""" # Sample random starting positions ix = torch.randint(len(data) - block_size, (batch_size,)) # Extract sequences x = torch.stack([data[i:i+block_size] for i in ix]) y = torch.stack([data[i+1:i+1+block_size] for i in ix]) return x, y # Example: # x = [[1, 2, 3, 4], [5, 6, 7, 8]] # Input sequences # y = [[2, 3, 4, 5], [6, 7, 8, 9]] # Target sequences (shifted by 1)

Training Step

# Get batch of data x, y = get_batch(train_data, batch_size=32, block_size=256) # Forward pass logits = model(x) # (B, T, vocab_size) # Reshape for cross-entropy: (B*T, vocab_size) and (B*T,) B, T, C = logits.shape logits = logits.view(B*T, C) targets = y.view(B*T) # Compute loss loss = F.cross_entropy(logits, targets) # Backward pass loss.backward() # Optimizer step optimizer.step() optimizer.zero_grad()

Why This Works

At each position tt, the model:

  1. Input: Receives tokens [x0,x1,...,xt][x_0, x_1, ..., x_t]
  2. Output: Predicts distribution over vocabulary for xt+1x_{t+1}
  3. Target: Actual token xt+1x_{t+1}
  4. Loss: Cross-entropy between predicted and actual

All positions contribute to loss in parallel during training (despite sequential dependencies enforced by causal masking).

Gradient Accumulation

When GPU memory is limited, gradient accumulation simulates larger batch sizes by accumulating gradients over multiple mini-batches before updating weights.

Implementation

accumulation_steps = 4 # Effective batch size = batch_size × accumulation_steps optimizer.zero_grad() for i, (x, y) in enumerate(dataloader): # Forward pass logits = model(x) loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1)) # Normalize loss by accumulation steps loss = loss / accumulation_steps # Backward pass (gradients accumulate) loss.backward() # Update weights every accumulation_steps batches if (i + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad()

Why It Works

Gradient accumulation exploits the linearity of gradients:

θLtotal=θL1+θL2+...+θLk\nabla_\theta \mathcal{L}_{\text{total}} = \nabla_\theta \mathcal{L}_1 + \nabla_\theta \mathcal{L}_2 + ... + \nabla_\theta \mathcal{L}_k

Benefits:

  • Effective batch size = batch_size × accumulation_steps
  • Same weight updates as training with larger batch
  • Uses less GPU memory than large batch
  • Critical for training large models on limited hardware

Example: With batch size 16 and accumulation steps 32, you get effective batch size of 512.

Learning Rate Schedule

Modern language model training uses warmup followed by cosine decay.

Implementation

def get_lr(it, warmup_iters, max_iters, max_lr, min_lr): """Learning rate schedule with linear warmup and cosine decay.""" # 1. Linear warmup if it < warmup_iters: return max_lr * (it / warmup_iters) # 2. Constant at max (optional) if it > max_iters: return min_lr # 3. Cosine decay decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters) coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) return min_lr + coeff * (max_lr - min_lr)

Schedule Visualization

Learning Rate max_lr ┐ ╱╲ │ ╱ ╲ │ ╱ ╲___ │ ╱ ╲___ │ ╱ ╲___ min_lr └──┴──────────────────────→ Iterations │←warmup→│←cosine decay→│

Phase Breakdown

1. Linear Warmup (first 2-10% of training):

  • Gradually increase LR from 0 to maximum
  • Why: Prevents instability from large updates early in training
  • Typical duration: 2,000-10,000 steps

2. Cosine Decay (remaining training):

  • Smoothly decrease LR following cosine curve
  • Why: Allows fine-tuning in later stages
  • End value: Usually 10% of max LR (not 0)

Typical Hyperparameters

# GPT-2 style training max_lr = 6e-4 # Peak learning rate min_lr = 6e-5 # Final learning rate (10% of max) warmup_iters = 2000 # Warmup steps max_iters = 600000 # Total training steps

Optimizer: AdamW

Language models typically use AdamW (Adam with decoupled weight decay).

Configuration

optimizer = torch.optim.AdamW( model.parameters(), lr=6e-4, # Peak learning rate betas=(0.9, 0.95), # Momentum parameters (β1, β2) eps=1e-8, # Numerical stability epsilon weight_decay=0.1 # L2 regularization strength )

Why AdamW?

Adaptive learning rates:

  • Different learning rate for each parameter
  • Fast convergence for sparse gradients
  • Reduces need for manual per-layer LR tuning

Momentum:

  • beta1 (0.9): First moment (mean) for smoothing gradients
  • beta2 (0.95): Second moment (variance) for adaptive scaling
  • Helps escape saddle points and navigate ravines

Decoupled weight decay:

  • Fixes weight decay implementation bug in Adam
  • Proper L2 regularization independent of gradient scaling
  • Standard for transformer training

Adam Update Rule

mt=β1mt1+(1β1)gtm_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t vt=β2vt1+(1β2)gt2v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 m^t=mt1β1t,v^t=vt1β2t\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} θt=θt1αm^tv^t+ϵλθt1\theta_t = \theta_{t-1} - \alpha \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} - \lambda \theta_{t-1}

where the last term is weight decay.

Complete Training Loop

# Configuration max_iters = 5000 eval_interval = 500 batch_size = 64 block_size = 256 grad_clip = 1.0 # Optimizer optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4, weight_decay=0.1) # Training loop for iter in range(max_iters): # Periodic evaluation if iter % eval_interval == 0: model.eval() with torch.no_grad(): val_loss = estimate_loss(model, val_data) model.train() print(f"step {iter}: train loss {train_loss:.4f}, val loss {val_loss:.4f}") # Sample batch xb, yb = get_batch(train_data, batch_size, block_size) # Forward pass logits = model(xb) loss = F.cross_entropy(logits.view(-1, vocab_size), yb.view(-1)) # Backward pass optimizer.zero_grad() loss.backward() # Gradient clipping (prevent exploding gradients) torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) # Update learning rate lr = get_lr(iter, warmup_iters=100, max_iters=max_iters, max_lr=6e-4, min_lr=6e-5) for param_group in optimizer.param_groups: param_group['lr'] = lr # Optimizer step optimizer.step()

Gradient Clipping

Prevents exploding gradients by capping gradient norm:

# Clip gradients to max norm of 1.0 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

How it works:

  • Compute total gradient norm: gnorm=igi2g_{\text{norm}} = \sqrt{\sum_i g_i^2}
  • If g_{\text{norm}} > \text{max_norm}: scale all gradients by \frac{\text{max_norm}}{g_{\text{norm}}}
  • Maintains gradient direction, just reduces magnitude

When to use:

  • Essential for RNNs (exploding gradients common)
  • Helpful for transformers (occasional spikes in gradients)
  • Standard practice for LLM training

Checkpointing

Save model state periodically to resume training or deploy:

# Save checkpoint checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'iter': iter, 'best_val_loss': best_val_loss, 'config': model_config, } torch.save(checkpoint, f'ckpt_iter_{iter}.pt') # Load checkpoint checkpoint = torch.load('ckpt_iter_5000.pt') model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) start_iter = checkpoint['iter'] + 1

Best practices:

  • Save every N steps (e.g., every 1000 steps)
  • Keep last K checkpoints (avoid filling disk)
  • Save best model based on validation loss
  • Include optimizer state for resuming training

Training Dynamics

Typical Loss Curves

Healthy training:

  • Rapid initial decrease (first 10-20% of training)
  • Slower, steady improvement
  • Train and val loss track each other
  • Eventually plateau (convergence)

Warning signs:

  • Large train/val gap: Overfitting (need regularization)
  • Val loss increasing: Severe overfitting (stop training)
  • Loss not decreasing: Learning rate too low or data issues
  • Loss exploding: Learning rate too high or gradient issues

Common Issues and Solutions

Exploding Gradients:

  • Symptoms: Loss becomes NaN, gradients > 1000
  • Solutions:
    • Enable gradient clipping (clip_grad_norm_)
    • Reduce learning rate
    • Check for bugs (e.g., missing normalization)

Vanishing Gradients:

  • Symptoms: Very slow learning, gradients near 0
  • Solutions:
    • Rare in transformers (residuals help)
    • Check layer normalization is working
    • Verify proper weight initialization

Slow Convergence:

  • Solutions:
    • Increase learning rate (if stable)
    • Increase batch size / gradient accumulation
    • Verify warmup is long enough
    • Check data quality and preprocessing

Overfitting:

  • Solutions:
    • Increase dropout (try 0.1 → 0.2)
    • Increase weight decay (try 0.1 → 0.2)
    • Get more training data
    • Reduce model size

Advanced Techniques

Mixed Precision Training

Use FP16 instead of FP32 to reduce memory and increase speed:

from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() for x, y in dataloader: optimizer.zero_grad() # Forward in FP16 with autocast(): logits = model(x) loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1)) # Backward with gradient scaling scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()

Benefits:

  • 2× memory reduction
  • 2-3× training speedup
  • Minimal quality impact
  • Standard for large model training

Gradient Checkpointing

Trade compute for memory by recomputing activations during backward pass:

from torch.utils.checkpoint import checkpoint # In transformer block forward: def forward(self, x): x = checkpoint(self.attention_block, x) # Recompute during backward x = checkpoint(self.mlp_block, x) return x

Benefits:

  • Reduces memory by ~40-60%
  • Allows training larger models
  • Cost: ~30% slower training

Key Insights

  1. Batch size matters: Larger batches stabilize training but may hurt generalization (use gradient accumulation)
  2. Learning rate is critical: Most important hyperparameter; warmup prevents early instability
  3. Gradient clipping is essential: Prevents occasional gradient spikes from derailing training
  4. Monitor validation loss: Use for early stopping and checkpoint selection
  5. Mixed precision is free performance: 2× speedup with minimal effort

Learning Resources

Papers

  • Attention Is All You Need (Vaswani et al., 2017) - Original transformer training details
  • Language Models are Few-Shot Learners (Brown et al., 2020) - GPT-3 training at scale
  • Decoupled Weight Decay Regularization (Loshchilov & Hutter, 2019) - AdamW paper

Implementation Guides

Video Tutorials

Articles