Skip to Content

Transformers for Patient Event Sequences

Transformers have revolutionized healthcare AI by enabling powerful modeling of patient trajectories and clinical event sequences. This page covers how attention mechanisms apply to Electronic Health Records (EHR) and outcome prediction.

Why Transformers for Healthcare?

Traditional sequence models (RNNs, LSTMs) struggled with patient data. Transformers solve key challenges:

The Core Problem

Each patient visit generates a sequence of clinical events:

Patient Timeline: [Admission] → [Vitals] → [Lab Tests] → [Diagnosis] → [Treatment] → [Outcome] t=0 t=1 t=2 t=3 t=4 t=5

Challenges:

  • Variable length: Some patients have 10 events, others 100+
  • Long-range dependencies: Early symptoms may predict outcomes weeks later
  • Parallel processing: Need to process millions of patients efficiently
  • Multiple data types: Diagnoses, procedures, medications, lab results

Transformer Advantages

ChallengeRNN/LSTMTransformer
Long sequencesVanishing gradientsDirect connections via attention
ParallelizationSequential processingFully parallel
Long-range dependenciesDifficult (>50 steps)Explicit modeling via self-attention
InterpretabilityBlack boxAttention weights show reasoning
Context windowFixed hidden stateFull sequence context

Patient Event Sequence Modeling

Representing Patient Trajectories

Each medical event is a discrete token (similar to words in NLP):

# Example patient trajectory trajectory = [ "DIAG:I21.0", # ICD-10: Acute myocardial infarction "PROC:PCI", # Procedure: Percutaneous coronary intervention "MED:aspirin", # Medication "LAB:troponin_high", # Lab result "VISIT:followup_30d" # Follow-up visit ] # Create event vocabulary event_vocab = { "DIAG:I21.0": 1, "PROC:PCI": 2, "MED:aspirin": 3, "LAB:troponin_high": 4, "VISIT:followup_30d": 5, # ... thousands more events } # Tokenize trajectory tokens = [event_vocab[event] for event in trajectory] # [1, 2, 3, 4, 5]

Transformer Encoder for Patient Representation

import torch import torch.nn as nn class PatientEventEncoder(nn.Module): """ Transformer encoder for patient event sequences Encodes variable-length sequences of medical events into fixed-size representations for outcome prediction. """ def __init__(self, event_vocab_size, d_model=512, num_heads=8, num_layers=6, max_seq_len=512): super().__init__() # Event embedding: convert event IDs to dense vectors self.event_embedding = nn.Embedding(event_vocab_size, d_model) # Positional encoding: capture event order self.positional_encoding = nn.Parameter( self._create_positional_encoding(max_seq_len, d_model) ) # Transformer encoder layers encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=num_heads, dim_feedforward=2048, dropout=0.1, batch_first=True ) self.transformer_encoder = nn.TransformerEncoder( encoder_layer, num_layers=num_layers ) # Layer normalization self.norm = nn.LayerNorm(d_model) def forward(self, event_sequence, mask=None): """ Args: event_sequence: (batch, seq_len) - sequence of event IDs mask: (batch, seq_len) - padding mask Returns: embeddings: (batch, seq_len, d_model) - contextual event representations """ batch_size, seq_len = event_sequence.shape # Embed events x = self.event_embedding(event_sequence) # (batch, seq_len, d_model) # Add positional information x = x + self.positional_encoding[:seq_len, :] # Transform with attention x = self.transformer_encoder(x, src_key_padding_mask=mask) # Normalize x = self.norm(x) return x # (batch, seq_len, d_model) def _create_positional_encoding(self, max_len, d_model): """Sinusoidal positional encoding""" position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * -(torch.log(torch.tensor(10000.0)) / d_model)) pe = torch.zeros(max_len, d_model) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) return pe

Multimodal Fusion with Cross-Attention

Healthcare data is inherently multimodal: structured events, clinical text, imaging.

Architecture for Multi-Modal Patient Representation

class MultimodalPatientEncoder(nn.Module): """ Fuse structured EHR events, clinical text, and imaging using cross-modal attention """ def __init__(self, d_model=512, num_heads=8): super().__init__() # Encoder for each modality self.ehr_encoder = PatientEventEncoder( event_vocab_size=10000, d_model=d_model, num_heads=num_heads ) # Clinical text encoder (BERT-based) from transformers import AutoModel self.text_encoder = AutoModel.from_pretrained( 'emilyalsentzer/Bio_ClinicalBERT' ) self.text_projection = nn.Linear(768, d_model) # Image encoder (ResNet-based) from torchvision.models import resnet50 self.image_encoder = resnet50(pretrained=True) self.image_encoder.fc = nn.Identity() self.image_projection = nn.Linear(2048, d_model) # Cross-modal attention layers self.ehr_to_text_attn = nn.MultiheadAttention( d_model, num_heads, batch_first=True ) self.text_to_ehr_attn = nn.MultiheadAttention( d_model, num_heads, batch_first=True ) self.image_fusion_attn = nn.MultiheadAttention( d_model, num_heads, batch_first=True ) # Layer normalization self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) def forward(self, ehr_events, clinical_text=None, image=None): """ Args: ehr_events: (batch, seq_len) - event sequence clinical_text: (batch, text_len) - tokenized text image: (batch, 3, H, W) - medical images (optional) Returns: fused_representation: (batch, d_model) - multimodal patient embedding attention_weights: dict - interpretability """ batch_size = ehr_events.shape[0] # Encode EHR events ehr_repr = self.ehr_encoder(ehr_events) # (batch, seq_len, d_model) # If clinical text available, fuse with cross-attention if clinical_text is not None: # Encode text text_output = self.text_encoder(**clinical_text) text_repr = self.text_projection( text_output.last_hidden_state ) # (batch, text_len, d_model) # Cross-attention: EHR attends to text ehr_enhanced, attn_ehr_to_text = self.ehr_to_text_attn( query=ehr_repr, key=text_repr, value=text_repr ) ehr_enhanced = self.norm1(ehr_repr + ehr_enhanced) # Cross-attention: Text attends to EHR text_enhanced, attn_text_to_ehr = self.text_to_ehr_attn( query=text_repr, key=ehr_repr, value=ehr_repr ) text_enhanced = self.norm1(text_repr + text_enhanced) # Combine combined = torch.cat([ehr_enhanced, text_enhanced], dim=1) else: combined = ehr_repr # If image available, fuse via attention if image is not None: image_features = self.image_encoder(image) # (batch, 2048) image_repr = self.image_projection(image_features) # (batch, d_model) image_repr = image_repr.unsqueeze(1) # (batch, 1, d_model) combined, attn_image = self.image_fusion_attn( query=combined, key=image_repr, value=image_repr ) combined = self.norm2(combined) # Pool to single representation fused_representation = combined.mean(dim=1) # (batch, d_model) return fused_representation, { 'ehr_to_text': attn_ehr_to_text if clinical_text else None, 'text_to_ehr': attn_text_to_ehr if clinical_text else None, 'image': attn_image if image is not None else None }

Temporal Modeling

Patient trajectories have temporal structure - time matters!

Time-Aware Positional Encoding

Standard positional encoding captures order, but not actual time gaps:

def temporal_positional_encoding(timestamps, d_model): """ Encode actual time (not just position) Args: timestamps: (batch, seq_len) - absolute timestamps (hours since admission) d_model: model dimension Returns: temporal_pe: (batch, seq_len, d_model) """ batch_size, seq_len = timestamps.shape # Normalize timestamps to [0, 1] range time_normalized = (timestamps - timestamps.min(dim=1, keepdim=True)[0]) time_normalized = time_normalized / (time_normalized.max(dim=1, keepdim=True)[0] + 1e-8) # Apply sinusoidal encoding pe = torch.zeros(batch_size, seq_len, d_model) for i in range(d_model // 2): freq = 1.0 / (10000 ** (2 * i / d_model)) pe[:, :, 2*i] = torch.sin(time_normalized * freq) pe[:, :, 2*i+1] = torch.cos(time_normalized * freq) return pe

Incorporating Time Deltas

# Add time gaps as special tokens trajectory_with_time = [ ("DIAG:I21.0", t=0), ("PROC:PCI", t=0), ("TIME_DELTA:1day", None), # Explicit time gap ("MED:aspirin", t=1), ("TIME_DELTA:30days", None), ("VISIT:followup", t=31) ]

Zero-Shot Clinical Prediction

The ETHOS Approach

ETHOS (Emerging Topics in Healthcare Outcomes) enables zero-shot prediction:

Pre-training: Self-supervised learning on patient trajectories

  • Masked event modeling: Predict randomly masked events
  • Learns general patterns in healthcare
def pretrain_patient_encoder(model, patient_sequences): """ Pre-train on unlabeled patient data (self-supervised) """ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) for sequence in patient_sequences: # Randomly mask 15% of events masked_sequence, labels = mask_events(sequence, mask_prob=0.15) # Predict masked events predictions = model(masked_sequence) loss = F.cross_entropy(predictions.view(-1, vocab_size), labels.view(-1)) optimizer.zero_grad() loss.backward() optimizer.step()

Zero-shot transfer: Predict new outcomes without task-specific training

  • No retraining needed for new prediction tasks
  • Generalizes from learned patient representations
@torch.no_grad() def zero_shot_predict(model, patient_sequence, outcome_query): """ Predict outcome without fine-tuning on that specific outcome Args: patient_sequence: Historical events outcome_query: Outcome of interest (e.g., "ICU admission") Returns: probability: Predicted probability """ # Encode patient trajectory patient_repr = model(patient_sequence) # Encode outcome query outcome_repr = model.encode_text(outcome_query) # Similarity = probability probability = torch.cosine_similarity(patient_repr, outcome_repr) return probability

Attention for Clinical Interpretability

Key advantage: Attention weights reveal which events influenced predictions.

Visualizing Clinical Reasoning

def visualize_patient_attention(model, patient_sequence, outcome): """ Show which past events the model focused on for prediction """ # Forward pass with attention weights prediction, attention_weights = model( patient_sequence, return_attention=True ) # attention_weights: (num_layers, num_heads, seq_len, seq_len) # Average across layers and heads avg_attention = attention_weights.mean(dim=(0, 1)) # (seq_len, seq_len) # For each event, show which past events it attended to for i, event in enumerate(patient_sequence): top_attended = avg_attention[i].topk(5) # Top 5 attended events print(f"{event} attended to: {[patient_sequence[j] for j in top_attended.indices]}")

Clinical validation example:

Predicting: Risk of readmission within 30 days Patient Event Sequence: [Admission] [Chest Pain] [ECG Normal] [Discharge] [7 days gap] [Readmission] [SOB] Attention Weights: Admission: 0.05 Chest Pain: 0.35 ← High attention ECG Normal: 0.20 ← Moderate attention Discharge: 0.05 7 days gap: 0.10 ← Time gap matters Readmission: 0.15 SOB: 0.10 Interpretation: Model recognizes pattern of recurring cardiac symptoms (chest pain → normal ECG → short gap → return with SOB)

Practical Implementation

Complete Outcome Prediction System

class ClinicalOutcomePredictor(nn.Module): """ End-to-end system for predicting clinical outcomes """ def __init__(self, num_outcomes=5): super().__init__() # Patient encoder self.patient_encoder = PatientEventEncoder( event_vocab_size=10000, d_model=512, num_heads=8, num_layers=6 ) # Outcome prediction head self.prediction_head = nn.Sequential( nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_outcomes) ) def forward(self, patient_sequence): # Encode patient trajectory embeddings = self.patient_encoder(patient_sequence) # Pool to single representation (use [CLS] token or mean) patient_repr = embeddings.mean(dim=1) # (batch, d_model) # Predict outcomes logits = self.prediction_head(patient_repr) # (batch, num_outcomes) return logits # Training model = ClinicalOutcomePredictor(num_outcomes=5) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) criterion = nn.BCEWithLogitsLoss() # Multi-label classification for patient_sequences, outcomes in train_loader: logits = model(patient_sequences) loss = criterion(logits, outcomes) optimizer.zero_grad() loss.backward() optimizer.step()

Best Practices

Do:

  • Pre-train on large corpus of patient sequences (self-supervised)
  • Use positional encoding that captures temporal information
  • Implement attention visualization for interpretability
  • Validate attention patterns with clinicians
  • Fine-tune with lower learning rates for specific outcomes
  • Handle variable-length sequences with padding masks
  • Consider multi-modal fusion (EHR + text + imaging)

Don’t:

  • Train from scratch (pre-training is crucial)
  • Ignore temporal information (time matters in healthcare)
  • Deploy without interpretability (attention is essential)
  • Forget data leakage checks (future events can’t predict past)
  • Use excessively long context windows (computation cost)
  • Assume model generalizes across hospitals without validation

Attention Mechanism - Core concept

Multi-Head Attention - Parallel attention

Transformer Architecture - Original paper

GPT Architecture - Decoder-only alternative

EHR Data Structure - Data format

Healthcare Foundation Models - Pre-trained models (ETHOS, BEHRT)

Clinical Interpretability - Explainability requirements

Learning Resources

Key Papers

  • BEHRT (Li et al., 2020): BERT for Electronic Health Records
  • ETHOS (Raff et al., 2024): Zero-shot trajectory prediction
  • Med-BERT (Rasmy et al., 2021): Medical concept embedding
  • Hi-BEHRT (Pang et al., 2021): Hierarchical transformers for EHR

Datasets

  • MIMIC-III: 53,423 ICU admissions (2001-2012) with detailed event sequences
  • MIMIC-IV v3.1: Over 65,000 ICU patients + over 200,000 ED patients (2008-2022) with enhanced data quality, ICD-9/10 codes, improved mortality data, and provider tracking
  • eICU: Multi-center ICU database (200,000 admissions)

Code

  • PyHealth: Python library for healthcare AI (includes EHR transformers)
  • Hugging Face Medical: Pre-trained clinical transformers