Masked Prediction
Masked prediction is a self-supervised learning paradigm where models learn by predicting hidden portions of their input. This forces the model to develop rich contextual representations that understand relationships and structure in the data.
Core Concept
Fundamental Idea: Hide parts of the input and train the model to reconstruct them from context.
Learning Signal: Reconstruction error provides supervision without requiring labels.
Key Insight: To predict masked content accurately, the model must learn semantic relationships and contextual dependencies.
BERT: Masked Language Modeling
BERT (Bidirectional Encoder Representations from Transformers) revolutionized NLP by introducing masked language modeling as a pre-training objective.
The Masking Strategy
Randomly mask 15% of tokens in a sequence and predict them based on surrounding context:
# Original sentence
tokens = ["The", "patient", "presented", "with", "severe", "chest", "pain"]
# Mask 15% of tokens randomly
masked_tokens = ["The", "patient", "[MASK]", "with", "severe", "chest", "pain"]
# Model predicts: "presented"
predictions = bert(masked_tokens)
loss = cross_entropy(predictions[2], original_tokens[2])BERT’s 80-10-10 Rule
For each selected token (15% of all tokens), BERT applies:
- 80% of the time: Replace with
[MASK]token - 10% of the time: Replace with random token
- 10% of the time: Keep original token unchanged
Why this mixed strategy?
- Prevents overfitting to the
[MASK]token - Forces model to work with noisy/imperfect inputs
- Model must always be prepared to reconstruct, even when it sees real tokens
- Bridges gap between pre-training and fine-tuning (no
[MASK]in downstream tasks)
BERT Architecture
import torch
import torch.nn as nn
class BERT(nn.Module):
"""Bidirectional Encoder with masked language modeling"""
def __init__(self, vocab_size, hidden_size=768, num_layers=12, num_heads=12):
super().__init__()
# Token and position embeddings
self.token_embedding = nn.Embedding(vocab_size, hidden_size)
self.position_embedding = nn.Embedding(512, hidden_size) # Max 512 tokens
# Stack of bidirectional transformer encoders
self.transformer_layers = nn.ModuleList([
TransformerEncoderLayer(
d_model=hidden_size,
nhead=num_heads,
dim_feedforward=hidden_size * 4
)
for _ in range(num_layers)
])
self.layer_norm = nn.LayerNorm(hidden_size)
# Masked language modeling head
self.mlm_head = nn.Linear(hidden_size, vocab_size)
def forward(self, token_ids, masked_positions):
"""
Args:
token_ids: [batch_size, seq_len] - input tokens (some masked)
masked_positions: [batch_size, seq_len] - binary mask of positions to predict
Returns:
predictions: [batch_size, seq_len, vocab_size] - logits for masked positions
"""
batch_size, seq_len = token_ids.shape
# Create embeddings
token_embed = self.token_embedding(token_ids)
positions = torch.arange(seq_len, device=token_ids.device)
position_embed = self.position_embedding(positions)
# Combine embeddings
x = token_embed + position_embed # [batch_size, seq_len, hidden_size]
# Pass through transformer layers
for layer in self.transformer_layers:
x = layer(x)
x = self.layer_norm(x)
# Predict all token positions
predictions = self.mlm_head(x) # [batch_size, seq_len, vocab_size]
return predictions
def create_masked_lm_batch(tokens, mask_prob=0.15, vocab_size=30000):
"""Create masked tokens following BERT's 80-10-10 strategy"""
masked_tokens = tokens.clone()
targets = tokens.clone()
# Sample 15% of positions to mask
mask_positions = torch.rand(tokens.shape) < mask_prob
for i in range(tokens.shape[0]):
for j in range(tokens.shape[1]):
if mask_positions[i, j]:
rand = torch.rand(1).item()
if rand < 0.8:
# 80%: Replace with [MASK] token
masked_tokens[i, j] = MASK_TOKEN_ID
elif rand < 0.9:
# 10%: Replace with random token
masked_tokens[i, j] = torch.randint(0, vocab_size, (1,)).item()
# else: 10% keep original token
return masked_tokens, targets, mask_positions
# Training loop
def train_bert_mlm(model, data_loader, epochs=10):
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
model.train()
for epoch in range(epochs):
for batch_tokens in data_loader:
# Create masked batch
masked_tokens, targets, mask_positions = create_masked_lm_batch(
batch_tokens, mask_prob=0.15
)
# Forward pass
predictions = model(masked_tokens, mask_positions)
# Compute loss only on masked positions
predictions_masked = predictions[mask_positions]
targets_masked = targets[mask_positions]
loss = nn.CrossEntropyLoss()(predictions_masked, targets_masked)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 100 == 0:
print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}")
return modelWhat BERT Learns
By predicting masked tokens, BERT learns:
- Semantic relationships: Words that often appear together
- Syntactic patterns: Grammatical structure and dependencies
- Contextual meaning: Word meanings depend on surrounding context
- Bidirectional context: Uses both left and right context (unlike GPT)
MAE: Masked Autoencoders for Images
Masked Autoencoders (MAE) apply the masked prediction paradigm to computer vision with remarkable success.
Key Innovation: High Masking Ratio
MAE masks 75% of image patches - far higher than BERT’s 15%!
Why such aggressive masking?
- Lower ratios (~15%) can be solved via local interpolation without understanding
- 75% masking forces global semantic understanding
- Harder task leads to better learned representations
- Computational efficiency: encoder only processes 25% of patches
MAE Architecture
class MAE(nn.Module):
"""Masked Autoencoder for self-supervised vision pre-training"""
def __init__(self, img_size=224, patch_size=16, in_channels=3,
encoder_embed_dim=1024, encoder_depth=24, encoder_num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16):
super().__init__()
self.patch_size = patch_size
num_patches = (img_size // patch_size) ** 2
# Patch embedding
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=encoder_embed_dim
)
# Encoder: Deep and wide (processes only visible patches)
self.encoder = VisionTransformer(
embed_dim=encoder_embed_dim,
depth=encoder_depth,
num_heads=encoder_num_heads
)
# Decoder: Shallow and narrow (reconstructs all patches)
self.decoder = VisionTransformer(
embed_dim=decoder_embed_dim,
depth=decoder_depth,
num_heads=decoder_num_heads
)
# Learnable mask token for decoder
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
# Project encoder output to decoder dimension
self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim)
# Reconstruct pixels from patches
self.decoder_pred = nn.Linear(
decoder_embed_dim,
patch_size ** 2 * in_channels
)
def forward(self, imgs, mask_ratio=0.75):
"""
Args:
imgs: [batch_size, 3, H, W] - input images
mask_ratio: float - fraction of patches to mask
Returns:
loss: reconstruction loss on masked patches
pred: reconstructed image
mask: binary mask showing which patches were masked
"""
# Patchify image
patches = self.patch_embed(imgs) # [batch, num_patches, patch_dim]
batch_size, num_patches, patch_dim = patches.shape
# Random masking
shuffle_indices = torch.rand(batch_size, num_patches).argsort(dim=1)
num_keep = int(num_patches * (1 - mask_ratio))
# Indices of visible and masked patches
visible_indices = shuffle_indices[:, :num_keep]
masked_indices = shuffle_indices[:, num_keep:]
# Keep only visible patches for encoder (25% of patches)
visible_patches = torch.gather(
patches, dim=1,
index=visible_indices.unsqueeze(-1).expand(-1, -1, patch_dim)
)
# Encode visible patches
encoded = self.encoder(visible_patches) # [batch, num_keep, encoder_dim]
# Project to decoder dimension
encoded = self.encoder_to_decoder(encoded)
# Create mask tokens for masked positions
mask_tokens = self.mask_token.repeat(
batch_size, num_patches - num_keep, 1
)
# Combine encoded visible patches with mask tokens
# (Restore original order with unshuffle operation)
decoder_input = torch.cat([encoded, mask_tokens], dim=1)
decoder_input = self.unshuffle(decoder_input, shuffle_indices)
# Decode to reconstruct patches
decoded = self.decoder(decoder_input) # [batch, num_patches, decoder_dim]
pred_patches = self.decoder_pred(decoded) # [batch, num_patches, patch_dim]
# Compute loss only on masked patches
target_patches = patches
mask = torch.zeros(batch_size, num_patches, device=imgs.device)
mask.scatter_(1, masked_indices, 1)
loss = F.mse_loss(
pred_patches[mask.bool()],
target_patches[mask.bool()]
)
return loss, pred_patches, mask
def unshuffle(self, x, shuffle_indices):
"""Restore original patch order"""
batch_size, num_patches, dim = x.shape
unshuffle_indices = torch.argsort(shuffle_indices, dim=1)
x = torch.gather(
x, dim=1,
index=unshuffle_indices.unsqueeze(-1).expand(-1, -1, dim)
)
return xAsymmetric Encoder-Decoder Design
MAE uses an asymmetric architecture for computational efficiency:
| Component | Depth | Width | Input | Purpose |
|---|---|---|---|---|
| Encoder | Deep (24 layers) | Wide (1024-dim) | Only visible patches (25%) | Learn rich representations |
| Decoder | Shallow (8 layers) | Narrow (512-dim) | All patches (with mask tokens) | Lightweight reconstruction |
Why asymmetric?
- Encoder efficiency: Only processes 25% of patches → 3-4x training speedup
- Decoder simplicity: Reconstruction is easier than representation learning
- Encoder quality: Deep encoder learns the representations we care about
- Decoder disposability: Decoder is discarded after pre-training, only encoder is used
MAE Training Process
1. Split image into patches (e.g., 224×224 image → 196 patches of 16×16 pixels)
2. Randomly mask 75% of patches (keep only 49 visible patches)
3. Encode only visible patches through deep ViT encoder
4. Add learnable mask tokens for masked patch positions
5. Decode full sequence (49 encoded + 147 mask tokens) through shallow decoder
6. Reconstruct pixel values for all patches
7. Compute MSE loss only on masked patches (147 patches)Visualization
Original Image (16 patches):
[🐕] [🐕] [🐕] [🐕]
[🐕] [🐕] [🐕] [🐕]
[🐕] [🐕] [🐕] [🐕]
[🐕] [🐕] [🐕] [🐕]
After 75% masking (⬜ = masked):
[🐕] [⬜] [⬜] [⬜]
[⬜] [🐕] [⬜] [⬜]
[⬜] [⬜] [⬜] [🐕]
[⬜] [⬜] [🐕] [⬜]
MAE Reconstruction:
[🐕] [🐕] [🐕] [🐕]
[🐕] [🐕] [🐕] [🐕]
[🐕] [🐕] [🐕] [🐕]
[🐕] [🐕] [🐕] [🐕]The model learns global structure by being forced to reconstruct 75% of the image from just 25% of visible patches!
Application to Sequential Data
Masked prediction is particularly effective for sequential data like time series, event logs, and medical records.
Mask patient events in EHR sequences and predict them - similar to BERT for text, but with medical codes as tokens. This approach is used in models like BEHRT and Med-BERT. See Healthcare Foundation Models.
Example: Masked EHR Prediction
class MaskedEHRModel(nn.Module):
"""Masked prediction for electronic health records"""
def __init__(self, vocab_size, hidden_size=256, num_layers=6):
super().__init__()
self.event_embedding = nn.Embedding(vocab_size, hidden_size)
self.position_embedding = nn.Embedding(5000, hidden_size) # Max sequence length
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=hidden_size,
nhead=8,
dim_feedforward=hidden_size * 4
),
num_layers=num_layers
)
self.prediction_head = nn.Linear(hidden_size, vocab_size)
def forward(self, event_sequence, mask_positions):
"""
Args:
event_sequence: [batch, seq_len] - sequence of medical event codes
mask_positions: [batch, seq_len] - positions to predict
Returns:
predictions: [batch, seq_len, vocab_size] - predictions for all positions
"""
batch_size, seq_len = event_sequence.shape
# Embed events and positions
event_embed = self.event_embedding(event_sequence)
positions = torch.arange(seq_len, device=event_sequence.device)
pos_embed = self.position_embedding(positions)
x = event_embed + pos_embed
# Transformer encoding (bidirectional)
x = self.transformer(x.transpose(0, 1)) # Transformer expects [seq, batch, dim]
x = x.transpose(0, 1) # Back to [batch, seq, dim]
# Predict events at all positions
predictions = self.prediction_head(x)
return predictions
def mask_ehr_events(event_sequence, mask_prob=0.15, vocab_size=10000):
"""Create masked EHR sequence following BERT strategy"""
masked_sequence = event_sequence.clone()
targets = event_sequence.clone()
# Determine which positions to mask
mask_positions = torch.rand(event_sequence.shape) < mask_prob
# Apply 80-10-10 masking strategy
for i in range(event_sequence.shape[0]):
for j in range(event_sequence.shape[1]):
if mask_positions[i, j]:
rand = torch.rand(1).item()
if rand < 0.8:
masked_sequence[i, j] = MASK_EVENT_ID
elif rand < 0.9:
masked_sequence[i, j] = torch.randint(0, vocab_size, (1,)).item()
# else: keep original
return masked_sequence, targets, mask_positionsWhat the Model Learns from EHR Masking
By predicting masked medical events, the model learns:
- Temporal patterns: Which events typically follow others in patient trajectories
- Clinical relationships: Which diagnoses, procedures, and medications co-occur
- Disease progression: How conditions evolve over time
- Treatment patterns: Standard care pathways and intervention sequences
These learned representations transfer to downstream tasks:
- Readmission prediction
- Mortality risk assessment
- Diagnosis prediction
- Treatment response forecasting
Contrastive vs Masked Prediction
| Aspect | Contrastive Learning | Masked Prediction |
|---|---|---|
| Task | Match augmented views | Reconstruct masked input |
| Training Signal | Similarity/dissimilarity | Reconstruction error |
| Augmentation | Critical (defines positive pairs) | Less critical (masks provide signal) |
| Batch Size | Often large (2048+ for SimCLR) | Can be small (256) |
| Architecture | Usually symmetric encoders | Can be asymmetric (MAE) |
| Best for | Images, multimodal, spatial data | Sequences, structured data, text |
| Examples | SimCLR, MoCo, CLIP | BERT, MAE, GPT |
When to Use Each Approach
Use contrastive learning when:
- Strong augmentations are available (images, audio)
- You want invariance to transformations
- Data has natural pairs (text-image, audio-video)
- Working with multimodal data
Use masked prediction when:
- Data is inherently sequential (text, time series, logs)
- Structure and order are important
- Augmentations are hard to define
- You want to preserve local structure
Domain-specific guidance:
- Medical imaging: Contrastive (with medical-appropriate augmentations)
- EHR event sequences: Masked prediction (preserve temporal order)
- Clinical notes: Masked prediction (BERT-style)
- Radiology reports + images: Contrastive (clinical VLMs)
Key Insights
- Masking ratio matters: Higher masking (75% for images) forces global understanding
- Asymmetric design: Deep encoder + shallow decoder is efficient for MAE
- Mixed masking strategy: BERT’s 80-10-10 prevents overfitting to mask tokens
- Reconstruction target: Predict original input, not some transformation of it
- Transfer learning: Discard decoder after pre-training, use encoder for downstream tasks
Historical Context
- 2013: Word2Vec introduces context prediction for word embeddings
- 2018: BERT brings masked prediction to transformers, revolutionizes NLP
- 2019: ALBERT, RoBERTa improve BERT training
- 2020: GPT-3 shows autoregressive prediction (unidirectional variant) at massive scale
- 2021: MAE successfully applies masked prediction to computer vision
- 2022: Masked prediction becomes standard pre-training for vision transformers
Related Concepts
- Self-Supervised Learning - Learning from unlabeled data
- Contrastive Learning - Alternative self-supervised paradigm
- Language Model Training - Autoregressive variant of masked prediction
- Transformer Architecture - Underlying architecture for BERT
- Tokenization - Converting input to tokens for masking
- Healthcare Foundation Models - Domain-specific masked prediction
Further Reading
Papers
- BERT: “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding” (Devlin et al., 2018)
- MAE: “Masked Autoencoders Are Scalable Vision Learners” (He et al., 2021)
- BEHRT: “BEHRT: Transformer for Electronic Health Records” (Li et al., 2020)
- SimMIM: “SimMIM: A Simple Framework for Masked Image Modeling” (Xie et al., 2021)
Tutorials
- The Illustrated BERT: Visual guide to masked language modeling (Jay Alammar)
- MAE Paper Explained: Video walkthrough (Yannic Kilcher)
- Hugging Face BERT Tutorial: Hands-on masked LM training
- PyTorch MAE Implementation: Official code walkthrough
Code Implementations
- Hugging Face Transformers: BERT, RoBERTa, and variants
- timm library: MAE and other masked vision models
- Official MAE: PyTorch implementation by Facebook Research
- MinBERT: Minimal BERT implementation for learning