Transformers for Patient Event Sequences
Transformers have revolutionized healthcare AI by enabling powerful modeling of patient trajectories and clinical event sequences. This page covers how attention mechanisms apply to Electronic Health Records (EHR) and outcome prediction.
Why Transformers for Healthcare?
Traditional sequence models (RNNs, LSTMs) struggled with patient data. Transformers solve key challenges:
The Core Problem
Each patient visit generates a sequence of clinical events:
Patient Timeline:
[Admission] → [Vitals] → [Lab Tests] → [Diagnosis] → [Treatment] → [Outcome]
t=0 t=1 t=2 t=3 t=4 t=5Challenges:
- Variable length: Some patients have 10 events, others 100+
- Long-range dependencies: Early symptoms may predict outcomes weeks later
- Parallel processing: Need to process millions of patients efficiently
- Multiple data types: Diagnoses, procedures, medications, lab results
Transformer Advantages
| Challenge | RNN/LSTM | Transformer |
|---|---|---|
| Long sequences | Vanishing gradients | Direct connections via attention |
| Parallelization | Sequential processing | Fully parallel |
| Long-range dependencies | Difficult (>50 steps) | Explicit modeling via self-attention |
| Interpretability | Black box | Attention weights show reasoning |
| Context window | Fixed hidden state | Full sequence context |
Patient Event Sequence Modeling
Representing Patient Trajectories
Each medical event is a discrete token (similar to words in NLP):
# Example patient trajectory
trajectory = [
"DIAG:I21.0", # ICD-10: Acute myocardial infarction
"PROC:PCI", # Procedure: Percutaneous coronary intervention
"MED:aspirin", # Medication
"LAB:troponin_high", # Lab result
"VISIT:followup_30d" # Follow-up visit
]
# Create event vocabulary
event_vocab = {
"DIAG:I21.0": 1,
"PROC:PCI": 2,
"MED:aspirin": 3,
"LAB:troponin_high": 4,
"VISIT:followup_30d": 5,
# ... thousands more events
}
# Tokenize trajectory
tokens = [event_vocab[event] for event in trajectory] # [1, 2, 3, 4, 5]Transformer Encoder for Patient Representation
import torch
import torch.nn as nn
class PatientEventEncoder(nn.Module):
"""
Transformer encoder for patient event sequences
Encodes variable-length sequences of medical events into
fixed-size representations for outcome prediction.
"""
def __init__(self,
event_vocab_size,
d_model=512,
num_heads=8,
num_layers=6,
max_seq_len=512):
super().__init__()
# Event embedding: convert event IDs to dense vectors
self.event_embedding = nn.Embedding(event_vocab_size, d_model)
# Positional encoding: capture event order
self.positional_encoding = nn.Parameter(
self._create_positional_encoding(max_seq_len, d_model)
)
# Transformer encoder layers
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=num_heads,
dim_feedforward=2048,
dropout=0.1,
batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer,
num_layers=num_layers
)
# Layer normalization
self.norm = nn.LayerNorm(d_model)
def forward(self, event_sequence, mask=None):
"""
Args:
event_sequence: (batch, seq_len) - sequence of event IDs
mask: (batch, seq_len) - padding mask
Returns:
embeddings: (batch, seq_len, d_model) - contextual event representations
"""
batch_size, seq_len = event_sequence.shape
# Embed events
x = self.event_embedding(event_sequence) # (batch, seq_len, d_model)
# Add positional information
x = x + self.positional_encoding[:seq_len, :]
# Transform with attention
x = self.transformer_encoder(x, src_key_padding_mask=mask)
# Normalize
x = self.norm(x)
return x # (batch, seq_len, d_model)
def _create_positional_encoding(self, max_len, d_model):
"""Sinusoidal positional encoding"""
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) *
-(torch.log(torch.tensor(10000.0)) / d_model))
pe = torch.zeros(max_len, d_model)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return peMultimodal Fusion with Cross-Attention
Healthcare data is inherently multimodal: structured events, clinical text, imaging.
Architecture for Multi-Modal Patient Representation
class MultimodalPatientEncoder(nn.Module):
"""
Fuse structured EHR events, clinical text, and imaging
using cross-modal attention
"""
def __init__(self, d_model=512, num_heads=8):
super().__init__()
# Encoder for each modality
self.ehr_encoder = PatientEventEncoder(
event_vocab_size=10000,
d_model=d_model,
num_heads=num_heads
)
# Clinical text encoder (BERT-based)
from transformers import AutoModel
self.text_encoder = AutoModel.from_pretrained(
'emilyalsentzer/Bio_ClinicalBERT'
)
self.text_projection = nn.Linear(768, d_model)
# Image encoder (ResNet-based)
from torchvision.models import resnet50
self.image_encoder = resnet50(pretrained=True)
self.image_encoder.fc = nn.Identity()
self.image_projection = nn.Linear(2048, d_model)
# Cross-modal attention layers
self.ehr_to_text_attn = nn.MultiheadAttention(
d_model, num_heads, batch_first=True
)
self.text_to_ehr_attn = nn.MultiheadAttention(
d_model, num_heads, batch_first=True
)
self.image_fusion_attn = nn.MultiheadAttention(
d_model, num_heads, batch_first=True
)
# Layer normalization
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, ehr_events, clinical_text=None, image=None):
"""
Args:
ehr_events: (batch, seq_len) - event sequence
clinical_text: (batch, text_len) - tokenized text
image: (batch, 3, H, W) - medical images (optional)
Returns:
fused_representation: (batch, d_model) - multimodal patient embedding
attention_weights: dict - interpretability
"""
batch_size = ehr_events.shape[0]
# Encode EHR events
ehr_repr = self.ehr_encoder(ehr_events) # (batch, seq_len, d_model)
# If clinical text available, fuse with cross-attention
if clinical_text is not None:
# Encode text
text_output = self.text_encoder(**clinical_text)
text_repr = self.text_projection(
text_output.last_hidden_state
) # (batch, text_len, d_model)
# Cross-attention: EHR attends to text
ehr_enhanced, attn_ehr_to_text = self.ehr_to_text_attn(
query=ehr_repr,
key=text_repr,
value=text_repr
)
ehr_enhanced = self.norm1(ehr_repr + ehr_enhanced)
# Cross-attention: Text attends to EHR
text_enhanced, attn_text_to_ehr = self.text_to_ehr_attn(
query=text_repr,
key=ehr_repr,
value=ehr_repr
)
text_enhanced = self.norm1(text_repr + text_enhanced)
# Combine
combined = torch.cat([ehr_enhanced, text_enhanced], dim=1)
else:
combined = ehr_repr
# If image available, fuse via attention
if image is not None:
image_features = self.image_encoder(image) # (batch, 2048)
image_repr = self.image_projection(image_features) # (batch, d_model)
image_repr = image_repr.unsqueeze(1) # (batch, 1, d_model)
combined, attn_image = self.image_fusion_attn(
query=combined,
key=image_repr,
value=image_repr
)
combined = self.norm2(combined)
# Pool to single representation
fused_representation = combined.mean(dim=1) # (batch, d_model)
return fused_representation, {
'ehr_to_text': attn_ehr_to_text if clinical_text else None,
'text_to_ehr': attn_text_to_ehr if clinical_text else None,
'image': attn_image if image is not None else None
}Temporal Modeling
Patient trajectories have temporal structure - time matters!
Time-Aware Positional Encoding
Standard positional encoding captures order, but not actual time gaps:
def temporal_positional_encoding(timestamps, d_model):
"""
Encode actual time (not just position)
Args:
timestamps: (batch, seq_len) - absolute timestamps (hours since admission)
d_model: model dimension
Returns:
temporal_pe: (batch, seq_len, d_model)
"""
batch_size, seq_len = timestamps.shape
# Normalize timestamps to [0, 1] range
time_normalized = (timestamps - timestamps.min(dim=1, keepdim=True)[0])
time_normalized = time_normalized / (time_normalized.max(dim=1, keepdim=True)[0] + 1e-8)
# Apply sinusoidal encoding
pe = torch.zeros(batch_size, seq_len, d_model)
for i in range(d_model // 2):
freq = 1.0 / (10000 ** (2 * i / d_model))
pe[:, :, 2*i] = torch.sin(time_normalized * freq)
pe[:, :, 2*i+1] = torch.cos(time_normalized * freq)
return peIncorporating Time Deltas
# Add time gaps as special tokens
trajectory_with_time = [
("DIAG:I21.0", t=0),
("PROC:PCI", t=0),
("TIME_DELTA:1day", None), # Explicit time gap
("MED:aspirin", t=1),
("TIME_DELTA:30days", None),
("VISIT:followup", t=31)
]Zero-Shot Clinical Prediction
The ETHOS Approach
ETHOS (Emerging Topics in Healthcare Outcomes) enables zero-shot prediction:
Pre-training: Self-supervised learning on patient trajectories
- Masked event modeling: Predict randomly masked events
- Learns general patterns in healthcare
def pretrain_patient_encoder(model, patient_sequences):
"""
Pre-train on unlabeled patient data (self-supervised)
"""
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for sequence in patient_sequences:
# Randomly mask 15% of events
masked_sequence, labels = mask_events(sequence, mask_prob=0.15)
# Predict masked events
predictions = model(masked_sequence)
loss = F.cross_entropy(predictions.view(-1, vocab_size), labels.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()Zero-shot transfer: Predict new outcomes without task-specific training
- No retraining needed for new prediction tasks
- Generalizes from learned patient representations
@torch.no_grad()
def zero_shot_predict(model, patient_sequence, outcome_query):
"""
Predict outcome without fine-tuning on that specific outcome
Args:
patient_sequence: Historical events
outcome_query: Outcome of interest (e.g., "ICU admission")
Returns:
probability: Predicted probability
"""
# Encode patient trajectory
patient_repr = model(patient_sequence)
# Encode outcome query
outcome_repr = model.encode_text(outcome_query)
# Similarity = probability
probability = torch.cosine_similarity(patient_repr, outcome_repr)
return probabilityAttention for Clinical Interpretability
Key advantage: Attention weights reveal which events influenced predictions.
Visualizing Clinical Reasoning
def visualize_patient_attention(model, patient_sequence, outcome):
"""
Show which past events the model focused on for prediction
"""
# Forward pass with attention weights
prediction, attention_weights = model(
patient_sequence,
return_attention=True
)
# attention_weights: (num_layers, num_heads, seq_len, seq_len)
# Average across layers and heads
avg_attention = attention_weights.mean(dim=(0, 1)) # (seq_len, seq_len)
# For each event, show which past events it attended to
for i, event in enumerate(patient_sequence):
top_attended = avg_attention[i].topk(5) # Top 5 attended events
print(f"{event} attended to: {[patient_sequence[j] for j in top_attended.indices]}")Clinical validation example:
Predicting: Risk of readmission within 30 days
Patient Event Sequence:
[Admission] [Chest Pain] [ECG Normal] [Discharge] [7 days gap] [Readmission] [SOB]
Attention Weights:
Admission: 0.05
Chest Pain: 0.35 ← High attention
ECG Normal: 0.20 ← Moderate attention
Discharge: 0.05
7 days gap: 0.10 ← Time gap matters
Readmission: 0.15
SOB: 0.10
Interpretation: Model recognizes pattern of recurring cardiac symptoms
(chest pain → normal ECG → short gap → return with SOB)Practical Implementation
Complete Outcome Prediction System
class ClinicalOutcomePredictor(nn.Module):
"""
End-to-end system for predicting clinical outcomes
"""
def __init__(self, num_outcomes=5):
super().__init__()
# Patient encoder
self.patient_encoder = PatientEventEncoder(
event_vocab_size=10000,
d_model=512,
num_heads=8,
num_layers=6
)
# Outcome prediction head
self.prediction_head = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_outcomes)
)
def forward(self, patient_sequence):
# Encode patient trajectory
embeddings = self.patient_encoder(patient_sequence)
# Pool to single representation (use [CLS] token or mean)
patient_repr = embeddings.mean(dim=1) # (batch, d_model)
# Predict outcomes
logits = self.prediction_head(patient_repr) # (batch, num_outcomes)
return logits
# Training
model = ClinicalOutcomePredictor(num_outcomes=5)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss() # Multi-label classification
for patient_sequences, outcomes in train_loader:
logits = model(patient_sequences)
loss = criterion(logits, outcomes)
optimizer.zero_grad()
loss.backward()
optimizer.step()Best Practices
✅ Do:
- Pre-train on large corpus of patient sequences (self-supervised)
- Use positional encoding that captures temporal information
- Implement attention visualization for interpretability
- Validate attention patterns with clinicians
- Fine-tune with lower learning rates for specific outcomes
- Handle variable-length sequences with padding masks
- Consider multi-modal fusion (EHR + text + imaging)
❌ Don’t:
- Train from scratch (pre-training is crucial)
- Ignore temporal information (time matters in healthcare)
- Deploy without interpretability (attention is essential)
- Forget data leakage checks (future events can’t predict past)
- Use excessively long context windows (computation cost)
- Assume model generalizes across hospitals without validation
Related Concepts
Attention Mechanism - Core concept
Multi-Head Attention - Parallel attention
Transformer Architecture - Original paper
GPT Architecture - Decoder-only alternative
EHR Data Structure - Data format
Healthcare Foundation Models - Pre-trained models (ETHOS, BEHRT)
Clinical Interpretability - Explainability requirements
Learning Resources
Key Papers
- BEHRT (Li et al., 2020): BERT for Electronic Health Records
- ETHOS (Raff et al., 2024): Zero-shot trajectory prediction
- Med-BERT (Rasmy et al., 2021): Medical concept embedding
- Hi-BEHRT (Pang et al., 2021): Hierarchical transformers for EHR
Datasets
- MIMIC-III: 53,423 ICU admissions (2001-2012) with detailed event sequences
- MIMIC-IV v3.1: Over 65,000 ICU patients + over 200,000 ED patients (2008-2022) with enhanced data quality, ICD-9/10 codes, improved mortality data, and provider tracking
- eICU: Multi-center ICU database (200,000 admissions)
Code
- PyHealth: Python library for healthcare AI (includes EHR transformers)
- Hugging Face Medical: Pre-trained clinical transformers