Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Use Case12 min read

Drug Discovery: NNConv for Molecular Property Prediction

Bringing a drug to market costs $2.6B and takes 10-15 years. 90% of candidates fail in clinical trials. Here is how to build a GNN that predicts molecular properties directly from the molecular graph, accelerating the discovery pipeline.

PyTorch Geometric

TL;DR

  • 1Molecules are naturally graphs: atoms are nodes, bonds are edges. GNNs operate directly on this structure, learning which substructures predict properties like toxicity, solubility, and binding affinity.
  • 2NNConv generates bond-specific weight matrices, so single bonds, double bonds, and aromatic bonds produce different message transformations, critical for chemical accuracy.
  • 3GNN molecular models outperform fingerprint-based methods on MoleculeNet benchmarks: ~0.82 AUROC on Tox21, ~0.80 on HIV, and ~0.55 RMSE on ESOL vs ~0.90 for Random Forest.
  • 4The PyG model is ~35 lines using NNConv + global pooling. Production drug discovery pipelines also need virtual screening, ADMET filtering, and synthetic accessibility scoring.
  • 5For relational drug discovery data (clinical trials, assay results, patient outcomes), KumoRFM predicts outcomes with one PQL query. Molecular graph tasks use PyG directly.

The business problem

Drug discovery is one of the most expensive and failure-prone endeavors in industry. The average cost to bring a drug to market is $2.6 billion, and the process takes 10-15 years. Approximately 90% of drug candidates that enter clinical trials fail. Early identification of problematic molecular properties (poor solubility, liver toxicity, low bioavailability) can eliminate bad candidates years earlier, saving hundreds of millions per avoided failure.

Traditional computational chemistry uses molecular fingerprints (ECFP, MACCS) and descriptors (molecular weight, LogP, TPSA) as fixed representations. These hand-engineered features lose structural information and cannot capture subtle 3D interactions. GNNs learn task-specific representations directly from the molecular graph.

Why flat ML fails

  • Information loss: Molecular fingerprints are fixed-length bit vectors that hash substructures. Two structurally different molecules can produce identical fingerprints (hash collisions). GNNs maintain the full structural information.
  • No bond awareness: Flat features treat all bonds equally. The difference between a single bond and an aromatic bond in a specific position can determine whether a molecule is a potent drug or an inert compound.
  • No learned representations: Fingerprints are static. GNNs learn task-specific features: different molecular substructures matter for toxicity vs solubility vs binding affinity.
  • No transfer learning: Fingerprint-based models train from scratch for each task. GNN representations transfer across tasks, enabling multi-task learning and few-shot prediction.

The molecular graph

molecule_graph.txt
Node features (per atom):
  Atomic number, degree, formal charge, hybridization,
  aromaticity, num_hydrogens, is_in_ring

Edge features (per bond):
  Bond type (single=1, double=2, triple=3, aromatic=4),
  conjugated, in_ring, stereo_type

Global features (per molecule):
  Molecular weight, LogP, TPSA, num_rotatable_bonds

# Example: Aspirin (C9H8O4)
# 13 atoms (nodes), 13 bonds (edges)
# Bond types: 5 aromatic, 4 single, 2 double, 2 ester

Atoms are nodes, bonds are edges. RDKit converts SMILES strings to PyG Data objects with node features, edge features, and edge indices.

PyG architecture: NNConv + global pooling

drug_discovery_model.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import NNConv, global_mean_pool, Linear

class MoleculeGNN(torch.nn.Module):
    def __init__(self, node_dim, edge_dim, hidden_dim=128):
        super().__init__()
        self.node_lin = Linear(node_dim, hidden_dim)

        # NNConv: edge-conditioned message passing
        # nn generates weight matrix from bond features
        nn1 = torch.nn.Sequential(
            Linear(edge_dim, hidden_dim * hidden_dim))
        self.conv1 = NNConv(hidden_dim, hidden_dim, nn1)

        nn2 = torch.nn.Sequential(
            Linear(edge_dim, hidden_dim * hidden_dim))
        self.conv2 = NNConv(hidden_dim, hidden_dim, nn2)

        # Global pooling: molecule-level prediction
        self.head = torch.nn.Sequential(
            Linear(hidden_dim, 64),
            torch.nn.ReLU(),
            Linear(64, 1),
        )

    def forward(self, x, edge_index, edge_attr, batch):
        x = F.relu(self.node_lin(x))
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = self.conv2(x, edge_index, edge_attr)

        # Pool atom embeddings to molecule embedding
        x = global_mean_pool(x, batch)

        return self.head(x).squeeze(-1)

# Data preparation with RDKit
# from rdkit import Chem
# mol = Chem.MolFromSmiles('CC(=O)OC1=CC=CC=C1C(=O)O')
# Convert to PyG Data object with atom/bond features

NNConv generates bond-specific transformations from edge features. Global mean pooling aggregates atom-level embeddings into a molecule-level representation for property prediction.

Training considerations

  • Data sources: MoleculeNet provides standardized benchmarks (ESOL, FreeSol, Tox21, HIV). ZINC and ChEMBL provide larger datasets for pre-training.
  • Multi-task learning: Predict multiple ADMET properties simultaneously. Shared molecular representations improve performance on each individual task.
  • Scaffold splitting: Use Bemis-Murcko scaffold splits (not random) for evaluation. This tests generalization to novel molecular scaffolds, which matters for real drug discovery.
  • 3D structure: For binding affinity prediction, add 3D coordinates as node features or use SchNet/DimeNet architectures that incorporate inter-atomic distances.

Expected performance

On MoleculeNet benchmarks, GNNs consistently outperform fingerprint-based methods:

  • ESOL (solubility): GNN RMSE ~0.55, Random Forest RMSE ~0.90
  • Tox21 (toxicity): GNN AUROC ~0.82, SVM AUROC ~0.75
  • HIV (activity): GNN AUROC ~0.80, RF AUROC ~0.72

Or use KumoRFM for relational drug data

KumoRFM PQL
PREDICT trial_success FOR compound
USING compound, assay_result, clinical_trial, patient_outcome

For relational drug discovery data (trials, assays, patient outcomes), KumoRFM predicts outcomes from database tables. For molecular graph tasks, use PyG directly.

Drug discovery involves both molecular graph tasks (property prediction, where PyG excels) and relational data tasks (clinical trial prediction, assay result analysis). KumoRFM handles the relational side automatically with its pre-trained graph transformer.

Frequently asked questions

Why are GNNs the standard approach for molecular property prediction?

Molecules are naturally graphs: atoms are nodes and bonds are edges. Traditional approaches (fingerprints, descriptors) lose structural information. GNNs operate directly on the molecular graph, learning which substructures (functional groups, ring systems) predict properties like solubility, toxicity, and binding affinity.

What is NNConv and why use it for molecules?

NNConv (Neural Network Convolution) uses a neural network to generate edge-specific weight matrices based on edge features. For molecules, different bond types (single, double, aromatic) should produce different message transformations. NNConv learns these bond-specific transformations automatically, which is critical for chemical accuracy.

What molecular properties can GNNs predict?

GNNs predict ADMET properties (Absorption, Distribution, Metabolism, Excretion, Toxicity), solubility, binding affinity to target proteins, synthetic accessibility, and drug-likeness. Each is a regression or classification task on the molecular graph. Multi-task models that predict multiple properties simultaneously often outperform single-task models.

How do you represent a molecule as a graph in PyG?

Atoms become nodes with features: atomic number, degree, formal charge, hybridization, aromaticity, and hydrogen count. Bonds become edges with features: bond type (single/double/triple/aromatic), conjugation, ring membership, and stereochemistry. RDKit converts SMILES strings to these graph representations.

Can KumoRFM be used for drug discovery?

KumoRFM is designed for relational/tabular data rather than molecular graphs. However, for drug discovery workflows that involve relational data (clinical trials, patient outcomes, assay results), KumoRFM can predict outcomes from these relational tables with a single PQL query.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.