Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Use Case11 min read

Clinical Trial Success: GNN on Sponsor-Site-Condition Graphs

90% of clinical trials fail. Each failed Phase III trial costs $50-300M. Here is how to build a GNN that predicts trial success from the network of sponsors, sites, conditions, and past trial outcomes.

PyTorch Geometric

TL;DR

  • 1Clinical trial success is a graph prediction problem. Outcomes depend on the relationships between sponsors, sites, investigators, conditions, and interventions.
  • 2HeteroConv on the trial-sponsor-site-condition graph captures sponsor track records, site capabilities, and condition-intervention patterns in a unified model.
  • 3On RelBench benchmarks, GNNs achieve 75.83 AUROC vs 62.44 for flat-table LightGBM. Relational context (who is running the trial, where, and for what) drives the improvement.
  • 4ClinicalTrials.gov provides structured data on 400K+ trials for training. Historical patterns predict future trial outcomes.
  • 5KumoRFM predicts trial outcomes with one PQL query (76.71 AUROC zero-shot), capturing sponsor-site-condition patterns automatically.

The business problem

Bringing a drug to market costs $2.6 billion on average, and 90% of clinical trials fail. A Phase III failure can cost $50-300 million in direct trial costs plus years of lost time. Better prediction of trial outcomes enables pharmaceutical companies to allocate R&D budgets more effectively, redesign trials before launch, or terminate unlikely candidates earlier.

Trial success depends on far more than the drug itself. The sponsor's experience with the condition, the quality and enrollment capacity of trial sites, the choice of endpoints, and the competitive landscape of similar trials all affect outcomes. These are relational factors that require graph-level reasoning.

Why flat ML fails

  • No sponsor context: A sponsor with 10 successful oncology trials brings expertise that increases success probability. Flat models see “sponsor_id = 42”, not the sponsor's track record in the specific condition area.
  • No site quality: Trial sites vary enormously in enrollment speed, data quality, and investigator experience. The site graph captures these quality signals from past trial participation.
  • No mechanism-of-action signal: If similar drugs (same mechanism) have failed for the same condition, the probability of success decreases. The intervention-condition graph propagates these failure signals.
  • No competitive landscape: Multiple trials competing for the same patient population reduce enrollment speed and increase failure risk. The condition graph captures this competition.

The relational schema

schema.txt
Node types:
  Trial        (id, phase, enrollment, start_date, design)
  Sponsor      (id, type, portfolio_size, therapeutic_focus)
  Site         (id, institution, country, enrollment_capacity)
  Condition    (id, icd10, therapeutic_area, patient_population)
  Intervention (id, drug_class, mechanism, route)

Edge types:
  Trial       --[sponsored_by]-->  Sponsor
  Trial       --[conducted_at]-->  Site        (enrollment, quality)
  Trial       --[targets]-->       Condition
  Trial       --[tests]-->         Intervention
  Intervention --[similar_to]-->   Intervention (mechanism_overlap)

Five node types capture the full trial context. Similar-intervention edges propagate mechanism-of-action success/failure signals.

PyG architecture: HeteroConv for trial prediction

trial_model.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, HeteroConv, Linear

class TrialSuccessGNN(torch.nn.Module):
    def __init__(self, hidden_dim=64):
        super().__init__()
        self.trial_lin = Linear(-1, hidden_dim)
        self.sponsor_lin = Linear(-1, hidden_dim)
        self.site_lin = Linear(-1, hidden_dim)
        self.condition_lin = Linear(-1, hidden_dim)
        self.intervention_lin = Linear(-1, hidden_dim)

        self.conv1 = HeteroConv({
            ('trial', 'sponsored_by', 'sponsor'): SAGEConv(
                hidden_dim, hidden_dim),
            ('trial', 'conducted_at', 'site'): SAGEConv(
                hidden_dim, hidden_dim),
            ('trial', 'targets', 'condition'): SAGEConv(
                hidden_dim, hidden_dim),
            ('trial', 'tests', 'intervention'): SAGEConv(
                hidden_dim, hidden_dim),
            ('intervention', 'similar_to', 'intervention'):
                SAGEConv(hidden_dim, hidden_dim),
        }, aggr='mean')

        self.conv2 = HeteroConv({
            ('trial', 'sponsored_by', 'sponsor'): SAGEConv(
                hidden_dim, hidden_dim),
            ('trial', 'targets', 'condition'): SAGEConv(
                hidden_dim, hidden_dim),
            ('trial', 'tests', 'intervention'): SAGEConv(
                hidden_dim, hidden_dim),
        }, aggr='mean')

        self.classifier = Linear(hidden_dim, 1)

    def forward(self, x_dict, edge_index_dict):
        x_dict['trial'] = self.trial_lin(x_dict['trial'])
        x_dict['sponsor'] = self.sponsor_lin(x_dict['sponsor'])
        x_dict['site'] = self.site_lin(x_dict['site'])
        x_dict['condition'] = self.condition_lin(
            x_dict['condition'])
        x_dict['intervention'] = self.intervention_lin(
            x_dict['intervention'])

        x_dict = {k: F.relu(v) for k, v in
                  self.conv1(x_dict, edge_index_dict).items()}
        x_dict = self.conv2(x_dict, edge_index_dict)

        return torch.sigmoid(
            self.classifier(x_dict['trial']).squeeze(-1))

HeteroConv aggregates sponsor track record, site quality, condition precedent, and intervention similarity into per-trial success predictions.

Expected performance

  • Phase-based heuristic: ~55 AUROC
  • LightGBM (flat features): 62.44 AUROC
  • GNN (HeteroConv): 75.83 AUROC
  • KumoRFM (zero-shot): 76.71 AUROC

Or use KumoRFM in one line

KumoRFM PQL
PREDICT is_successful FOR trial
USING trial, sponsor, site, condition, intervention

One PQL query. KumoRFM captures sponsor-site-condition patterns from historical trial data for success prediction.

Frequently asked questions

Why use GNNs to predict clinical trial success?

Trial outcomes depend on the relationships between sponsors, sites, investigators, conditions, and interventions. A sponsor with a strong track record in oncology running trials at high-enrollment sites with experienced investigators has different success probability than the same drug at inexperienced sites. GNNs capture these relational patterns.

What data feeds the clinical trial graph?

ClinicalTrials.gov provides structured data on 400K+ trials: sponsors, conditions, interventions, phases, sites, enrollment, and outcomes. The graph connects trials to sponsors, conditions, and sites, with trial outcomes as labels. Historical patterns predict future trial success.

How does the graph capture drug mechanism-of-action similarity?

Interventions (drugs) connect to conditions (diseases) they target. Similar drugs (same mechanism of action, similar molecular structure) share edges. If Drug A failed for Condition X via Mechanism Y, and Drug B uses a similar mechanism for the same condition, the graph propagates this negative signal.

Can GNNs predict which trial design choices affect success?

Yes. Trial design features (phase, endpoint type, enrollment criteria, number of sites) become node attributes. The GNN learns which design choices correlate with success in the context of the specific condition, sponsor, and site network. This enables design optimization before trial launch.

How does KumoRFM handle clinical trial prediction?

KumoRFM takes your clinical trial database (trials, sponsors, sites, conditions, interventions) and predicts trial outcomes with one PQL query. It captures sponsor track records, site capabilities, and condition-intervention patterns automatically.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.