Healthcare Foundation Models
Overview
Healthcare foundation models are large-scale transformer architectures pre-trained on electronic health records and clinical text. These models learn generalizable representations of patient trajectories, enabling zero-shot and few-shot transfer to downstream clinical prediction tasks.
ETHOS: Zero-Shot Health Trajectory Prediction
PRIMARY REFERENCE for EmergAI thesis
Paper
Zero-shot Health Trajectory Prediction Using Transformers (Renc et al., 2024, npj Digital Medicine)
Key Contributions
- Encoder-only transformer for temporal EHR event sequences
- Masked event modeling pre-training objective (similar to BERT)
- Zero-shot transfer: Predicts new clinical tasks without task-specific fine-tuning
- State-of-the-art performance: Matches or exceeds task-specific models on mortality, readmission, length-of-stay
Architecture
class ETHOS(nn.Module):
"""
ETHOS-style architecture (simplified)
"""
def __init__(self, vocab_size=50000, d_model=768, n_layers=12, n_heads=12):
super().__init__()
# Event embeddings
self.event_embedding = nn.Embedding(vocab_size, d_model)
self.temporal_embedding = TemporalEmbedding(d_model)
self.type_embedding = nn.Embedding(4, d_model) # diagnosis, procedure, med, lab
# Transformer encoder (bidirectional)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=n_heads,
dim_feedforward=d_model * 4,
dropout=0.1,
activation='gelu'
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
# Prediction heads (task-specific, added on top)
self.mortality_head = nn.Linear(d_model, 2)
self.readmission_head = nn.Linear(d_model, 2)
self.los_head = nn.Linear(d_model, 1) # Continuous prediction
def forward(self, event_ids, event_times, event_types):
# Combine embeddings
event_emb = self.event_embedding(event_ids) # (batch, seq_len, d_model)
temporal_emb = self.temporal_embedding(event_times) # (batch, seq_len, d_model)
type_emb = self.type_embedding(event_types) # (batch, seq_len, d_model)
# Sum embeddings
x = event_emb + temporal_emb + type_emb
# Encode sequence (bidirectional attention)
encoded = self.transformer(x) # (batch, seq_len, d_model)
# Pool representation (use CLS token or mean pooling)
pooled = encoded[:, 0, :] # CLS token representation
return pooled
def predict_mortality(self, event_ids, event_times, event_types):
pooled = self.forward(event_ids, event_times, event_types)
return self.mortality_head(pooled)
def predict_readmission(self, event_ids, event_times, event_types):
pooled = self.forward(event_ids, event_times, event_types)
return self.readmission_head(pooled)Pre-training: Masked Event Modeling
Pre-train ETHOS by masking 15% of events and predicting them from context:
def pretrain_masked_event_modeling(model, event_sequence, event_times, event_types):
"""
Masked event modeling pre-training (ETHOS objective)
Args:
event_sequence: (batch, seq_len) tensor of event IDs
event_times: (batch, seq_len) tensor of timestamps
event_types: (batch, seq_len) tensor of event types
"""
batch_size, seq_len = event_sequence.shape
# Randomly mask 15% of events
mask_prob = 0.15
mask = torch.rand(batch_size, seq_len) < mask_prob
# Create masked input (replace masked events with [MASK] token)
masked_sequence = event_sequence.clone()
masked_sequence[mask] = MASK_TOKEN_ID
# Forward pass
encoded = model.transformer(
model.event_embedding(masked_sequence) +
model.temporal_embedding(event_times) +
model.type_embedding(event_types)
)
# Predict masked events
predictions = model.event_prediction_head(encoded) # (batch, seq_len, vocab_size)
# Compute loss only on masked positions
loss = F.cross_entropy(
predictions[mask].view(-1, vocab_size),
event_sequence[mask].view(-1)
)
return lossZero-Shot Transfer
After pre-training, ETHOS can predict new tasks without fine-tuning:
# Pre-trained ETHOS
ethos = ETHOS.from_pretrained('ethos-base')
# Add task-specific head for new task
ethos.sepsis_head = nn.Linear(768, 2)
# Predict on new task (zero-shot)
pooled_repr = ethos(event_ids, event_times, event_types)
sepsis_prediction = ethos.sepsis_head(pooled_repr)Performance
From the paper:
- Mortality prediction: AUROC 0.87 (matches task-specific LSTM)
- Readmission prediction: AUROC 0.72 (outperforms baselines)
- Length-of-stay: MAE 2.3 days (competitive with specialized models)
- Zero-shot: No fine-tuning required for new tasks
BEHRT: Bidirectional Encoder Representations from Health Records
Paper
BEHRT: Transformer for Electronic Health Records (Xie et al., 2020)
Architecture
- Similar to BERT for NLP, but for EHR event sequences
- Pre-training on diagnosis code sequences (ICD-10)
- Bidirectional encoding captures past and future context
Pre-training Tasks
- Masked diagnosis prediction: Mask random diagnoses, predict from context
- Visit type prediction: Predict whether event is inpatient/outpatient/ED
Key Differences from ETHOS
- BEHRT focuses on diagnosis codes only
- ETHOS uses all event types (diagnoses, procedures, meds, labs)
- BEHRT requires fine-tuning for downstream tasks
- ETHOS emphasizes zero-shot transfer
Med-BERT: Pre-trained Contextualized Embeddings
Paper
Med-BERT: Pre-trained Contextualized Embeddings on Large-Scale Structured EHRs for Disease Prediction (Rasmy et al., 2020)
Key Features
- Hierarchical embeddings: Both code-level and visit-level representations
- Multimodal: Combines structured codes and clinical notes
- Strong readmission prediction: State-of-the-art on 30-day readmission
Hierarchical Structure
class MedBERT(nn.Module):
def __init__(self, vocab_size=50000, d_model=768):
super().__init__()
# Code-level encoder
self.code_encoder = nn.TransformerEncoder(...)
# Visit-level encoder (encode sequence of visits)
self.visit_encoder = nn.TransformerEncoder(...)
def forward(self, visits):
# visits: list of visits, each visit is a list of codes
# Encode each visit
visit_representations = []
for visit_codes in visits:
code_emb = self.code_encoder(visit_codes) # Encode codes within visit
visit_repr = code_emb.mean(dim=0) # Pool to visit representation
visit_representations.append(visit_repr)
# Encode visit sequence
visit_sequence = torch.stack(visit_representations)
patient_repr = self.visit_encoder(visit_sequence)
return patient_reprAdvantages
- Two-level hierarchy: Captures both intra-visit and inter-visit patterns
- Visit-level reasoning: Understands temporal visit sequences
- Multimodal fusion: Integrates codes + clinical notes
GatorTron: Large-Scale Clinical Language Model
Paper
A study of generative large language model for medical research and healthcare (Yang et al., 2022)
Scale
- 8.9 billion parameters (GPT-3 scale for healthcare)
- Trained on 90 billion words of clinical text from University of Florida Health
- GPT-style decoder architecture
Key Features
- Largest clinical LM (as of 2022)
- Clinical text generation: Can generate discharge summaries, clinical notes
- Strong performance on medical NLP benchmarks (NER, relation extraction, QA)
Comparison to ClinicalBERT
- ClinicalBERT: 110M parameters, encoder-only, trained on MIMIC-III notes (newer variants can use MIMIC-IV with 269,573 additional ED notes)
- GatorTron: 8.9B parameters, decoder (GPT-style), trained on 90B words
- Use cases:
- ClinicalBERT: Classification, NER, embeddings
- GatorTron: Generation, completion, question-answering
Comparison Table
| Model | Architecture | Scale | Pre-training Data | Pre-training Task | Key Strength |
|---|---|---|---|---|---|
| ETHOS | Encoder (BERT-style) | ~100M params | EHR sequences (all event types) | Masked event modeling | Zero-shot transfer |
| BEHRT | Encoder (BERT-style) | ~110M params | Diagnosis code sequences | Masked diagnosis + visit type | Bidirectional context |
| Med-BERT | Encoder (hierarchical) | ~110M params | Codes + clinical notes | Masked prediction + visit | Hierarchical structure |
| GatorTron | Decoder (GPT-style) | 8.9B params | Clinical text (90B words) | Next-token prediction | Text generation |
| ClinicalBERT | Encoder (BERT-style) | 110M params | MIMIC clinical notes | Masked language modeling | Clinical NLP |
Multimodal Extension for EmergAI
Building on ETHOS for the EmergAI thesis:
class EmergAIModel(nn.Module):
"""
Multimodal extension of ETHOS for EmergAI thesis
Combines structured EHR, symptom text, and symptom sketches
"""
def __init__(self):
super().__init__()
# ETHOS-style encoder for structured EHR events
self.ehr_encoder = ETHOSEncoder(vocab_size=50000, d_model=768, n_layers=12)
# ClinicalBERT for patient-reported symptom text
self.text_encoder = AutoModel.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')
# Vision encoder for 3D symptom sketches (ResNet or ViT)
self.sketch_encoder = ResNet50(pretrained=True, out_features=768)
# Cross-modal fusion (attention-based)
self.cross_attention = nn.MultiheadAttention(embed_dim=768, num_heads=12)
# Outcome prediction head
self.outcome_head = nn.Linear(768, num_outcomes)
def forward(self, ehr_events, ehr_times, ehr_types, symptom_text, symptom_sketch):
# Encode each modality
ehr_repr = self.ehr_encoder(ehr_events, ehr_times, ehr_types) # (batch, 768)
text_repr = self.text_encoder(symptom_text).pooler_output # (batch, 768)
sketch_repr = self.sketch_encoder(symptom_sketch) # (batch, 768)
# Stack modalities for cross-attention
modalities = torch.stack([ehr_repr, text_repr, sketch_repr], dim=1) # (batch, 3, 768)
# Cross-modal attention fusion
fused_repr, _ = self.cross_attention(modalities, modalities, modalities)
fused_repr = fused_repr.mean(dim=1) # Pool across modalities
# Predict outcome
outcome = self.outcome_head(fused_repr)
return outcomeThesis Research Question: Does adding patient-reported multimodal data (symptom text + 3D sketches) improve ED outcome prediction over structured EHR alone (ETHOS baseline)?
Related Concepts
- EHR Structure and Medical Coding: Understanding the data these models process
- Medical Event Tokenization: How EHR events become token sequences
- Language Model Training: Pre-training objectives and fine-tuning strategies
- Healthcare Interpretability: Validating and explaining foundation model predictions
Learning Resources
Papers
- ETHOS: Zero-shot Health Trajectory Prediction (Renc et al., 2024) - Primary reference
- BEHRT: Transformer for Electronic Health Records (Xie et al., 2020)
- Med-BERT: Pre-trained Contextualized Embeddings (Rasmy et al., 2020)
- GatorTron: Large Language Model for Clinical Research (Yang et al., 2022)
- ClinicalBERT: Clinical Notes Modeling (Alsentzer et al., 2019)
Code Repositories
- ETHOS GitHub - Official ETHOS implementation
- BEHRT GitHub - Official BEHRT implementation
- Med-BERT GitHub - Med-BERT code
Models
- emilyalsentzer/Bio_ClinicalBERT - ClinicalBERT on HuggingFace
- GatorTron - GatorTron models
Applications
- Mortality prediction: ICU and hospital mortality risk
- Readmission prediction: 30-day readmission risk stratification
- Length-of-stay estimation: Hospital/ICU duration forecasting
- Disease progression: Temporal modeling of chronic conditions
- Adverse event detection: Early warning systems for complications
- Clinical trial recruitment: Identify eligible patients from EHR
- Precision medicine: Personalized treatment recommendations
Key Takeaways
- ETHOS achieves zero-shot transfer through masked event pre-training on all EHR event types
- BEHRT pioneered BERT-style pre-training for diagnosis code sequences
- Med-BERT uses hierarchical encoding (code-level + visit-level) for richer representations
- GatorTron is the largest clinical LM (8.9B params) for text generation
- Pre-training paradigm: Masked modeling enables generalization across clinical tasks
- Multimodal extension: Combining ETHOS with patient-reported data (text + sketches) is the EmergAI thesis contribution