Skip to Content
LibraryConceptsMasked Prediction

Masked Prediction

Masked prediction is a self-supervised learning paradigm where models learn by predicting hidden portions of their input. This forces the model to develop rich contextual representations that understand relationships and structure in the data.

Core Concept

Fundamental Idea: Hide parts of the input and train the model to reconstruct them from context.

Learning Signal: Reconstruction error provides supervision without requiring labels.

Key Insight: To predict masked content accurately, the model must learn semantic relationships and contextual dependencies.

BERT: Masked Language Modeling

BERT (Bidirectional Encoder Representations from Transformers) revolutionized NLP by introducing masked language modeling as a pre-training objective.

The Masking Strategy

Randomly mask 15% of tokens in a sequence and predict them based on surrounding context:

# Original sentence tokens = ["The", "patient", "presented", "with", "severe", "chest", "pain"] # Mask 15% of tokens randomly masked_tokens = ["The", "patient", "[MASK]", "with", "severe", "chest", "pain"] # Model predicts: "presented" predictions = bert(masked_tokens) loss = cross_entropy(predictions[2], original_tokens[2])

BERT’s 80-10-10 Rule

For each selected token (15% of all tokens), BERT applies:

  • 80% of the time: Replace with [MASK] token
  • 10% of the time: Replace with random token
  • 10% of the time: Keep original token unchanged

Why this mixed strategy?

  • Prevents overfitting to the [MASK] token
  • Forces model to work with noisy/imperfect inputs
  • Model must always be prepared to reconstruct, even when it sees real tokens
  • Bridges gap between pre-training and fine-tuning (no [MASK] in downstream tasks)

BERT Architecture

import torch import torch.nn as nn class BERT(nn.Module): """Bidirectional Encoder with masked language modeling""" def __init__(self, vocab_size, hidden_size=768, num_layers=12, num_heads=12): super().__init__() # Token and position embeddings self.token_embedding = nn.Embedding(vocab_size, hidden_size) self.position_embedding = nn.Embedding(512, hidden_size) # Max 512 tokens # Stack of bidirectional transformer encoders self.transformer_layers = nn.ModuleList([ TransformerEncoderLayer( d_model=hidden_size, nhead=num_heads, dim_feedforward=hidden_size * 4 ) for _ in range(num_layers) ]) self.layer_norm = nn.LayerNorm(hidden_size) # Masked language modeling head self.mlm_head = nn.Linear(hidden_size, vocab_size) def forward(self, token_ids, masked_positions): """ Args: token_ids: [batch_size, seq_len] - input tokens (some masked) masked_positions: [batch_size, seq_len] - binary mask of positions to predict Returns: predictions: [batch_size, seq_len, vocab_size] - logits for masked positions """ batch_size, seq_len = token_ids.shape # Create embeddings token_embed = self.token_embedding(token_ids) positions = torch.arange(seq_len, device=token_ids.device) position_embed = self.position_embedding(positions) # Combine embeddings x = token_embed + position_embed # [batch_size, seq_len, hidden_size] # Pass through transformer layers for layer in self.transformer_layers: x = layer(x) x = self.layer_norm(x) # Predict all token positions predictions = self.mlm_head(x) # [batch_size, seq_len, vocab_size] return predictions def create_masked_lm_batch(tokens, mask_prob=0.15, vocab_size=30000): """Create masked tokens following BERT's 80-10-10 strategy""" masked_tokens = tokens.clone() targets = tokens.clone() # Sample 15% of positions to mask mask_positions = torch.rand(tokens.shape) < mask_prob for i in range(tokens.shape[0]): for j in range(tokens.shape[1]): if mask_positions[i, j]: rand = torch.rand(1).item() if rand < 0.8: # 80%: Replace with [MASK] token masked_tokens[i, j] = MASK_TOKEN_ID elif rand < 0.9: # 10%: Replace with random token masked_tokens[i, j] = torch.randint(0, vocab_size, (1,)).item() # else: 10% keep original token return masked_tokens, targets, mask_positions # Training loop def train_bert_mlm(model, data_loader, epochs=10): optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) model.train() for epoch in range(epochs): for batch_tokens in data_loader: # Create masked batch masked_tokens, targets, mask_positions = create_masked_lm_batch( batch_tokens, mask_prob=0.15 ) # Forward pass predictions = model(masked_tokens, mask_positions) # Compute loss only on masked positions predictions_masked = predictions[mask_positions] targets_masked = targets[mask_positions] loss = nn.CrossEntropyLoss()(predictions_masked, targets_masked) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() if step % 100 == 0: print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}") return model

What BERT Learns

By predicting masked tokens, BERT learns:

  • Semantic relationships: Words that often appear together
  • Syntactic patterns: Grammatical structure and dependencies
  • Contextual meaning: Word meanings depend on surrounding context
  • Bidirectional context: Uses both left and right context (unlike GPT)

MAE: Masked Autoencoders for Images

Masked Autoencoders (MAE) apply the masked prediction paradigm to computer vision with remarkable success.

Key Innovation: High Masking Ratio

MAE masks 75% of image patches - far higher than BERT’s 15%!

Why such aggressive masking?

  • Lower ratios (~15%) can be solved via local interpolation without understanding
  • 75% masking forces global semantic understanding
  • Harder task leads to better learned representations
  • Computational efficiency: encoder only processes 25% of patches

MAE Architecture

class MAE(nn.Module): """Masked Autoencoder for self-supervised vision pre-training""" def __init__(self, img_size=224, patch_size=16, in_channels=3, encoder_embed_dim=1024, encoder_depth=24, encoder_num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16): super().__init__() self.patch_size = patch_size num_patches = (img_size // patch_size) ** 2 # Patch embedding self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=encoder_embed_dim ) # Encoder: Deep and wide (processes only visible patches) self.encoder = VisionTransformer( embed_dim=encoder_embed_dim, depth=encoder_depth, num_heads=encoder_num_heads ) # Decoder: Shallow and narrow (reconstructs all patches) self.decoder = VisionTransformer( embed_dim=decoder_embed_dim, depth=decoder_depth, num_heads=decoder_num_heads ) # Learnable mask token for decoder self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) # Project encoder output to decoder dimension self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim) # Reconstruct pixels from patches self.decoder_pred = nn.Linear( decoder_embed_dim, patch_size ** 2 * in_channels ) def forward(self, imgs, mask_ratio=0.75): """ Args: imgs: [batch_size, 3, H, W] - input images mask_ratio: float - fraction of patches to mask Returns: loss: reconstruction loss on masked patches pred: reconstructed image mask: binary mask showing which patches were masked """ # Patchify image patches = self.patch_embed(imgs) # [batch, num_patches, patch_dim] batch_size, num_patches, patch_dim = patches.shape # Random masking shuffle_indices = torch.rand(batch_size, num_patches).argsort(dim=1) num_keep = int(num_patches * (1 - mask_ratio)) # Indices of visible and masked patches visible_indices = shuffle_indices[:, :num_keep] masked_indices = shuffle_indices[:, num_keep:] # Keep only visible patches for encoder (25% of patches) visible_patches = torch.gather( patches, dim=1, index=visible_indices.unsqueeze(-1).expand(-1, -1, patch_dim) ) # Encode visible patches encoded = self.encoder(visible_patches) # [batch, num_keep, encoder_dim] # Project to decoder dimension encoded = self.encoder_to_decoder(encoded) # Create mask tokens for masked positions mask_tokens = self.mask_token.repeat( batch_size, num_patches - num_keep, 1 ) # Combine encoded visible patches with mask tokens # (Restore original order with unshuffle operation) decoder_input = torch.cat([encoded, mask_tokens], dim=1) decoder_input = self.unshuffle(decoder_input, shuffle_indices) # Decode to reconstruct patches decoded = self.decoder(decoder_input) # [batch, num_patches, decoder_dim] pred_patches = self.decoder_pred(decoded) # [batch, num_patches, patch_dim] # Compute loss only on masked patches target_patches = patches mask = torch.zeros(batch_size, num_patches, device=imgs.device) mask.scatter_(1, masked_indices, 1) loss = F.mse_loss( pred_patches[mask.bool()], target_patches[mask.bool()] ) return loss, pred_patches, mask def unshuffle(self, x, shuffle_indices): """Restore original patch order""" batch_size, num_patches, dim = x.shape unshuffle_indices = torch.argsort(shuffle_indices, dim=1) x = torch.gather( x, dim=1, index=unshuffle_indices.unsqueeze(-1).expand(-1, -1, dim) ) return x

Asymmetric Encoder-Decoder Design

MAE uses an asymmetric architecture for computational efficiency:

ComponentDepthWidthInputPurpose
EncoderDeep (24 layers)Wide (1024-dim)Only visible patches (25%)Learn rich representations
DecoderShallow (8 layers)Narrow (512-dim)All patches (with mask tokens)Lightweight reconstruction

Why asymmetric?

  • Encoder efficiency: Only processes 25% of patches → 3-4x training speedup
  • Decoder simplicity: Reconstruction is easier than representation learning
  • Encoder quality: Deep encoder learns the representations we care about
  • Decoder disposability: Decoder is discarded after pre-training, only encoder is used

MAE Training Process

1. Split image into patches (e.g., 224×224 image → 196 patches of 16×16 pixels) 2. Randomly mask 75% of patches (keep only 49 visible patches) 3. Encode only visible patches through deep ViT encoder 4. Add learnable mask tokens for masked patch positions 5. Decode full sequence (49 encoded + 147 mask tokens) through shallow decoder 6. Reconstruct pixel values for all patches 7. Compute MSE loss only on masked patches (147 patches)

Visualization

Original Image (16 patches): [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] After 75% masking (⬜ = masked): [🐕] [⬜] [⬜] [⬜] [⬜] [🐕] [⬜] [⬜] [⬜] [⬜] [⬜] [🐕] [⬜] [⬜] [🐕] [⬜] MAE Reconstruction: [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] [🐕] [🐕]

The model learns global structure by being forced to reconstruct 75% of the image from just 25% of visible patches!

Application to Sequential Data

Masked prediction is particularly effective for sequential data like time series, event logs, and medical records.

Healthcare Application

Mask patient events in EHR sequences and predict them - similar to BERT for text, but with medical codes as tokens. This approach is used in models like BEHRT and Med-BERT. See Healthcare Foundation Models.

Example: Masked EHR Prediction

class MaskedEHRModel(nn.Module): """Masked prediction for electronic health records""" def __init__(self, vocab_size, hidden_size=256, num_layers=6): super().__init__() self.event_embedding = nn.Embedding(vocab_size, hidden_size) self.position_embedding = nn.Embedding(5000, hidden_size) # Max sequence length self.transformer = nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model=hidden_size, nhead=8, dim_feedforward=hidden_size * 4 ), num_layers=num_layers ) self.prediction_head = nn.Linear(hidden_size, vocab_size) def forward(self, event_sequence, mask_positions): """ Args: event_sequence: [batch, seq_len] - sequence of medical event codes mask_positions: [batch, seq_len] - positions to predict Returns: predictions: [batch, seq_len, vocab_size] - predictions for all positions """ batch_size, seq_len = event_sequence.shape # Embed events and positions event_embed = self.event_embedding(event_sequence) positions = torch.arange(seq_len, device=event_sequence.device) pos_embed = self.position_embedding(positions) x = event_embed + pos_embed # Transformer encoding (bidirectional) x = self.transformer(x.transpose(0, 1)) # Transformer expects [seq, batch, dim] x = x.transpose(0, 1) # Back to [batch, seq, dim] # Predict events at all positions predictions = self.prediction_head(x) return predictions def mask_ehr_events(event_sequence, mask_prob=0.15, vocab_size=10000): """Create masked EHR sequence following BERT strategy""" masked_sequence = event_sequence.clone() targets = event_sequence.clone() # Determine which positions to mask mask_positions = torch.rand(event_sequence.shape) < mask_prob # Apply 80-10-10 masking strategy for i in range(event_sequence.shape[0]): for j in range(event_sequence.shape[1]): if mask_positions[i, j]: rand = torch.rand(1).item() if rand < 0.8: masked_sequence[i, j] = MASK_EVENT_ID elif rand < 0.9: masked_sequence[i, j] = torch.randint(0, vocab_size, (1,)).item() # else: keep original return masked_sequence, targets, mask_positions

What the Model Learns from EHR Masking

By predicting masked medical events, the model learns:

  • Temporal patterns: Which events typically follow others in patient trajectories
  • Clinical relationships: Which diagnoses, procedures, and medications co-occur
  • Disease progression: How conditions evolve over time
  • Treatment patterns: Standard care pathways and intervention sequences

These learned representations transfer to downstream tasks:

  • Readmission prediction
  • Mortality risk assessment
  • Diagnosis prediction
  • Treatment response forecasting

Contrastive vs Masked Prediction

AspectContrastive LearningMasked Prediction
TaskMatch augmented viewsReconstruct masked input
Training SignalSimilarity/dissimilarityReconstruction error
AugmentationCritical (defines positive pairs)Less critical (masks provide signal)
Batch SizeOften large (2048+ for SimCLR)Can be small (256)
ArchitectureUsually symmetric encodersCan be asymmetric (MAE)
Best forImages, multimodal, spatial dataSequences, structured data, text
ExamplesSimCLR, MoCo, CLIPBERT, MAE, GPT

When to Use Each Approach

Use contrastive learning when:

  • Strong augmentations are available (images, audio)
  • You want invariance to transformations
  • Data has natural pairs (text-image, audio-video)
  • Working with multimodal data

Use masked prediction when:

  • Data is inherently sequential (text, time series, logs)
  • Structure and order are important
  • Augmentations are hard to define
  • You want to preserve local structure

Domain-specific guidance:

  • Medical imaging: Contrastive (with medical-appropriate augmentations)
  • EHR event sequences: Masked prediction (preserve temporal order)
  • Clinical notes: Masked prediction (BERT-style)
  • Radiology reports + images: Contrastive (clinical VLMs)

Key Insights

  1. Masking ratio matters: Higher masking (75% for images) forces global understanding
  2. Asymmetric design: Deep encoder + shallow decoder is efficient for MAE
  3. Mixed masking strategy: BERT’s 80-10-10 prevents overfitting to mask tokens
  4. Reconstruction target: Predict original input, not some transformation of it
  5. Transfer learning: Discard decoder after pre-training, use encoder for downstream tasks

Historical Context

  • 2013: Word2Vec introduces context prediction for word embeddings
  • 2018: BERT brings masked prediction to transformers, revolutionizes NLP
  • 2019: ALBERT, RoBERTa improve BERT training
  • 2020: GPT-3 shows autoregressive prediction (unidirectional variant) at massive scale
  • 2021: MAE successfully applies masked prediction to computer vision
  • 2022: Masked prediction becomes standard pre-training for vision transformers

Further Reading

Papers

  • BERT: “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding” (Devlin et al., 2018)
  • MAE: “Masked Autoencoders Are Scalable Vision Learners” (He et al., 2021)
  • BEHRT: “BEHRT: Transformer for Electronic Health Records” (Li et al., 2020)
  • SimMIM: “SimMIM: A Simple Framework for Masked Image Modeling” (Xie et al., 2021)

Tutorials

  • The Illustrated BERT: Visual guide to masked language modeling (Jay Alammar)
  • MAE Paper Explained: Video walkthrough (Yannic Kilcher)
  • Hugging Face BERT Tutorial: Hands-on masked LM training
  • PyTorch MAE Implementation: Official code walkthrough

Code Implementations

  • Hugging Face Transformers: BERT, RoBERTa, and variants
  • timm library: MAE and other masked vision models
  • Official MAE: PyTorch implementation by Facebook Research
  • MinBERT: Minimal BERT implementation for learning