The business problem
Drug discovery is one of the most expensive and failure-prone endeavors in industry. The average cost to bring a drug to market is $2.6 billion, and the process takes 10-15 years. Approximately 90% of drug candidates that enter clinical trials fail. Early identification of problematic molecular properties (poor solubility, liver toxicity, low bioavailability) can eliminate bad candidates years earlier, saving hundreds of millions per avoided failure.
Traditional computational chemistry uses molecular fingerprints (ECFP, MACCS) and descriptors (molecular weight, LogP, TPSA) as fixed representations. These hand-engineered features lose structural information and cannot capture subtle 3D interactions. GNNs learn task-specific representations directly from the molecular graph.
Why flat ML fails
- Information loss: Molecular fingerprints are fixed-length bit vectors that hash substructures. Two structurally different molecules can produce identical fingerprints (hash collisions). GNNs maintain the full structural information.
- No bond awareness: Flat features treat all bonds equally. The difference between a single bond and an aromatic bond in a specific position can determine whether a molecule is a potent drug or an inert compound.
- No learned representations: Fingerprints are static. GNNs learn task-specific features: different molecular substructures matter for toxicity vs solubility vs binding affinity.
- No transfer learning: Fingerprint-based models train from scratch for each task. GNN representations transfer across tasks, enabling multi-task learning and few-shot prediction.
The molecular graph
Node features (per atom):
Atomic number, degree, formal charge, hybridization,
aromaticity, num_hydrogens, is_in_ring
Edge features (per bond):
Bond type (single=1, double=2, triple=3, aromatic=4),
conjugated, in_ring, stereo_type
Global features (per molecule):
Molecular weight, LogP, TPSA, num_rotatable_bonds
# Example: Aspirin (C9H8O4)
# 13 atoms (nodes), 13 bonds (edges)
# Bond types: 5 aromatic, 4 single, 2 double, 2 esterAtoms are nodes, bonds are edges. RDKit converts SMILES strings to PyG Data objects with node features, edge features, and edge indices.
PyG architecture: NNConv + global pooling
import torch
import torch.nn.functional as F
from torch_geometric.nn import NNConv, global_mean_pool, Linear
class MoleculeGNN(torch.nn.Module):
def __init__(self, node_dim, edge_dim, hidden_dim=128):
super().__init__()
self.node_lin = Linear(node_dim, hidden_dim)
# NNConv: edge-conditioned message passing
# nn generates weight matrix from bond features
nn1 = torch.nn.Sequential(
Linear(edge_dim, hidden_dim * hidden_dim))
self.conv1 = NNConv(hidden_dim, hidden_dim, nn1)
nn2 = torch.nn.Sequential(
Linear(edge_dim, hidden_dim * hidden_dim))
self.conv2 = NNConv(hidden_dim, hidden_dim, nn2)
# Global pooling: molecule-level prediction
self.head = torch.nn.Sequential(
Linear(hidden_dim, 64),
torch.nn.ReLU(),
Linear(64, 1),
)
def forward(self, x, edge_index, edge_attr, batch):
x = F.relu(self.node_lin(x))
x = F.relu(self.conv1(x, edge_index, edge_attr))
x = self.conv2(x, edge_index, edge_attr)
# Pool atom embeddings to molecule embedding
x = global_mean_pool(x, batch)
return self.head(x).squeeze(-1)
# Data preparation with RDKit
# from rdkit import Chem
# mol = Chem.MolFromSmiles('CC(=O)OC1=CC=CC=C1C(=O)O')
# Convert to PyG Data object with atom/bond featuresNNConv generates bond-specific transformations from edge features. Global mean pooling aggregates atom-level embeddings into a molecule-level representation for property prediction.
Training considerations
- Data sources: MoleculeNet provides standardized benchmarks (ESOL, FreeSol, Tox21, HIV). ZINC and ChEMBL provide larger datasets for pre-training.
- Multi-task learning: Predict multiple ADMET properties simultaneously. Shared molecular representations improve performance on each individual task.
- Scaffold splitting: Use Bemis-Murcko scaffold splits (not random) for evaluation. This tests generalization to novel molecular scaffolds, which matters for real drug discovery.
- 3D structure: For binding affinity prediction, add 3D coordinates as node features or use SchNet/DimeNet architectures that incorporate inter-atomic distances.
Expected performance
On MoleculeNet benchmarks, GNNs consistently outperform fingerprint-based methods:
- ESOL (solubility): GNN RMSE ~0.55, Random Forest RMSE ~0.90
- Tox21 (toxicity): GNN AUROC ~0.82, SVM AUROC ~0.75
- HIV (activity): GNN AUROC ~0.80, RF AUROC ~0.72
Or use KumoRFM for relational drug data
PREDICT trial_success FOR compound
USING compound, assay_result, clinical_trial, patient_outcomeFor relational drug discovery data (trials, assays, patient outcomes), KumoRFM predicts outcomes from database tables. For molecular graph tasks, use PyG directly.
Drug discovery involves both molecular graph tasks (property prediction, where PyG excels) and relational data tasks (clinical trial prediction, assay result analysis). KumoRFM handles the relational side automatically with its pre-trained graph transformer.