Skip to Content

Interpretability and Validation for Healthcare AI

Overview

Medical AI systems require interpretable, explainable predictions that clinicians can validate and trust. Unlike general AI applications, healthcare models face strict regulatory requirements, safety concerns, and the need for clinical acceptance. This module covers attention visualization, SHAP analysis, clinical validation protocols, and fairness audits.

Why Interpretability Matters in Healthcare

Critical Requirements

  1. Clinical Validation: Doctors need to understand why the model made a prediction
  2. Safety: Wrong predictions can harm or kill patients
  3. Trust: Clinicians will not use black-box models in high-stakes decisions
  4. Regulatory Compliance: FDA requires explainability for medical device approval
  5. Legal Liability: Healthcare providers need to justify AI-assisted decisions

Consequences of Opacity

  • Missed diagnoses: If clinicians don’t trust predictions, they ignore them
  • Automation bias: Over-reliance on opaque models leads to uncritical acceptance
  • Bias amplification: Hidden biases in training data perpetuate healthcare disparities
  • Regulatory rejection: Opaque models fail FDA/EMA approval processes

Key Principle: In healthcare, accuracy alone is insufficient. Models must provide interpretable reasoning that clinicians can validate.

Attention Visualization

Attention mechanisms naturally provide interpretability by showing which past events the model focuses on when making predictions.

Extracting Attention Weights

def visualize_attention(model, event_sequence, event_names, event_times): """ Visualize which past events the model attends to Args: model: Transformer model with attention event_sequence: (seq_len,) patient trajectory event_names: List of event descriptions event_times: (seq_len,) timestamps Returns: Attention heatmap """ with torch.no_grad(): # Forward pass with attention outputs output, attention_weights = model( event_sequence.unsqueeze(0), return_attention=True ) # attention_weights: (n_layers, n_heads, seq_len, seq_len) # Average attention across layers and heads avg_attention = attention_weights.mean(dim=(0, 1)) # (seq_len, seq_len) # Plot heatmap plt.figure(figsize=(12, 10)) sns.heatmap( avg_attention.numpy(), xticklabels=event_names, yticklabels=event_names, cmap='Blues', vmin=0, vmax=1 ) plt.xlabel('Event (Key)') plt.ylabel('Event (Query)') plt.title('Attention Weights: Which Events Matter for Prediction?') plt.tight_layout() plt.show() return avg_attention

Interpreting Attention Patterns

Example: Mortality prediction for cardiac patient

High attention on:

  • Recent lab values: Troponin (indicates cardiac damage), lactate (tissue hypoxia)
  • Critical diagnoses: Acute MI (I21.0), cardiogenic shock
  • High-risk procedures: Mechanical ventilation, vasopressor administration

Lower attention on:

  • Routine vitals: Stable blood pressure readings
  • Older events: Outpatient visits from months ago
  • Common diagnoses: Hypertension (I10), which most patients have

This pattern suggests the model correctly identifies acute severity markers rather than chronic background conditions.

Head-Specific Attention Analysis

Different attention heads may specialize in different patterns:

def analyze_attention_heads(model, event_sequence): """ Analyze what different attention heads focus on """ with torch.no_grad(): _, attention_weights = model(event_sequence, return_attention=True) # attention_weights: (n_layers, n_heads, seq_len, seq_len) # Analyze last layer attention heads last_layer_attention = attention_weights[-1] # (n_heads, seq_len, seq_len) for head_idx in range(last_layer_attention.shape[0]): head_attention = last_layer_attention[head_idx] # Find which events this head focuses on max_attention_idx = head_attention.argmax(dim=-1) print(f"Head {head_idx}:") print(f" Primary focus: {event_names[max_attention_idx[0]]}") print(f" Attention entropy: {entropy(head_attention[0]):.3f}")

Example findings:

  • Head 0: Focuses on diagnosis codes (disease-specific attention)
  • Head 3: Focuses on temporal proximity (recent events)
  • Head 7: Focuses on lab values (quantitative measurements)

SHAP: SHapley Additive exPlanations

SHAP values quantify each feature’s contribution to the prediction using game-theoretic Shapley values.

Implementation

import shap import torch def compute_shap_values(model, patient_trajectory, background_data): """ Compute SHAP values for a patient trajectory Args: model: Trained healthcare AI model patient_trajectory: (seq_len, features) patient data background_data: (n_samples, seq_len, features) reference dataset Returns: SHAP values for each event """ # Create SHAP explainer explainer = shap.DeepExplainer(model, background_data) # Compute SHAP values shap_values = explainer.shap_values(patient_trajectory) return shap_values def visualize_shap(shap_values, event_names, prediction): """ Visualize SHAP values as force plot """ shap.force_plot( base_value=explainer.expected_value, shap_values=shap_values, features=event_names, matplotlib=True ) # Bar plot of top contributing events shap.summary_plot(shap_values, event_names, plot_type="bar")

Interpreting SHAP Values

Example: 30-day readmission prediction

  • Positive SHAP (increase risk):

    • E11.65 (diabetes with hyperglycemia): +0.12
    • I50.9 (heart failure, unspecified): +0.08
    • High creatinine lab value: +0.05
  • Negative SHAP (decrease risk):

    • Completed cardiac rehab program: -0.07
    • Outpatient follow-up scheduled: -0.04
    • Stable vital signs at discharge: -0.03

SHAP values provide quantitative attribution of risk to specific events, enabling clinicians to validate model reasoning.

Clinical Validation Protocol

A rigorous multi-step validation process for medical AI systems.

1. Retrospective Evaluation

Test on historical data with known outcomes:

def retrospective_evaluation(model, test_data): """ Evaluate model on historical data """ results = { 'accuracy': compute_accuracy(model, test_data), 'auroc': compute_auroc(model, test_data), 'auprc': compute_auprc(model, test_data), 'calibration': compute_calibration_curve(model, test_data), 'sensitivity': compute_sensitivity(model, test_data), 'specificity': compute_specificity(model, test_data) } return results

2. Prospective Study

Test on new patients as they arrive (before outcomes known):

def prospective_validation(model, new_patients): """ Prospective evaluation on new patients 1. Model makes predictions on admission 2. Predictions recorded but not shown to clinicians 3. Actual outcomes recorded independently 4. Compare predictions to actual outcomes """ predictions = [] actual_outcomes = [] for patient in new_patients: # Model predicts at admission pred = model.predict(patient.admission_data) predictions.append(pred) # Wait for actual outcome actual_outcome = patient.wait_for_outcome() actual_outcomes.append(actual_outcome) # Evaluate prospective performance auroc = compute_auroc(predictions, actual_outcomes) return auroc

3. Clinician Review

Expert clinicians evaluate predictions for clinical plausibility:

def clinician_review_study(model, test_cases, clinicians): """ Have clinicians review model predictions Returns: Agreement rate, clinician confidence, perceived utility """ results = [] for case in test_cases: # Model prediction model_pred, attention_viz, shap_values = model.predict_with_explanation(case) # Clinician review for clinician in clinicians: review = clinician.evaluate( patient_data=case, model_prediction=model_pred, explanation=attention_viz, shap_values=shap_values ) results.append({ 'case_id': case.id, 'clinician_id': clinician.id, 'agrees_with_model': review.agreement, 'confidence': review.confidence, 'would_use_in_practice': review.utility }) # Analyze agreement agreement_rate = np.mean([r['agrees_with_model'] for r in results]) return agreement_rate, results

4. Failure Analysis

Study incorrect predictions to identify systematic errors:

def failure_analysis(model, test_data): """ Analyze model failures to find patterns """ predictions = model.predict(test_data) errors = test_data[predictions != test_data.labels] # Categorize errors false_positives = errors[predictions > test_data.labels] false_negatives = errors[predictions < test_data.labels] # Analyze error patterns print(f"False Positives: {len(false_positives)}") print(f" Common diagnoses: {false_positives.diagnoses.value_counts()[:5]}") print(f" Average age: {false_positives.age.mean():.1f}") print(f"False Negatives: {len(false_negatives)}") print(f" Common diagnoses: {false_negatives.diagnoses.value_counts()[:5]}") print(f" Average severity: {false_negatives.severity_score.mean():.2f}") return false_positives, false_negatives

5. Fairness Audits

Check for demographic bias:

def fairness_audit(model, test_data, protected_attrs=['race', 'sex', 'age_group']): """ Audit model for fairness across demographic groups """ fairness_metrics = {} for attr in protected_attrs: groups = test_data[attr].unique() # Compute metrics per group group_metrics = {} for group in groups: group_data = test_data[test_data[attr] == group] group_metrics[group] = { 'auroc': compute_auroc(model, group_data), 'accuracy': compute_accuracy(model, group_data), 'fpr': compute_fpr(model, group_data), 'fnr': compute_fnr(model, group_data), 'calibration_slope': compute_calibration(model, group_data) } fairness_metrics[attr] = group_metrics return fairness_metrics

Comprehensive Validation Pipeline

Complete validation for healthcare AI:

def validate_clinical_model(model, test_data, clinicians): """ Comprehensive validation for healthcare AI """ results = {} # 1. Performance metrics results['performance'] = { 'accuracy': compute_accuracy(model, test_data), 'auroc': compute_auroc(model, test_data), 'auprc': compute_auprc(model, test_data), 'calibration': compute_calibration(model, test_data), 'sensitivity': compute_sensitivity(model, test_data), 'specificity': compute_specificity(model, test_data) } # 2. Interpretability analysis results['interpretability'] = { 'attention_maps': extract_attention_maps(model, test_data), 'shap_values': compute_shap_values(model, test_data), 'feature_importance': compute_feature_importance(model, test_data) } # 3. Fairness metrics results['fairness'] = fairness_audit( model, test_data, protected_attrs=['race', 'sex', 'age_group', 'insurance_type'] ) # 4. Clinical validation results['clinical_validation'] = { 'clinician_agreement': clinician_review_study(model, test_data, clinicians), 'failure_patterns': failure_analysis(model, test_data), 'prospective_performance': prospective_validation(model, new_patients) } # 5. Regulatory compliance results['regulatory'] = { 'meets_fda_criteria': check_fda_compliance(results), 'documentation': generate_regulatory_documentation(model, results) } return results

Fairness in Healthcare AI

Medical AI must be fair across demographics to avoid perpetuating healthcare disparities.

Fairness Metrics

  1. Demographic Parity: Similar positive prediction rates across groups

    • P(Y^=1A=a)=P(Y^=1A=b)P(\hat{Y}=1 | A=a) = P(\hat{Y}=1 | A=b) for all groups a,ba, b
  2. Equalized Odds: Similar TPR and FPR across groups

    • P(Y^=1Y=1,A=a)=P(Y^=1Y=1,A=b)P(\hat{Y}=1 | Y=1, A=a) = P(\hat{Y}=1 | Y=1, A=b) (Equal TPR)
    • P(Y^=1Y=0,A=a)=P(Y^=1Y=0,A=b)P(\hat{Y}=1 | Y=0, A=a) = P(\hat{Y}=1 | Y=0, A=b) (Equal FPR)
  3. Calibration: Predictions equally reliable across groups

    • P(Y=1Y^=p,A=a)=pP(Y=1 | \hat{Y}=p, A=a) = p for all groups aa and probabilities pp

Common Sources of Bias

  • Historical bias: Training data reflects historical healthcare disparities
  • Representation bias: Underrepresentation of minority groups in training data
  • Measurement bias: Diagnostic criteria may vary by demographic group
  • Label bias: Ground truth labels may be influenced by provider biases

For EmergAI Thesis

Interpretability requirements for the multimodal EmergAI model:

1. Attention Visualization

  • EHR attention: Which past events (diagnoses, procedures, labs) matter?
  • Text attention: Which symptom phrases from Symptoms.se drive predictions?
  • Sketch attention: Which body regions in 3D sketches are most important?
  • Cross-modal attention: How do modalities interact?

2. SHAP Analysis

  • Modality-level SHAP: Quantify contribution of EHR vs. text vs. sketch
  • Event-level SHAP: Within EHR, which specific events contribute most?
  • Token-level SHAP: Within text, which symptom descriptions matter?

3. Failure Analysis

  • Error categories: Where does multimodal model fail vs. ETHOS baseline?
  • Modality-specific failures: Do errors correlate with missing modalities?
  • Demographic patterns: Are errors evenly distributed across demographics?

4. Fairness Audit

  • Demographics: Age, sex, triage level, insurance status
  • Metrics: AUROC, calibration, FPR, FNR per group
  • Disparities: Identify and document any performance gaps
  • Attention Mechanism: The foundation for attention-based interpretability
  • Multi-Head Attention: Different heads may specialize in different patterns
  • Healthcare Foundation Models: Models that require interpretability

Learning Resources

Papers

Tools

Regulatory Guidance

Applications

  • Clinical decision support: Explain predictions to doctors for validation
  • Regulatory approval: Demonstrate safety and interpretability to FDA/EMA
  • Fairness auditing: Identify and mitigate demographic biases
  • Model debugging: Find systematic errors and failure modes
  • Trust building: Enable clinician adoption through transparency

Key Takeaways

  1. Healthcare AI must be interpretable—accuracy alone is insufficient for clinical adoption
  2. Attention visualization shows which events the model focuses on
  3. SHAP values quantify feature contributions using game theory
  4. Clinical validation requires retrospective, prospective, and expert review
  5. Fairness audits ensure equitable performance across demographics
  6. Regulatory compliance (FDA/EMA) mandates explainability and safety documentation