Regularization and L2 Weight Decay
The goal of machine learning is not just to fit the training data well, but to generalize to new, unseen data. Regularization techniques prevent overfitting and improve generalization.
Understanding Overfitting
The Three Datasets
Training set: Data used to compute gradients and update parameters
Validation set: Data used to:
- Tune hyperparameters (learning rate, regularization, architecture)
- Monitor overfitting
- Decide when to stop training
Test set: Data used only once at the end to estimate final performance
- Never use for any decisions during training!
- Simulates deployment on truly new data
Never use test data for any training decisions. The test set should be locked away until final evaluation. If you tune hyperparameters based on test performance, you’re effectively training on the test set!
Signs of Overfitting
Overfitting indicators:
- Training accuracy continues improving
- Validation accuracy plateaus or decreases
- Large gap between training and validation error
Underfitting indicators:
- Both training and validation accuracy are low
- Model is too simple to capture patterns
# Plotting learning curves
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_loss_history, label='Train')
plt.plot(val_loss_history, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Curves')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(train_acc_history, label='Train')
plt.plot(val_acc_history, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy Curves')
plt.legend()L2 Regularization (Weight Decay)
Add a penalty term to the loss function that discourages large weights:
where is the regularization strength (hyperparameter).
Why it works:
- Large weights = more complex function = more likely to overfit
- Small weights = smoother function = better generalization
- Geometric interpretation: Constrains weights to a sphere
Gradient of L2 regularization:
Update rule (called “weight decay”):
Note the term—weights decay toward zero each step!
Implementation
def loss_with_regularization(self, X, y, reg=0.0):
"""
Compute loss with L2 regularization.
Args:
X: Input data (N, D)
y: Labels (N,)
reg: Regularization strength (lambda)
Returns:
loss: Total loss (data loss + regularization loss)
"""
# Compute data loss (e.g., cross-entropy)
scores, cache = self.forward(X)
data_loss = self.softmax_loss(scores, y)
# Compute regularization loss
reg_loss = 0.5 * reg * (
np.sum(self.params['W1'] ** 2) +
np.sum(self.params['W2'] ** 2)
)
total_loss = data_loss + reg_loss
return total_loss
def backward_with_regularization(self, X, y, cache, reg=0.0):
"""Backward pass with regularization gradient."""
# Compute gradients from data loss
grads = self.backward(X, y, cache)
# Add regularization gradient
grads['W1'] += reg * self.params['W1']
grads['W2'] += reg * self.params['W2']
# Note: Don't regularize biases (common practice)
return gradsTypical values:
L1 Regularization
Penalty based on absolute values:
Effect: Encourages sparsity (many weights become exactly zero)
Use case: Feature selection, when you want interpretable models
Comparison:
- L2: Prefers many small weights
- L1: Prefers few large weights (sparsity)
Early Stopping
Stop training when validation performance stops improving:
best_val_acc = 0
patience = 5 # Number of epochs to wait
patience_counter = 0
for epoch in range(max_epochs):
# Train for one epoch
train_one_epoch()
# Evaluate on validation set
val_acc = evaluate(X_val, y_val)
if val_acc > best_val_acc:
best_val_acc = val_acc
patience_counter = 0
save_checkpoint() # Save best model
else:
patience_counter += 1
if patience_counter >= patience:
print(f"Early stopping at epoch {epoch}")
breakBenefits:
- Simple and effective
- Prevents overfitting by stopping before memorization
- Saves training time
Data Augmentation
Create more training examples by applying transformations:
For Images:
- Rotation, flipping, cropping
- Color jittering, brightness adjustment
- Random erasing, cutout
For Text:
- Synonym replacement
- Back-translation
- Random insertion/deletion
For Audio:
- Time stretching
- Pitch shifting
- Adding noise
Benefits:
- Artificially increases dataset size
- Teaches invariances (e.g., rotation-invariant object recognition)
- Very effective, especially for small datasets
from torchvision import transforms
# Image augmentation example
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.RandomResizedCrop(224),
transforms.ToTensor(),
])Regularization Strategy
Step-by-Step Approach
Step 1: Train without regularization
- Establish a baseline
- See if model can fit training data
Step 2: Check for overfitting
- Plot train vs. validation curves
- Large gap? → Add regularization
Step 3: Apply regularization
- Start with L2 regularization ()
- Add dropout if still overfitting
- Consider data augmentation if applicable
Step 4: Tune regularization strength
- Grid search over
- Use validation set to select best values
Recommended regularization stack:
- L2 regularization:
- Dropout in fully-connected layers:
- Data augmentation (if applicable)
- Early stopping based on validation performance
This combination works well for most problems.
Tuning Regularization
# Compare different regularization strengths
regularization_strengths = [0, 1e-5, 1e-4, 1e-3, 1e-2]
results = {}
for reg in regularization_strengths:
net = TwoLayerNet(input_dim, hidden_dim, output_dim)
history = train(net, X_train, y_train, X_val, y_val,
learning_rate=1e-3, reg=reg, num_epochs=100)
results[f'reg={reg}'] = history
# Plot overfitting behavior
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
for name, history in results.items():
plt.plot(history['train_acc_history'], label=f'{name} (train)')
plt.xlabel('Epoch')
plt.ylabel('Training Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
for name, history in results.items():
plt.plot(history['val_acc_history'], label=f'{name} (val)')
plt.xlabel('Epoch')
plt.ylabel('Validation Accuracy')
plt.legend()Expected observations:
- No regularization (): Large train-val gap
- Too much regularization: Both train and val accuracy low (underfitting)
- Optimal regularization: Small train-val gap, best val accuracy
Learning Resources
Videos
Reading
- CS231n: Neural Networks Part 2 (regularization section)
- Nielsen: Neural Networks and Deep Learning - Chapter 3
Related Concepts
- Dropout - Powerful regularization technique
- Bias-Variance Tradeoff - Understanding overfitting
- Batch Normalization - Also acts as regularization
- Optimization - Training algorithms
Next Steps
- Learn about dropout for stronger regularization
- Understand the bias-variance tradeoff
- Explore batch normalization
- Apply in MNIST example