Skip to Content
LibraryConceptsTransformer Training

Training Transformers

Training transformer models requires careful attention to masking strategies, loss functions, and learning rate schedules to achieve stable and effective learning. This guide covers the essential techniques from the “Attention Is All You Need” paper and modern best practices.

Masking Strategies

Transformers use different types of masks to control attention patterns and prevent information leakage.

Padding Mask

Purpose: Prevent attention to padding tokens (used to batch variable-length sequences).

Problem: Sequences in a batch have different lengths, so we pad shorter sequences:

Batch of sequences: Sequence 1: "The cat sat" (3 tokens) + 2 padding Sequence 2: "Hello world" (2 tokens) + 3 padding Sequence 3: "A" (1 token) + 4 padding

Solution: Create a mask that marks padding positions as invalid for attention.

Implementation:

def create_padding_mask(seq, pad_token_id=0): """ Create mask for padding tokens Args: seq: Token IDs (batch, seq_len) pad_token_id: ID used for padding (usually 0) Returns: mask: (batch, 1, 1, seq_len) for broadcasting across heads """ # Mark non-padding positions as 1, padding as 0 mask = (seq != pad_token_id).unsqueeze(1).unsqueeze(2) # Shape: (batch, 1, 1, seq_len) return mask # Usage in scaled dot-product attention: scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) scores = scores.masked_fill(mask == 0, -1e9) # Set padding to -inf attn = F.softmax(scores, dim=-1) # softmax(-inf) ≈ 0

Why -1e9? After softmax, exp()0\exp(-\infty) \approx 0, so padding positions get zero attention weight.

Causal Mask (Look-Ahead Mask)

Purpose: Prevent decoder from attending to future positions during training (maintains autoregressive property).

Why needed? During training, we have the full target sequence, but we want to ensure the model learns to generate each token based only on previous tokens.

Causal mask structure:

Position can attend to: Position 0: [0] → only itself Position 1: [0, 1] → 0 and itself Position 2: [0, 1, 2] → 0, 1, and itself Position 3: [0, 1, 2, 3] → all previous + itself

Implementation:

def create_causal_mask(seq_len): """ Create lower triangular mask for causal attention Args: seq_len: Sequence length Returns: mask: (seq_len, seq_len) lower triangular matrix """ # Create lower triangular matrix mask = torch.tril(torch.ones(seq_len, seq_len)) # [[1, 0, 0, 0], # [1, 1, 0, 0], # [1, 1, 1, 0], # [1, 1, 1, 1]] return mask # Equivalent using upper triangular: mask = (torch.triu(torch.ones(seq_len, seq_len), diagonal=1) == 0)

Effect: Position ii can only attend to positions i\leq i, enforcing left-to-right information flow.

Combined Mask

In practice, decoders need both padding and causal masks:

def create_combined_mask(tgt_seq, pad_token_id=0): """ Combine causal and padding masks for decoder self-attention Args: tgt_seq: Target sequence (batch, seq_len) pad_token_id: Padding token ID Returns: mask: Combined mask (batch, 1, seq_len, seq_len) """ batch_size, seq_len = tgt_seq.size() # Padding mask: (batch, 1, 1, seq_len) padding_mask = create_padding_mask(tgt_seq, pad_token_id) # Causal mask: (1, 1, seq_len, seq_len) causal_mask = create_causal_mask(seq_len) causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) # Combine: BOTH conditions must be satisfied (logical AND) mask = padding_mask & causal_mask.to(tgt_seq.device) return mask

When to use which mask:

  • Encoder self-attention: Padding mask only (bidirectional)
  • Decoder self-attention: Padding + causal mask (unidirectional)
  • Cross-attention (decoder attending to encoder): Padding mask on source only

Loss Functions

Transformers typically use cross-entropy loss over the vocabulary for each position.

Standard Cross-Entropy

For sequence-to-sequence tasks, the loss sums over all positions:

L=t=1TlogP(yty<t,x)L = -\sum_{t=1}^{T} \log P(y_t | y_{<t}, x)

Where:

  • yty_t: Target token at position tt
  • y<ty_{<t}: All previous target tokens
  • xx: Source sequence (for encoder-decoder)

Implementation:

def compute_loss(logits, targets, pad_token_id=0): """ Compute cross-entropy loss, ignoring padding positions Args: logits: Model output (batch, seq_len, vocab_size) targets: Target tokens (batch, seq_len) pad_token_id: Padding token ID to ignore Returns: loss: Scalar loss value """ batch_size, seq_len, vocab_size = logits.size() # Reshape for cross-entropy # logits: (batch * seq_len, vocab_size) # targets: (batch * seq_len) logits_flat = logits.view(-1, vocab_size) targets_flat = targets.view(-1) # Compute cross-entropy (automatically ignores pad_token_id) loss = F.cross_entropy( logits_flat, targets_flat, ignore_index=pad_token_id, # Don't compute loss for padding reduction='mean' ) return loss

Label Smoothing

Purpose: Regularization technique to prevent overconfidence and improve generalization.

Problem with hard targets: Model becomes overconfident

Target: [0, 0, 1, 0, 0] # One-hot for token ID 2 Model learns to output: [0.001, 0.001, 0.998, 0.0, 0.0] # Overconfident

Solution with label smoothing: Distribute some probability mass to other tokens

Smoothed target: [0.02, 0.02, 0.92, 0.02, 0.02] # ε = 0.1 Model learns to output: [0.05, 0.05, 0.85, 0.03, 0.02] # More calibrated

Formula:

yi={1ϵif i=true classϵV1otherwisey'_i = \begin{cases} 1 - \epsilon & \text{if } i = \text{true class} \\ \frac{\epsilon}{V - 1} & \text{otherwise} \end{cases}

Where ϵ\epsilon is the smoothing parameter (e.g., 0.1) and VV is vocabulary size.

Implementation:

class LabelSmoothingLoss(nn.Module): def __init__(self, vocab_size, smoothing=0.1, pad_token_id=0): super().__init__() self.vocab_size = vocab_size self.smoothing = smoothing self.confidence = 1.0 - smoothing self.pad_token_id = pad_token_id def forward(self, logits, targets): """ Args: logits: (batch, seq_len, vocab_size) targets: (batch, seq_len) """ batch_size, seq_len, vocab_size = logits.size() # Flatten logits = logits.view(-1, vocab_size) targets = targets.view(-1) # Log probabilities log_probs = F.log_softmax(logits, dim=-1) # Create smoothed target distribution true_dist = torch.zeros_like(log_probs) # Distribute smoothing probability among non-target tokens true_dist.fill_(self.smoothing / (vocab_size - 2)) # -2 for target and pad # Set confidence for true class true_dist.scatter_(1, targets.unsqueeze(1), self.confidence) # Don't compute loss for padding true_dist[:, self.pad_token_id] = 0 mask = (targets != self.pad_token_id).float() # KL divergence between smoothed targets and predictions loss = -(true_dist * log_probs).sum(dim=-1) loss = (loss * mask).sum() / mask.sum() return loss # Usage: criterion = LabelSmoothingLoss(vocab_size=50000, smoothing=0.1) loss = criterion(logits, targets)

Benefits of label smoothing:

  • Prevents overconfidence
  • Improves calibration
  • Better generalization (especially for BLEU scores in translation)
  • Recommended ϵ=0.1\epsilon = 0.1 in original transformer paper

Learning Rate Schedules

Transformers require careful learning rate scheduling with warmup followed by decay.

Warmup + Inverse Square Root Decay

The original “Attention Is All You Need” paper uses:

lr=dmodel0.5min(step0.5,stepwarmup1.5)\text{lr} = d_{\text{model}}^{-0.5} \cdot \min(\text{step}^{-0.5}, \text{step} \cdot \text{warmup}^{-1.5})

Behavior:

  1. Warmup phase (steps 0 to warmup_steps): Linear increase from 0
  2. Decay phase (after warmup_steps): Inverse square root decay

Why warmup?

  • At initialization, parameters are random → gradients are large and noisy
  • Large learning rate at start causes instability and divergence
  • Warmup allows model to “settle” into a reasonable basin before aggressive learning

Visualization:

Learning Rate │ ╱‾‾‾‾‾‾⤵ │ ╱ ⤵ │ ╱ ⤵ │ ╱ ⤵ │ ╱ ⤵ │╱ ⤵___ └────────────────────────────> Steps warmup_steps (4000)

Implementation:

class TransformerLRScheduler: def __init__(self, optimizer, d_model, warmup_steps=4000): """ Learning rate scheduler from "Attention Is All You Need" Args: optimizer: PyTorch optimizer d_model: Model dimension (e.g., 512) warmup_steps: Number of warmup steps (e.g., 4000) """ self.optimizer = optimizer self.d_model = d_model self.warmup_steps = warmup_steps self.current_step = 0 def step(self): """Update learning rate""" self.current_step += 1 # Compute learning rate using formula lr = (self.d_model ** -0.5) * min( self.current_step ** -0.5, self.current_step * (self.warmup_steps ** -1.5) ) # Update all parameter groups for param_group in self.optimizer.param_groups: param_group['lr'] = lr return lr # Usage: optimizer = torch.optim.Adam( model.parameters(), lr=1.0, # Will be controlled by scheduler betas=(0.9, 0.98), eps=1e-9 ) scheduler = TransformerLRScheduler(optimizer, d_model=512, warmup_steps=4000) for epoch in range(num_epochs): for batch in dataloader: optimizer.zero_grad() loss = compute_loss(model(batch)) loss.backward() optimizer.step() scheduler.step() # Update LR every step (not epoch!)

Alternative: Cosine Annealing with Warmup

Modern transformers (GPT, BERT) often use cosine annealing:

from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR # Warmup phase: linear increase warmup_scheduler = LinearLR( optimizer, start_factor=0.1, # Start at 10% of base LR total_iters=warmup_steps ) # Decay phase: cosine annealing to 0 cosine_scheduler = CosineAnnealingLR( optimizer, T_max=total_steps - warmup_steps, eta_min=0 ) # Combine sequentially scheduler = SequentialLR( optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[warmup_steps] )

Comparison:

  • Inverse sqrt decay: Slower decay, never reaches zero, used in original transformer
  • Cosine annealing: Faster decay, reaches zero at end, popular in modern LLMs

See Language Model Training for more on modern LR schedules.

Complete Training Loop

Putting it all together:

def train_transformer(model, train_loader, val_loader, num_epochs=10): """ Train transformer model with all best practices Args: model: Transformer model (encoder-decoder or decoder-only) train_loader: Training data loader val_loader: Validation data loader num_epochs: Number of training epochs """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) # Optimizer (Note: lr=1.0, will be controlled by scheduler) optimizer = torch.optim.Adam( model.parameters(), lr=1.0, betas=(0.9, 0.98), # Transformer paper values eps=1e-9 ) # Learning rate scheduler scheduler = TransformerLRScheduler( optimizer, d_model=model.d_model, warmup_steps=4000 ) # Loss function with label smoothing criterion = LabelSmoothingLoss( vocab_size=model.tgt_vocab_size, smoothing=0.1 ) for epoch in range(num_epochs): model.train() total_loss = 0 for batch_idx, (src, tgt) in enumerate(train_loader): src = src.to(device) tgt = tgt.to(device) # Prepare decoder input (shift target right by 1) tgt_input = tgt[:, :-1] # Remove last token tgt_output = tgt[:, 1:] # Remove first token (usually <BOS>) # Create masks src_mask = create_padding_mask(src) tgt_mask = create_combined_mask(tgt_input) # Forward pass logits = model(src, tgt_input, src_mask, tgt_mask) # Compute loss loss = criterion(logits, tgt_output) # Backward pass optimizer.zero_grad() loss.backward() # Gradient clipping (prevent exploding gradients) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Update weights and learning rate optimizer.step() scheduler.step() total_loss += loss.item() # Logging if (batch_idx + 1) % 100 == 0: avg_loss = total_loss / (batch_idx + 1) lr = optimizer.param_groups[0]['lr'] print(f"Epoch {epoch}, Batch {batch_idx}, " f"Loss: {avg_loss:.4f}, LR: {lr:.6f}") # Validation val_loss = evaluate(model, val_loader, criterion, device) print(f"Epoch {epoch}, Val Loss: {val_loss:.4f}")

Optimization Techniques

Gradient Clipping

Purpose: Prevent exploding gradients (common in RNNs, less common in transformers but still useful).

# Clip gradients by norm torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Or clip by value torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)

Typical value: max_norm=1.0 for transformers

Gradient Accumulation

Purpose: Simulate larger batch sizes when GPU memory is limited.

accumulation_steps = 4 # Effective batch size = batch_size * 4 for i, batch in enumerate(dataloader): loss = compute_loss(model(batch)) loss = loss / accumulation_steps # Normalize loss loss.backward() # Accumulate gradients if (i + 1) % accumulation_steps == 0: optimizer.step() # Update weights optimizer.zero_grad() # Reset gradients scheduler.step() # Update LR

Why normalize loss? So gradient magnitude is independent of accumulation_steps.

Mixed Precision Training (FP16)

Purpose: Train faster with half precision (FP16) while maintaining stability.

from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() # Handles loss scaling for batch in dataloader: optimizer.zero_grad() # Forward pass in FP16 with autocast(): logits = model(batch) loss = criterion(logits, targets) # Backward pass with loss scaling scaler.scale(loss).backward() # Unscale before clipping scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Update with scaled gradients scaler.step(optimizer) scaler.update() scheduler.step()

Benefits: 2-3× faster training, 50% less memory

Hyperparameters from “Attention Is All You Need”

Base model:

  • Model dimension: dmodel=512d_{\text{model}} = 512
  • Feed-forward dimension: dff=2048d_{\text{ff}} = 2048 (4× expansion)
  • Number of attention heads: h=8h = 8
  • Number of encoder/decoder layers: N=6N = 6
  • Dropout: p=0.1p = 0.1
  • Warmup steps: 4000
  • Label smoothing: ϵ=0.1\epsilon = 0.1
  • Batch size: ~25,000 tokens per batch
  • Adam parameters: β1=0.9,β2=0.98,ϵ=109\beta_1=0.9, \beta_2=0.98, \epsilon=10^{-9}

Big model (better performance):

  • dmodel=1024d_{\text{model}} = 1024, dff=4096d_{\text{ff}} = 4096, h=16h = 16

Common Training Issues

Loss explodes early in training

  • Symptoms: NaN loss within first few batches
  • Solutions:
    • Increase warmup steps (try 8000 instead of 4000)
    • Reduce initial learning rate
    • Check for bugs in mask implementation
    • Enable gradient clipping

Model doesn’t learn (loss plateaus immediately)

  • Symptoms: Loss doesn’t decrease from initial value
  • Solutions:
    • Verify masks are correct (especially causal mask for decoders)
    • Check data preprocessing (tokenization, padding)
    • Ensure sufficient model capacity
    • Verify optimizer is updating parameters

Out of memory (OOM)

  • Symptoms: CUDA OOM error during forward/backward pass
  • Solutions:
    • Reduce batch size
    • Use gradient accumulation to maintain effective batch size
    • Enable gradient checkpointing (trades compute for memory)
    • Use mixed precision training (FP16)

Slow convergence

  • Symptoms: Model learns but very slowly
  • Solutions:
    • Increase batch size (larger batches → more stable gradients)
    • Tune learning rate (try different warmup steps)
    • Check data quality (noisy data slows learning)
    • Verify data is shuffled properly

Key Takeaways

  1. Masking is critical: Padding masks prevent attention to padding, causal masks enforce autoregressive property
  2. Label smoothing helps: Prevents overconfidence, improves generalization (ϵ=0.1\epsilon=0.1 is standard)
  3. Warmup is essential: Start with low LR, ramp up, then decay (typical: 4000-8000 steps)
  4. Gradient clipping recommended: Prevents instability (max_norm=1.0)
  5. Batch size matters: As large as memory allows (use gradient accumulation if needed)
  6. Mixed precision speeds up training: 2-3× faster with FP16
  7. Update LR every step: Not per epoch, per optimization step