The business problem
Hospital readmissions within 30 days cost Medicare $26 billion annually. Since 2012, CMS's Hospital Readmissions Reduction Program penalizes hospitals with excess readmission rates, reducing Medicare payments by up to 3%. Beyond financial penalties, readmissions indicate care quality gaps: inadequate discharge planning, missed follow-ups, or unaddressed comorbidities.
Traditional readmission models (LACE, HOSPITAL) use a handful of patient-level features: length of stay, acuity, comorbidities, and ED visits. They achieve modest discrimination (C-statistic ~0.60-0.65) because they miss the clinical context: which providers were involved, what care pathway was followed, and how similar patients with similar care patterns fared.
Why flat ML fails
- No care pathway signal: Two patients with the same diagnosis can have very different readmission risk based on their care pathway: procedures, medication regimen, and discharge disposition. Flat models compress this into aggregate features.
- No provider context: Provider-level quality varies. A patient discharged by a provider with high readmission rates carries risk that individual patient features do not capture.
- Comorbidity interactions: The interaction between diabetes and heart failure is different from diabetes and COPD. The diagnosis graph captures these interactions naturally through shared patient nodes.
- Medication interactions: Polypharmacy risk depends on the specific combination of medications, not just the count. The medication subgraph captures known and novel interaction patterns.
The relational schema
Node types:
Patient (id, age, sex, insurance_type)
Encounter (id, type, los, admit_date, discharge_date)
Diagnosis (id, icd10, category, chronic_flag)
Procedure (id, cpt_code, category, cost)
Medication (id, drug_class, route, frequency)
Provider (id, specialty, years_exp, panel_size)
Edge types:
Patient --[had_encounter]--> Encounter
Encounter --[has_diagnosis]--> Diagnosis
Encounter --[has_procedure]--> Procedure
Encounter --[prescribed]--> Medication
Encounter --[treated_by]--> ProviderSix node types capture the full clinical context. The graph encodes care pathways: which diagnoses led to which procedures and medications, under which provider.
PyG architecture: HeteroConv for EHR data
import torch
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, SAGEConv, Linear
class ReadmissionGNN(torch.nn.Module):
def __init__(self, hidden_dim=64):
super().__init__()
self.patient_lin = Linear(-1, hidden_dim)
self.encounter_lin = Linear(-1, hidden_dim)
self.diagnosis_lin = Linear(-1, hidden_dim)
self.procedure_lin = Linear(-1, hidden_dim)
self.medication_lin = Linear(-1, hidden_dim)
self.provider_lin = Linear(-1, hidden_dim)
self.conv1 = HeteroConv({
('patient', 'had_encounter', 'encounter'): SAGEConv(
hidden_dim, hidden_dim),
('encounter', 'has_diagnosis', 'diagnosis'): SAGEConv(
hidden_dim, hidden_dim),
('encounter', 'has_procedure', 'procedure'): SAGEConv(
hidden_dim, hidden_dim),
('encounter', 'prescribed', 'medication'): SAGEConv(
hidden_dim, hidden_dim),
('encounter', 'treated_by', 'provider'): SAGEConv(
hidden_dim, hidden_dim),
}, aggr='sum')
self.conv2 = HeteroConv({
('patient', 'had_encounter', 'encounter'): SAGEConv(
hidden_dim, hidden_dim),
('encounter', 'has_diagnosis', 'diagnosis'): SAGEConv(
hidden_dim, hidden_dim),
('encounter', 'treated_by', 'provider'): SAGEConv(
hidden_dim, hidden_dim),
}, aggr='sum')
self.classifier = Linear(hidden_dim, 1)
def forward(self, x_dict, edge_index_dict):
for key, lin in [('patient', self.patient_lin),
('encounter', self.encounter_lin),
('diagnosis', self.diagnosis_lin),
('procedure', self.procedure_lin),
('medication', self.medication_lin),
('provider', self.provider_lin)]:
x_dict[key] = lin(x_dict[key])
x_dict = {k: F.relu(v) for k, v in
self.conv1(x_dict, edge_index_dict).items()}
x_dict = self.conv2(x_dict, edge_index_dict)
return torch.sigmoid(
self.classifier(x_dict['patient']).squeeze(-1))HeteroConv on 6 node types. Two hops let each patient aggregate diagnosis, procedure, medication, and provider signals through their encounters.
Expected performance
- LACE score (heuristic): ~60 C-statistic
- LightGBM (flat-table): 62.44 AUROC
- GNN (HeteroConv on EHR): 75.83 AUROC
- KumoRFM (zero-shot): 76.71 AUROC
Or use KumoRFM in one line
PREDICT is_readmitted_30d FOR patient
USING patient, encounter, diagnosis, procedure, medicationOne PQL query. KumoRFM constructs the clinical graph from EHR tables, handles temporal encounters, and predicts 30-day readmission risk.