Molecules are graphs, and graph neural networks are the natural architecture for predicting their properties. Every molecule has atoms (nodes) connected by chemical bonds (edges). Atom features encode element type, formal charge, and hybridization state. Bond features encode bond order (single, double, triple, aromatic) and geometry. GNNs process this graph directly, learning which structural patterns correlate with properties like toxicity, solubility, and binding affinity.
This has transformed drug discovery. Instead of synthesizing and testing thousands of candidate compounds in wet labs (months, millions of dollars), a trained GNN screens millions of virtual compounds in hours, prioritizing the most promising candidates for physical testing.
From SMILES strings to molecular graphs
Molecules are typically stored as SMILES strings (e.g., CC(=O)Oc1ccccc1C(=O)O for aspirin). To build a graph, parse the SMILES into atoms and bonds:
- Node features: atomic number, degree, formal charge, number of hydrogens, aromaticity, hybridization (sp, sp2, sp3), and whether the atom is in a ring
- Edge features: bond type (single, double, triple, aromatic), conjugation, ring membership, stereochemistry (E/Z, R/S)
from torch_geometric.datasets import MoleculeNet
# Load a standard molecular benchmark
dataset = MoleculeNet(root='data/', name='ESOL')
# ESOL: 1,128 molecules, predict aqueous solubility
mol = dataset[0]
print(mol)
# Data(x=[13, 9], edge_index=[2, 28], edge_attr=[28, 3], y=[1, 1])
# 13 atoms, 28 bonds (directed), 9 atom features, 3 bond features
# y = solubility value (regression target)Each molecule becomes a PyG Data object. x holds atom features, edge_index holds bond connectivity, edge_attr holds bond features, y holds the target property.
How GNNs learn molecular properties
The architecture has three stages:
Stage 1: atom-level message passing
Each atom exchanges information with its bonded neighbors through message passing. After layer 1, each atom knows its immediate chemical environment. After layer 3, each atom encodes information from atoms up to 3 bonds away, which captures functional group context.
Stage 2: graph-level readout
After message passing, we need a single vector representing the entire molecule. A readout function pools over all atom embeddings:
- Sum pooling: preserves molecular size information (larger molecules have larger sums)
- Mean pooling: normalizes by size, better for size-invariant properties
- Attention pooling: learns which atoms matter most for the target property
Stage 3: property prediction
The molecular embedding feeds into a classification head (is this molecule toxic: yes/no) or regression head (what is its solubility value). Multi-task models predict multiple properties simultaneously, sharing the molecular representation.
import torch
import torch.nn.functional as F
from torch_geometric.nn import GINConv, global_add_pool
from torch.nn import Sequential, Linear, ReLU, BatchNorm1d
class MolecularGNN(torch.nn.Module):
def __init__(self, num_features, hidden_dim, num_classes):
super().__init__()
# GIN layers for maximum expressiveness
self.conv1 = GINConv(Sequential(
Linear(num_features, hidden_dim), BatchNorm1d(hidden_dim), ReLU(),
Linear(hidden_dim, hidden_dim), ReLU()
))
self.conv2 = GINConv(Sequential(
Linear(hidden_dim, hidden_dim), BatchNorm1d(hidden_dim), ReLU(),
Linear(hidden_dim, hidden_dim), ReLU()
))
self.conv3 = GINConv(Sequential(
Linear(hidden_dim, hidden_dim), BatchNorm1d(hidden_dim), ReLU(),
Linear(hidden_dim, hidden_dim), ReLU()
))
self.classifier = Linear(hidden_dim, num_classes)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = self.conv2(x, edge_index)
x = self.conv3(x, edge_index)
x = global_add_pool(x, batch) # graph-level readout
return self.classifier(x)GIN (Graph Isomorphism Network) is a popular choice for molecular property prediction because its sum aggregation is maximally expressive for distinguishing molecular structures.
GNNs vs traditional cheminformatics
Before GNNs, molecular ML relied on molecular fingerprints: fixed-length bit vectors encoding the presence or absence of predefined substructures. The most common is the Morgan/ECFP fingerprint, which hashes circular neighborhoods of each atom.
GNNs have three advantages over fingerprints:
- Learned substructures: fingerprints use fixed hash functions; GNNs learn which substructural patterns matter for each specific task
- Continuous representations: fingerprints are binary (present/absent); GNN embeddings are continuous vectors that capture degree of structural similarity
- Global context: fingerprints encode local neighborhoods independently; GNN message passing propagates information across the entire molecule, capturing long-range interactions
Real-world impact
Graph-based molecular ML has produced tangible results in drug discovery:
- Virtual screening: screen millions of compounds in hours instead of months. GNNs predict binding affinity to target proteins, ranking candidates for synthesis.
- Toxicity prediction: identify toxic compounds before synthesis, reducing animal testing and failed clinical trials.
- Molecular generation: graph-based generative models propose novel molecules with desired properties, expanding the search beyond known chemical space.
- Retrosynthesis: predict synthesis routes for target molecules, planning how to actually manufacture a drug candidate.
Batching molecular graphs in PyG
Unlike images (all 224x224), molecular graphs vary in size (10-100+ atoms). PyG handles this through graph mini-batching: multiple molecules are combined into a single disconnected graph. The batch vector tracks which atoms belong to which molecule, enabling global_add_pool to produce per-molecule embeddings efficiently.