Skip to Content

Healthcare Foundation Models

Overview

Healthcare foundation models are large-scale transformer architectures pre-trained on electronic health records and clinical text. These models learn generalizable representations of patient trajectories, enabling zero-shot and few-shot transfer to downstream clinical prediction tasks.

ETHOS: Zero-Shot Health Trajectory Prediction

PRIMARY REFERENCE for EmergAI thesis

Paper

Zero-shot Health Trajectory Prediction Using Transformers  (Renc et al., 2024, npj Digital Medicine)

Key Contributions

  1. Encoder-only transformer for temporal EHR event sequences
  2. Masked event modeling pre-training objective (similar to BERT)
  3. Zero-shot transfer: Predicts new clinical tasks without task-specific fine-tuning
  4. State-of-the-art performance: Matches or exceeds task-specific models on mortality, readmission, length-of-stay

Architecture

class ETHOS(nn.Module): """ ETHOS-style architecture (simplified) """ def __init__(self, vocab_size=50000, d_model=768, n_layers=12, n_heads=12): super().__init__() # Event embeddings self.event_embedding = nn.Embedding(vocab_size, d_model) self.temporal_embedding = TemporalEmbedding(d_model) self.type_embedding = nn.Embedding(4, d_model) # diagnosis, procedure, med, lab # Transformer encoder (bidirectional) encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=d_model * 4, dropout=0.1, activation='gelu' ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) # Prediction heads (task-specific, added on top) self.mortality_head = nn.Linear(d_model, 2) self.readmission_head = nn.Linear(d_model, 2) self.los_head = nn.Linear(d_model, 1) # Continuous prediction def forward(self, event_ids, event_times, event_types): # Combine embeddings event_emb = self.event_embedding(event_ids) # (batch, seq_len, d_model) temporal_emb = self.temporal_embedding(event_times) # (batch, seq_len, d_model) type_emb = self.type_embedding(event_types) # (batch, seq_len, d_model) # Sum embeddings x = event_emb + temporal_emb + type_emb # Encode sequence (bidirectional attention) encoded = self.transformer(x) # (batch, seq_len, d_model) # Pool representation (use CLS token or mean pooling) pooled = encoded[:, 0, :] # CLS token representation return pooled def predict_mortality(self, event_ids, event_times, event_types): pooled = self.forward(event_ids, event_times, event_types) return self.mortality_head(pooled) def predict_readmission(self, event_ids, event_times, event_types): pooled = self.forward(event_ids, event_times, event_types) return self.readmission_head(pooled)

Pre-training: Masked Event Modeling

Pre-train ETHOS by masking 15% of events and predicting them from context:

def pretrain_masked_event_modeling(model, event_sequence, event_times, event_types): """ Masked event modeling pre-training (ETHOS objective) Args: event_sequence: (batch, seq_len) tensor of event IDs event_times: (batch, seq_len) tensor of timestamps event_types: (batch, seq_len) tensor of event types """ batch_size, seq_len = event_sequence.shape # Randomly mask 15% of events mask_prob = 0.15 mask = torch.rand(batch_size, seq_len) < mask_prob # Create masked input (replace masked events with [MASK] token) masked_sequence = event_sequence.clone() masked_sequence[mask] = MASK_TOKEN_ID # Forward pass encoded = model.transformer( model.event_embedding(masked_sequence) + model.temporal_embedding(event_times) + model.type_embedding(event_types) ) # Predict masked events predictions = model.event_prediction_head(encoded) # (batch, seq_len, vocab_size) # Compute loss only on masked positions loss = F.cross_entropy( predictions[mask].view(-1, vocab_size), event_sequence[mask].view(-1) ) return loss

Zero-Shot Transfer

After pre-training, ETHOS can predict new tasks without fine-tuning:

# Pre-trained ETHOS ethos = ETHOS.from_pretrained('ethos-base') # Add task-specific head for new task ethos.sepsis_head = nn.Linear(768, 2) # Predict on new task (zero-shot) pooled_repr = ethos(event_ids, event_times, event_types) sepsis_prediction = ethos.sepsis_head(pooled_repr)

Performance

From the paper:

  • Mortality prediction: AUROC 0.87 (matches task-specific LSTM)
  • Readmission prediction: AUROC 0.72 (outperforms baselines)
  • Length-of-stay: MAE 2.3 days (competitive with specialized models)
  • Zero-shot: No fine-tuning required for new tasks

BEHRT: Bidirectional Encoder Representations from Health Records

Paper

BEHRT: Transformer for Electronic Health Records  (Xie et al., 2020)

Architecture

  • Similar to BERT for NLP, but for EHR event sequences
  • Pre-training on diagnosis code sequences (ICD-10)
  • Bidirectional encoding captures past and future context

Pre-training Tasks

  1. Masked diagnosis prediction: Mask random diagnoses, predict from context
  2. Visit type prediction: Predict whether event is inpatient/outpatient/ED

Key Differences from ETHOS

  • BEHRT focuses on diagnosis codes only
  • ETHOS uses all event types (diagnoses, procedures, meds, labs)
  • BEHRT requires fine-tuning for downstream tasks
  • ETHOS emphasizes zero-shot transfer

Med-BERT: Pre-trained Contextualized Embeddings

Paper

Med-BERT: Pre-trained Contextualized Embeddings on Large-Scale Structured EHRs for Disease Prediction  (Rasmy et al., 2020)

Key Features

  1. Hierarchical embeddings: Both code-level and visit-level representations
  2. Multimodal: Combines structured codes and clinical notes
  3. Strong readmission prediction: State-of-the-art on 30-day readmission

Hierarchical Structure

class MedBERT(nn.Module): def __init__(self, vocab_size=50000, d_model=768): super().__init__() # Code-level encoder self.code_encoder = nn.TransformerEncoder(...) # Visit-level encoder (encode sequence of visits) self.visit_encoder = nn.TransformerEncoder(...) def forward(self, visits): # visits: list of visits, each visit is a list of codes # Encode each visit visit_representations = [] for visit_codes in visits: code_emb = self.code_encoder(visit_codes) # Encode codes within visit visit_repr = code_emb.mean(dim=0) # Pool to visit representation visit_representations.append(visit_repr) # Encode visit sequence visit_sequence = torch.stack(visit_representations) patient_repr = self.visit_encoder(visit_sequence) return patient_repr

Advantages

  • Two-level hierarchy: Captures both intra-visit and inter-visit patterns
  • Visit-level reasoning: Understands temporal visit sequences
  • Multimodal fusion: Integrates codes + clinical notes

GatorTron: Large-Scale Clinical Language Model

Paper

A study of generative large language model for medical research and healthcare  (Yang et al., 2022)

Scale

  • 8.9 billion parameters (GPT-3 scale for healthcare)
  • Trained on 90 billion words of clinical text from University of Florida Health
  • GPT-style decoder architecture

Key Features

  • Largest clinical LM (as of 2022)
  • Clinical text generation: Can generate discharge summaries, clinical notes
  • Strong performance on medical NLP benchmarks (NER, relation extraction, QA)

Comparison to ClinicalBERT

  • ClinicalBERT: 110M parameters, encoder-only, trained on MIMIC-III notes (newer variants can use MIMIC-IV with 269,573 additional ED notes)
  • GatorTron: 8.9B parameters, decoder (GPT-style), trained on 90B words
  • Use cases:
    • ClinicalBERT: Classification, NER, embeddings
    • GatorTron: Generation, completion, question-answering

Comparison Table

ModelArchitectureScalePre-training DataPre-training TaskKey Strength
ETHOSEncoder (BERT-style)~100M paramsEHR sequences (all event types)Masked event modelingZero-shot transfer
BEHRTEncoder (BERT-style)~110M paramsDiagnosis code sequencesMasked diagnosis + visit typeBidirectional context
Med-BERTEncoder (hierarchical)~110M paramsCodes + clinical notesMasked prediction + visitHierarchical structure
GatorTronDecoder (GPT-style)8.9B paramsClinical text (90B words)Next-token predictionText generation
ClinicalBERTEncoder (BERT-style)110M paramsMIMIC clinical notesMasked language modelingClinical NLP

Multimodal Extension for EmergAI

Building on ETHOS for the EmergAI thesis:

class EmergAIModel(nn.Module): """ Multimodal extension of ETHOS for EmergAI thesis Combines structured EHR, symptom text, and symptom sketches """ def __init__(self): super().__init__() # ETHOS-style encoder for structured EHR events self.ehr_encoder = ETHOSEncoder(vocab_size=50000, d_model=768, n_layers=12) # ClinicalBERT for patient-reported symptom text self.text_encoder = AutoModel.from_pretrained('emilyalsentzer/Bio_ClinicalBERT') # Vision encoder for 3D symptom sketches (ResNet or ViT) self.sketch_encoder = ResNet50(pretrained=True, out_features=768) # Cross-modal fusion (attention-based) self.cross_attention = nn.MultiheadAttention(embed_dim=768, num_heads=12) # Outcome prediction head self.outcome_head = nn.Linear(768, num_outcomes) def forward(self, ehr_events, ehr_times, ehr_types, symptom_text, symptom_sketch): # Encode each modality ehr_repr = self.ehr_encoder(ehr_events, ehr_times, ehr_types) # (batch, 768) text_repr = self.text_encoder(symptom_text).pooler_output # (batch, 768) sketch_repr = self.sketch_encoder(symptom_sketch) # (batch, 768) # Stack modalities for cross-attention modalities = torch.stack([ehr_repr, text_repr, sketch_repr], dim=1) # (batch, 3, 768) # Cross-modal attention fusion fused_repr, _ = self.cross_attention(modalities, modalities, modalities) fused_repr = fused_repr.mean(dim=1) # Pool across modalities # Predict outcome outcome = self.outcome_head(fused_repr) return outcome

Thesis Research Question: Does adding patient-reported multimodal data (symptom text + 3D sketches) improve ED outcome prediction over structured EHR alone (ETHOS baseline)?

  • EHR Structure and Medical Coding: Understanding the data these models process
  • Medical Event Tokenization: How EHR events become token sequences
  • Language Model Training: Pre-training objectives and fine-tuning strategies
  • Healthcare Interpretability: Validating and explaining foundation model predictions

Learning Resources

Papers

Code Repositories

Models

Applications

  • Mortality prediction: ICU and hospital mortality risk
  • Readmission prediction: 30-day readmission risk stratification
  • Length-of-stay estimation: Hospital/ICU duration forecasting
  • Disease progression: Temporal modeling of chronic conditions
  • Adverse event detection: Early warning systems for complications
  • Clinical trial recruitment: Identify eligible patients from EHR
  • Precision medicine: Personalized treatment recommendations

Key Takeaways

  1. ETHOS achieves zero-shot transfer through masked event pre-training on all EHR event types
  2. BEHRT pioneered BERT-style pre-training for diagnosis code sequences
  3. Med-BERT uses hierarchical encoding (code-level + visit-level) for richer representations
  4. GatorTron is the largest clinical LM (8.9B params) for text generation
  5. Pre-training paradigm: Masked modeling enables generalization across clinical tasks
  6. Multimodal extension: Combining ETHOS with patient-reported data (text + sketches) is the EmergAI thesis contribution