Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Use Case11 min read

Patient Readmission: GNN on Electronic Health Record Graphs

Hospital readmissions cost Medicare $26B annually. CMS penalizes hospitals with high readmission rates. Here is how to build a GNN that predicts 30-day readmission by modeling the full clinical graph.

PyTorch Geometric

TL;DR

  • 1Patient readmission is a graph classification problem. Patients are connected to diagnoses, procedures, medications, and providers, forming a clinical network with rich predictive signal.
  • 2HeteroConv on the EHR graph captures comorbidity patterns, care pathway quality, and provider-level risk signals that patient-level features miss.
  • 3On RelBench benchmarks, GNNs achieve 75.83 AUROC vs 62.44 for flat-table LightGBM. Clinical network context provides 13+ points of lift.
  • 4HIPAA compliance requires on-premise or HIPAA-cloud deployment. The entire graph pipeline must run within the compliance perimeter.
  • 5KumoRFM predicts readmission with one PQL query (76.71 AUROC zero-shot), running within your HIPAA environment and handling temporal clinical data automatically.

The business problem

Hospital readmissions within 30 days cost Medicare $26 billion annually. Since 2012, CMS's Hospital Readmissions Reduction Program penalizes hospitals with excess readmission rates, reducing Medicare payments by up to 3%. Beyond financial penalties, readmissions indicate care quality gaps: inadequate discharge planning, missed follow-ups, or unaddressed comorbidities.

Traditional readmission models (LACE, HOSPITAL) use a handful of patient-level features: length of stay, acuity, comorbidities, and ED visits. They achieve modest discrimination (C-statistic ~0.60-0.65) because they miss the clinical context: which providers were involved, what care pathway was followed, and how similar patients with similar care patterns fared.

Why flat ML fails

  • No care pathway signal: Two patients with the same diagnosis can have very different readmission risk based on their care pathway: procedures, medication regimen, and discharge disposition. Flat models compress this into aggregate features.
  • No provider context: Provider-level quality varies. A patient discharged by a provider with high readmission rates carries risk that individual patient features do not capture.
  • Comorbidity interactions: The interaction between diabetes and heart failure is different from diabetes and COPD. The diagnosis graph captures these interactions naturally through shared patient nodes.
  • Medication interactions: Polypharmacy risk depends on the specific combination of medications, not just the count. The medication subgraph captures known and novel interaction patterns.

The relational schema

schema.txt
Node types:
  Patient    (id, age, sex, insurance_type)
  Encounter  (id, type, los, admit_date, discharge_date)
  Diagnosis  (id, icd10, category, chronic_flag)
  Procedure  (id, cpt_code, category, cost)
  Medication (id, drug_class, route, frequency)
  Provider   (id, specialty, years_exp, panel_size)

Edge types:
  Patient   --[had_encounter]-->   Encounter
  Encounter --[has_diagnosis]-->   Diagnosis
  Encounter --[has_procedure]-->   Procedure
  Encounter --[prescribed]-->      Medication
  Encounter --[treated_by]-->      Provider

Six node types capture the full clinical context. The graph encodes care pathways: which diagnoses led to which procedures and medications, under which provider.

PyG architecture: HeteroConv for EHR data

readmission_model.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, SAGEConv, Linear

class ReadmissionGNN(torch.nn.Module):
    def __init__(self, hidden_dim=64):
        super().__init__()
        self.patient_lin = Linear(-1, hidden_dim)
        self.encounter_lin = Linear(-1, hidden_dim)
        self.diagnosis_lin = Linear(-1, hidden_dim)
        self.procedure_lin = Linear(-1, hidden_dim)
        self.medication_lin = Linear(-1, hidden_dim)
        self.provider_lin = Linear(-1, hidden_dim)

        self.conv1 = HeteroConv({
            ('patient', 'had_encounter', 'encounter'): SAGEConv(
                hidden_dim, hidden_dim),
            ('encounter', 'has_diagnosis', 'diagnosis'): SAGEConv(
                hidden_dim, hidden_dim),
            ('encounter', 'has_procedure', 'procedure'): SAGEConv(
                hidden_dim, hidden_dim),
            ('encounter', 'prescribed', 'medication'): SAGEConv(
                hidden_dim, hidden_dim),
            ('encounter', 'treated_by', 'provider'): SAGEConv(
                hidden_dim, hidden_dim),
        }, aggr='sum')

        self.conv2 = HeteroConv({
            ('patient', 'had_encounter', 'encounter'): SAGEConv(
                hidden_dim, hidden_dim),
            ('encounter', 'has_diagnosis', 'diagnosis'): SAGEConv(
                hidden_dim, hidden_dim),
            ('encounter', 'treated_by', 'provider'): SAGEConv(
                hidden_dim, hidden_dim),
        }, aggr='sum')

        self.classifier = Linear(hidden_dim, 1)

    def forward(self, x_dict, edge_index_dict):
        for key, lin in [('patient', self.patient_lin),
                         ('encounter', self.encounter_lin),
                         ('diagnosis', self.diagnosis_lin),
                         ('procedure', self.procedure_lin),
                         ('medication', self.medication_lin),
                         ('provider', self.provider_lin)]:
            x_dict[key] = lin(x_dict[key])

        x_dict = {k: F.relu(v) for k, v in
                  self.conv1(x_dict, edge_index_dict).items()}
        x_dict = self.conv2(x_dict, edge_index_dict)

        return torch.sigmoid(
            self.classifier(x_dict['patient']).squeeze(-1))

HeteroConv on 6 node types. Two hops let each patient aggregate diagnosis, procedure, medication, and provider signals through their encounters.

Expected performance

  • LACE score (heuristic): ~60 C-statistic
  • LightGBM (flat-table): 62.44 AUROC
  • GNN (HeteroConv on EHR): 75.83 AUROC
  • KumoRFM (zero-shot): 76.71 AUROC

Or use KumoRFM in one line

KumoRFM PQL
PREDICT is_readmitted_30d FOR patient
USING patient, encounter, diagnosis, procedure, medication

One PQL query. KumoRFM constructs the clinical graph from EHR tables, handles temporal encounters, and predicts 30-day readmission risk.

Frequently asked questions

Why use GNNs for readmission prediction instead of logistic regression?

Logistic regression uses patient-level features (age, diagnosis, LOS). GNNs model the clinical network: which providers treated the patient, what procedures were ordered, which medications were prescribed, and how similar patients fared. A patient discharged by a provider with high readmission rates carries hidden risk that individual features miss.

What EHR data goes into the readmission graph?

Patients, encounters, diagnoses, procedures, medications, providers, and departments. The graph connects patients to their clinical events and providers, creating a rich network that captures care pathways, comorbidity patterns, and provider-level quality signals.

How do you handle HIPAA compliance with GNN models?

Keep all data on-premise or in HIPAA-compliant cloud environments. De-identify patient nodes (use IDs, not names). The graph structure itself is considered PHI, so the entire pipeline must run within the HIPAA perimeter. Use federated learning if multi-institution training is needed.

What is the standard readmission prediction window?

30-day all-cause readmission is the CMS standard for hospital penalties. Some models also predict 7-day readmission (urgent, high-acuity) and 90-day readmission (chronic disease management). The prediction is made at the time of discharge.

Can KumoRFM predict patient readmission?

Yes. KumoRFM takes your EHR tables (patients, encounters, diagnoses, procedures, medications) and predicts readmission with one PQL query, running entirely within your HIPAA-compliant environment. It handles temporal clinical data and comorbidity patterns automatically.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.