Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide8 min read

Drug Discovery with Graph Neural Networks: Molecular Property Prediction

A molecule is a graph. Atoms are nodes, bonds are edges. Graph neural networks predict toxicity, solubility, and binding affinity directly from this structure, replacing years of wet-lab experimentation with seconds of computation.

PyTorch Geometric

TL;DR

  • 1Molecules are literally graphs: atoms are nodes with features (element, charge, hybridization), bonds are edges with features (type, stereochemistry). No representation conversion needed.
  • 2Message passing across bonds lets each atom learn its chemical environment. After 3-5 layers, an atom embedding encodes functional groups, ring membership, and nearby electronegative atoms.
  • 3Graph-level readout (sum/mean pooling over all atoms) produces a single molecular embedding. This feeds into classifiers for toxicity, solubility, binding affinity, and ADMET properties.
  • 4GNNs outperform traditional Morgan fingerprints on most MoleculeNet tasks because they learn task-specific substructure patterns rather than using hand-designed rules.
  • 5PyG provides molecular datasets (ZINC, QM9, MoleculeNet, OGB-Mol) with pre-processed atom and bond features. A molecular GNN can be built and trained in under 50 lines of code.

Molecules are graphs, and graph neural networks are the natural architecture for predicting their properties. Every molecule has atoms (nodes) connected by chemical bonds (edges). Atom features encode element type, formal charge, and hybridization state. Bond features encode bond order (single, double, triple, aromatic) and geometry. GNNs process this graph directly, learning which structural patterns correlate with properties like toxicity, solubility, and binding affinity.

This has transformed drug discovery. Instead of synthesizing and testing thousands of candidate compounds in wet labs (months, millions of dollars), a trained GNN screens millions of virtual compounds in hours, prioritizing the most promising candidates for physical testing.

From SMILES strings to molecular graphs

Molecules are typically stored as SMILES strings (e.g., CC(=O)Oc1ccccc1C(=O)O for aspirin). To build a graph, parse the SMILES into atoms and bonds:

  • Node features: atomic number, degree, formal charge, number of hydrogens, aromaticity, hybridization (sp, sp2, sp3), and whether the atom is in a ring
  • Edge features: bond type (single, double, triple, aromatic), conjugation, ring membership, stereochemistry (E/Z, R/S)
molecular_graph.py
from torch_geometric.datasets import MoleculeNet

# Load a standard molecular benchmark
dataset = MoleculeNet(root='data/', name='ESOL')
# ESOL: 1,128 molecules, predict aqueous solubility

mol = dataset[0]
print(mol)
# Data(x=[13, 9], edge_index=[2, 28], edge_attr=[28, 3], y=[1, 1])
# 13 atoms, 28 bonds (directed), 9 atom features, 3 bond features
# y = solubility value (regression target)

Each molecule becomes a PyG Data object. x holds atom features, edge_index holds bond connectivity, edge_attr holds bond features, y holds the target property.

How GNNs learn molecular properties

The architecture has three stages:

Stage 1: atom-level message passing

Each atom exchanges information with its bonded neighbors through message passing. After layer 1, each atom knows its immediate chemical environment. After layer 3, each atom encodes information from atoms up to 3 bonds away, which captures functional group context.

Stage 2: graph-level readout

After message passing, we need a single vector representing the entire molecule. A readout function pools over all atom embeddings:

  • Sum pooling: preserves molecular size information (larger molecules have larger sums)
  • Mean pooling: normalizes by size, better for size-invariant properties
  • Attention pooling: learns which atoms matter most for the target property

Stage 3: property prediction

The molecular embedding feeds into a classification head (is this molecule toxic: yes/no) or regression head (what is its solubility value). Multi-task models predict multiple properties simultaneously, sharing the molecular representation.

molecular_gnn.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import GINConv, global_add_pool
from torch.nn import Sequential, Linear, ReLU, BatchNorm1d

class MolecularGNN(torch.nn.Module):
    def __init__(self, num_features, hidden_dim, num_classes):
        super().__init__()
        # GIN layers for maximum expressiveness
        self.conv1 = GINConv(Sequential(
            Linear(num_features, hidden_dim), BatchNorm1d(hidden_dim), ReLU(),
            Linear(hidden_dim, hidden_dim), ReLU()
        ))
        self.conv2 = GINConv(Sequential(
            Linear(hidden_dim, hidden_dim), BatchNorm1d(hidden_dim), ReLU(),
            Linear(hidden_dim, hidden_dim), ReLU()
        ))
        self.conv3 = GINConv(Sequential(
            Linear(hidden_dim, hidden_dim), BatchNorm1d(hidden_dim), ReLU(),
            Linear(hidden_dim, hidden_dim), ReLU()
        ))
        self.classifier = Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = self.conv2(x, edge_index)
        x = self.conv3(x, edge_index)
        x = global_add_pool(x, batch)  # graph-level readout
        return self.classifier(x)

GIN (Graph Isomorphism Network) is a popular choice for molecular property prediction because its sum aggregation is maximally expressive for distinguishing molecular structures.

GNNs vs traditional cheminformatics

Before GNNs, molecular ML relied on molecular fingerprints: fixed-length bit vectors encoding the presence or absence of predefined substructures. The most common is the Morgan/ECFP fingerprint, which hashes circular neighborhoods of each atom.

GNNs have three advantages over fingerprints:

  • Learned substructures: fingerprints use fixed hash functions; GNNs learn which substructural patterns matter for each specific task
  • Continuous representations: fingerprints are binary (present/absent); GNN embeddings are continuous vectors that capture degree of structural similarity
  • Global context: fingerprints encode local neighborhoods independently; GNN message passing propagates information across the entire molecule, capturing long-range interactions

Real-world impact

Graph-based molecular ML has produced tangible results in drug discovery:

  • Virtual screening: screen millions of compounds in hours instead of months. GNNs predict binding affinity to target proteins, ranking candidates for synthesis.
  • Toxicity prediction: identify toxic compounds before synthesis, reducing animal testing and failed clinical trials.
  • Molecular generation: graph-based generative models propose novel molecules with desired properties, expanding the search beyond known chemical space.
  • Retrosynthesis: predict synthesis routes for target molecules, planning how to actually manufacture a drug candidate.

Batching molecular graphs in PyG

Unlike images (all 224x224), molecular graphs vary in size (10-100+ atoms). PyG handles this through graph mini-batching: multiple molecules are combined into a single disconnected graph. The batch vector tracks which atoms belong to which molecule, enabling global_add_pool to produce per-molecule embeddings efficiently.

Frequently asked questions

Why are molecules naturally represented as graphs?

A molecule is literally a graph: atoms are nodes and chemical bonds are edges. Atom features include element type, charge, and hybridization. Bond features include bond type (single, double, triple, aromatic) and stereochemistry. This representation preserves the full 2D topology of the molecule, which determines its chemical properties.

What molecular properties can GNNs predict?

GNNs predict a wide range of molecular properties: toxicity (is this compound toxic?), solubility (will it dissolve in water?), binding affinity (how strongly does it bind to a target protein?), ADMET properties (absorption, distribution, metabolism, excretion, toxicity), and synthesizability (how easy is it to manufacture?). These are classification or regression tasks on graph-level representations.

How does a GNN process a molecular graph?

Each atom starts with features encoding its element type, charge, and hybridization. Message passing layers let atoms exchange information with bonded neighbors. After 3-5 layers, each atom's embedding encodes its chemical environment (functional groups, ring membership, nearby electronegative atoms). A readout function (sum or mean pooling over all atoms) produces a single vector for the entire molecule, which feeds into a classifier or regressor.

What are the standard molecular datasets in PyG?

PyG provides several molecular benchmark datasets: MoleculeNet (a collection of 17 datasets covering various properties), ZINC (250K drug-like molecules for graph regression), QM9 (134K molecules with quantum mechanical properties), and OGB-MolPCBA/MolHIV (large-scale molecular classification from Open Graph Benchmark). These datasets come pre-processed with atom and bond features.

How do GNNs compare to traditional cheminformatics approaches?

Traditional approaches like Morgan fingerprints convert molecules to fixed-length bit vectors based on hand-designed substructure rules. GNNs learn the relevant substructures automatically from data. On MoleculeNet benchmarks, GNN-based models (particularly those with attention mechanisms and edge features) outperform fingerprint-based methods on most tasks, especially for complex properties that depend on global molecular structure.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.