Training Transformers
Training transformer models requires careful attention to masking strategies, loss functions, and learning rate schedules to achieve stable and effective learning. This guide covers the essential techniques from the “Attention Is All You Need” paper and modern best practices.
Masking Strategies
Transformers use different types of masks to control attention patterns and prevent information leakage.
Padding Mask
Purpose: Prevent attention to padding tokens (used to batch variable-length sequences).
Problem: Sequences in a batch have different lengths, so we pad shorter sequences:
Batch of sequences:
Sequence 1: "The cat sat" (3 tokens) + 2 padding
Sequence 2: "Hello world" (2 tokens) + 3 padding
Sequence 3: "A" (1 token) + 4 paddingSolution: Create a mask that marks padding positions as invalid for attention.
Implementation:
def create_padding_mask(seq, pad_token_id=0):
"""
Create mask for padding tokens
Args:
seq: Token IDs (batch, seq_len)
pad_token_id: ID used for padding (usually 0)
Returns:
mask: (batch, 1, 1, seq_len) for broadcasting across heads
"""
# Mark non-padding positions as 1, padding as 0
mask = (seq != pad_token_id).unsqueeze(1).unsqueeze(2)
# Shape: (batch, 1, 1, seq_len)
return mask
# Usage in scaled dot-product attention:
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
scores = scores.masked_fill(mask == 0, -1e9) # Set padding to -inf
attn = F.softmax(scores, dim=-1) # softmax(-inf) ≈ 0Why -1e9? After softmax, , so padding positions get zero attention weight.
Causal Mask (Look-Ahead Mask)
Purpose: Prevent decoder from attending to future positions during training (maintains autoregressive property).
Why needed? During training, we have the full target sequence, but we want to ensure the model learns to generate each token based only on previous tokens.
Causal mask structure:
Position can attend to:
Position 0: [0] → only itself
Position 1: [0, 1] → 0 and itself
Position 2: [0, 1, 2] → 0, 1, and itself
Position 3: [0, 1, 2, 3] → all previous + itselfImplementation:
def create_causal_mask(seq_len):
"""
Create lower triangular mask for causal attention
Args:
seq_len: Sequence length
Returns:
mask: (seq_len, seq_len) lower triangular matrix
"""
# Create lower triangular matrix
mask = torch.tril(torch.ones(seq_len, seq_len))
# [[1, 0, 0, 0],
# [1, 1, 0, 0],
# [1, 1, 1, 0],
# [1, 1, 1, 1]]
return mask
# Equivalent using upper triangular:
mask = (torch.triu(torch.ones(seq_len, seq_len), diagonal=1) == 0)Effect: Position can only attend to positions , enforcing left-to-right information flow.
Combined Mask
In practice, decoders need both padding and causal masks:
def create_combined_mask(tgt_seq, pad_token_id=0):
"""
Combine causal and padding masks for decoder self-attention
Args:
tgt_seq: Target sequence (batch, seq_len)
pad_token_id: Padding token ID
Returns:
mask: Combined mask (batch, 1, seq_len, seq_len)
"""
batch_size, seq_len = tgt_seq.size()
# Padding mask: (batch, 1, 1, seq_len)
padding_mask = create_padding_mask(tgt_seq, pad_token_id)
# Causal mask: (1, 1, seq_len, seq_len)
causal_mask = create_causal_mask(seq_len)
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
# Combine: BOTH conditions must be satisfied (logical AND)
mask = padding_mask & causal_mask.to(tgt_seq.device)
return maskWhen to use which mask:
- Encoder self-attention: Padding mask only (bidirectional)
- Decoder self-attention: Padding + causal mask (unidirectional)
- Cross-attention (decoder attending to encoder): Padding mask on source only
Loss Functions
Transformers typically use cross-entropy loss over the vocabulary for each position.
Standard Cross-Entropy
For sequence-to-sequence tasks, the loss sums over all positions:
Where:
- : Target token at position
- : All previous target tokens
- : Source sequence (for encoder-decoder)
Implementation:
def compute_loss(logits, targets, pad_token_id=0):
"""
Compute cross-entropy loss, ignoring padding positions
Args:
logits: Model output (batch, seq_len, vocab_size)
targets: Target tokens (batch, seq_len)
pad_token_id: Padding token ID to ignore
Returns:
loss: Scalar loss value
"""
batch_size, seq_len, vocab_size = logits.size()
# Reshape for cross-entropy
# logits: (batch * seq_len, vocab_size)
# targets: (batch * seq_len)
logits_flat = logits.view(-1, vocab_size)
targets_flat = targets.view(-1)
# Compute cross-entropy (automatically ignores pad_token_id)
loss = F.cross_entropy(
logits_flat,
targets_flat,
ignore_index=pad_token_id, # Don't compute loss for padding
reduction='mean'
)
return lossLabel Smoothing
Purpose: Regularization technique to prevent overconfidence and improve generalization.
Problem with hard targets: Model becomes overconfident
Target: [0, 0, 1, 0, 0] # One-hot for token ID 2
Model learns to output: [0.001, 0.001, 0.998, 0.0, 0.0] # OverconfidentSolution with label smoothing: Distribute some probability mass to other tokens
Smoothed target: [0.02, 0.02, 0.92, 0.02, 0.02] # ε = 0.1
Model learns to output: [0.05, 0.05, 0.85, 0.03, 0.02] # More calibratedFormula:
Where is the smoothing parameter (e.g., 0.1) and is vocabulary size.
Implementation:
class LabelSmoothingLoss(nn.Module):
def __init__(self, vocab_size, smoothing=0.1, pad_token_id=0):
super().__init__()
self.vocab_size = vocab_size
self.smoothing = smoothing
self.confidence = 1.0 - smoothing
self.pad_token_id = pad_token_id
def forward(self, logits, targets):
"""
Args:
logits: (batch, seq_len, vocab_size)
targets: (batch, seq_len)
"""
batch_size, seq_len, vocab_size = logits.size()
# Flatten
logits = logits.view(-1, vocab_size)
targets = targets.view(-1)
# Log probabilities
log_probs = F.log_softmax(logits, dim=-1)
# Create smoothed target distribution
true_dist = torch.zeros_like(log_probs)
# Distribute smoothing probability among non-target tokens
true_dist.fill_(self.smoothing / (vocab_size - 2)) # -2 for target and pad
# Set confidence for true class
true_dist.scatter_(1, targets.unsqueeze(1), self.confidence)
# Don't compute loss for padding
true_dist[:, self.pad_token_id] = 0
mask = (targets != self.pad_token_id).float()
# KL divergence between smoothed targets and predictions
loss = -(true_dist * log_probs).sum(dim=-1)
loss = (loss * mask).sum() / mask.sum()
return loss
# Usage:
criterion = LabelSmoothingLoss(vocab_size=50000, smoothing=0.1)
loss = criterion(logits, targets)Benefits of label smoothing:
- Prevents overconfidence
- Improves calibration
- Better generalization (especially for BLEU scores in translation)
- Recommended in original transformer paper
Learning Rate Schedules
Transformers require careful learning rate scheduling with warmup followed by decay.
Warmup + Inverse Square Root Decay
The original “Attention Is All You Need” paper uses:
Behavior:
- Warmup phase (steps 0 to
warmup_steps): Linear increase from 0 - Decay phase (after
warmup_steps): Inverse square root decay
Why warmup?
- At initialization, parameters are random → gradients are large and noisy
- Large learning rate at start causes instability and divergence
- Warmup allows model to “settle” into a reasonable basin before aggressive learning
Visualization:
Learning Rate
│ ╱‾‾‾‾‾‾⤵
│ ╱ ⤵
│ ╱ ⤵
│ ╱ ⤵
│ ╱ ⤵
│╱ ⤵___
└────────────────────────────> Steps
↑
warmup_steps (4000)Implementation:
class TransformerLRScheduler:
def __init__(self, optimizer, d_model, warmup_steps=4000):
"""
Learning rate scheduler from "Attention Is All You Need"
Args:
optimizer: PyTorch optimizer
d_model: Model dimension (e.g., 512)
warmup_steps: Number of warmup steps (e.g., 4000)
"""
self.optimizer = optimizer
self.d_model = d_model
self.warmup_steps = warmup_steps
self.current_step = 0
def step(self):
"""Update learning rate"""
self.current_step += 1
# Compute learning rate using formula
lr = (self.d_model ** -0.5) * min(
self.current_step ** -0.5,
self.current_step * (self.warmup_steps ** -1.5)
)
# Update all parameter groups
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
return lr
# Usage:
optimizer = torch.optim.Adam(
model.parameters(),
lr=1.0, # Will be controlled by scheduler
betas=(0.9, 0.98),
eps=1e-9
)
scheduler = TransformerLRScheduler(optimizer, d_model=512, warmup_steps=4000)
for epoch in range(num_epochs):
for batch in dataloader:
optimizer.zero_grad()
loss = compute_loss(model(batch))
loss.backward()
optimizer.step()
scheduler.step() # Update LR every step (not epoch!)Alternative: Cosine Annealing with Warmup
Modern transformers (GPT, BERT) often use cosine annealing:
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
# Warmup phase: linear increase
warmup_scheduler = LinearLR(
optimizer,
start_factor=0.1, # Start at 10% of base LR
total_iters=warmup_steps
)
# Decay phase: cosine annealing to 0
cosine_scheduler = CosineAnnealingLR(
optimizer,
T_max=total_steps - warmup_steps,
eta_min=0
)
# Combine sequentially
scheduler = SequentialLR(
optimizer,
schedulers=[warmup_scheduler, cosine_scheduler],
milestones=[warmup_steps]
)Comparison:
- Inverse sqrt decay: Slower decay, never reaches zero, used in original transformer
- Cosine annealing: Faster decay, reaches zero at end, popular in modern LLMs
See Language Model Training for more on modern LR schedules.
Complete Training Loop
Putting it all together:
def train_transformer(model, train_loader, val_loader, num_epochs=10):
"""
Train transformer model with all best practices
Args:
model: Transformer model (encoder-decoder or decoder-only)
train_loader: Training data loader
val_loader: Validation data loader
num_epochs: Number of training epochs
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# Optimizer (Note: lr=1.0, will be controlled by scheduler)
optimizer = torch.optim.Adam(
model.parameters(),
lr=1.0,
betas=(0.9, 0.98), # Transformer paper values
eps=1e-9
)
# Learning rate scheduler
scheduler = TransformerLRScheduler(
optimizer,
d_model=model.d_model,
warmup_steps=4000
)
# Loss function with label smoothing
criterion = LabelSmoothingLoss(
vocab_size=model.tgt_vocab_size,
smoothing=0.1
)
for epoch in range(num_epochs):
model.train()
total_loss = 0
for batch_idx, (src, tgt) in enumerate(train_loader):
src = src.to(device)
tgt = tgt.to(device)
# Prepare decoder input (shift target right by 1)
tgt_input = tgt[:, :-1] # Remove last token
tgt_output = tgt[:, 1:] # Remove first token (usually <BOS>)
# Create masks
src_mask = create_padding_mask(src)
tgt_mask = create_combined_mask(tgt_input)
# Forward pass
logits = model(src, tgt_input, src_mask, tgt_mask)
# Compute loss
loss = criterion(logits, tgt_output)
# Backward pass
optimizer.zero_grad()
loss.backward()
# Gradient clipping (prevent exploding gradients)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Update weights and learning rate
optimizer.step()
scheduler.step()
total_loss += loss.item()
# Logging
if (batch_idx + 1) % 100 == 0:
avg_loss = total_loss / (batch_idx + 1)
lr = optimizer.param_groups[0]['lr']
print(f"Epoch {epoch}, Batch {batch_idx}, "
f"Loss: {avg_loss:.4f}, LR: {lr:.6f}")
# Validation
val_loss = evaluate(model, val_loader, criterion, device)
print(f"Epoch {epoch}, Val Loss: {val_loss:.4f}")Optimization Techniques
Gradient Clipping
Purpose: Prevent exploding gradients (common in RNNs, less common in transformers but still useful).
# Clip gradients by norm
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Or clip by value
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)Typical value: max_norm=1.0 for transformers
Gradient Accumulation
Purpose: Simulate larger batch sizes when GPU memory is limited.
accumulation_steps = 4 # Effective batch size = batch_size * 4
for i, batch in enumerate(dataloader):
loss = compute_loss(model(batch))
loss = loss / accumulation_steps # Normalize loss
loss.backward() # Accumulate gradients
if (i + 1) % accumulation_steps == 0:
optimizer.step() # Update weights
optimizer.zero_grad() # Reset gradients
scheduler.step() # Update LRWhy normalize loss? So gradient magnitude is independent of accumulation_steps.
Mixed Precision Training (FP16)
Purpose: Train faster with half precision (FP16) while maintaining stability.
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler() # Handles loss scaling
for batch in dataloader:
optimizer.zero_grad()
# Forward pass in FP16
with autocast():
logits = model(batch)
loss = criterion(logits, targets)
# Backward pass with loss scaling
scaler.scale(loss).backward()
# Unscale before clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Update with scaled gradients
scaler.step(optimizer)
scaler.update()
scheduler.step()Benefits: 2-3× faster training, 50% less memory
Hyperparameters from “Attention Is All You Need”
Base model:
- Model dimension:
- Feed-forward dimension: (4× expansion)
- Number of attention heads:
- Number of encoder/decoder layers:
- Dropout:
- Warmup steps: 4000
- Label smoothing:
- Batch size: ~25,000 tokens per batch
- Adam parameters:
Big model (better performance):
- , ,
Common Training Issues
Loss explodes early in training
- Symptoms: NaN loss within first few batches
- Solutions:
- Increase warmup steps (try 8000 instead of 4000)
- Reduce initial learning rate
- Check for bugs in mask implementation
- Enable gradient clipping
Model doesn’t learn (loss plateaus immediately)
- Symptoms: Loss doesn’t decrease from initial value
- Solutions:
- Verify masks are correct (especially causal mask for decoders)
- Check data preprocessing (tokenization, padding)
- Ensure sufficient model capacity
- Verify optimizer is updating parameters
Out of memory (OOM)
- Symptoms: CUDA OOM error during forward/backward pass
- Solutions:
- Reduce batch size
- Use gradient accumulation to maintain effective batch size
- Enable gradient checkpointing (trades compute for memory)
- Use mixed precision training (FP16)
Slow convergence
- Symptoms: Model learns but very slowly
- Solutions:
- Increase batch size (larger batches → more stable gradients)
- Tune learning rate (try different warmup steps)
- Check data quality (noisy data slows learning)
- Verify data is shuffled properly
Key Takeaways
- Masking is critical: Padding masks prevent attention to padding, causal masks enforce autoregressive property
- Label smoothing helps: Prevents overconfidence, improves generalization ( is standard)
- Warmup is essential: Start with low LR, ramp up, then decay (typical: 4000-8000 steps)
- Gradient clipping recommended: Prevents instability (max_norm=1.0)
- Batch size matters: As large as memory allows (use gradient accumulation if needed)
- Mixed precision speeds up training: 2-3× faster with FP16
- Update LR every step: Not per epoch, per optimization step
Related Concepts
- Attention Is All You Need - The original transformer paper with training details
- Attention Mechanism - Understanding what you’re training
- Language Model Training - Modern LM training techniques
- Optimization - Optimizer fundamentals
- Training Practices - Debugging and best practices