The business problem
Bringing a drug to market costs $2.6 billion on average, and 90% of clinical trials fail. A Phase III failure can cost $50-300 million in direct trial costs plus years of lost time. Better prediction of trial outcomes enables pharmaceutical companies to allocate R&D budgets more effectively, redesign trials before launch, or terminate unlikely candidates earlier.
Trial success depends on far more than the drug itself. The sponsor's experience with the condition, the quality and enrollment capacity of trial sites, the choice of endpoints, and the competitive landscape of similar trials all affect outcomes. These are relational factors that require graph-level reasoning.
Why flat ML fails
- No sponsor context: A sponsor with 10 successful oncology trials brings expertise that increases success probability. Flat models see “sponsor_id = 42”, not the sponsor's track record in the specific condition area.
- No site quality: Trial sites vary enormously in enrollment speed, data quality, and investigator experience. The site graph captures these quality signals from past trial participation.
- No mechanism-of-action signal: If similar drugs (same mechanism) have failed for the same condition, the probability of success decreases. The intervention-condition graph propagates these failure signals.
- No competitive landscape: Multiple trials competing for the same patient population reduce enrollment speed and increase failure risk. The condition graph captures this competition.
The relational schema
Node types:
Trial (id, phase, enrollment, start_date, design)
Sponsor (id, type, portfolio_size, therapeutic_focus)
Site (id, institution, country, enrollment_capacity)
Condition (id, icd10, therapeutic_area, patient_population)
Intervention (id, drug_class, mechanism, route)
Edge types:
Trial --[sponsored_by]--> Sponsor
Trial --[conducted_at]--> Site (enrollment, quality)
Trial --[targets]--> Condition
Trial --[tests]--> Intervention
Intervention --[similar_to]--> Intervention (mechanism_overlap)Five node types capture the full trial context. Similar-intervention edges propagate mechanism-of-action success/failure signals.
PyG architecture: HeteroConv for trial prediction
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, HeteroConv, Linear
class TrialSuccessGNN(torch.nn.Module):
def __init__(self, hidden_dim=64):
super().__init__()
self.trial_lin = Linear(-1, hidden_dim)
self.sponsor_lin = Linear(-1, hidden_dim)
self.site_lin = Linear(-1, hidden_dim)
self.condition_lin = Linear(-1, hidden_dim)
self.intervention_lin = Linear(-1, hidden_dim)
self.conv1 = HeteroConv({
('trial', 'sponsored_by', 'sponsor'): SAGEConv(
hidden_dim, hidden_dim),
('trial', 'conducted_at', 'site'): SAGEConv(
hidden_dim, hidden_dim),
('trial', 'targets', 'condition'): SAGEConv(
hidden_dim, hidden_dim),
('trial', 'tests', 'intervention'): SAGEConv(
hidden_dim, hidden_dim),
('intervention', 'similar_to', 'intervention'):
SAGEConv(hidden_dim, hidden_dim),
}, aggr='mean')
self.conv2 = HeteroConv({
('trial', 'sponsored_by', 'sponsor'): SAGEConv(
hidden_dim, hidden_dim),
('trial', 'targets', 'condition'): SAGEConv(
hidden_dim, hidden_dim),
('trial', 'tests', 'intervention'): SAGEConv(
hidden_dim, hidden_dim),
}, aggr='mean')
self.classifier = Linear(hidden_dim, 1)
def forward(self, x_dict, edge_index_dict):
x_dict['trial'] = self.trial_lin(x_dict['trial'])
x_dict['sponsor'] = self.sponsor_lin(x_dict['sponsor'])
x_dict['site'] = self.site_lin(x_dict['site'])
x_dict['condition'] = self.condition_lin(
x_dict['condition'])
x_dict['intervention'] = self.intervention_lin(
x_dict['intervention'])
x_dict = {k: F.relu(v) for k, v in
self.conv1(x_dict, edge_index_dict).items()}
x_dict = self.conv2(x_dict, edge_index_dict)
return torch.sigmoid(
self.classifier(x_dict['trial']).squeeze(-1))HeteroConv aggregates sponsor track record, site quality, condition precedent, and intervention similarity into per-trial success predictions.
Expected performance
- Phase-based heuristic: ~55 AUROC
- LightGBM (flat features): 62.44 AUROC
- GNN (HeteroConv): 75.83 AUROC
- KumoRFM (zero-shot): 76.71 AUROC
Or use KumoRFM in one line
PREDICT is_successful FOR trial
USING trial, sponsor, site, condition, interventionOne PQL query. KumoRFM captures sponsor-site-condition patterns from historical trial data for success prediction.