Original Paper
Neural Message Passing for Quantum Chemistry
Gilmer et al. (2017). ICML 2017
Read paper →What NNConv does
NNConv uses a neural network to generate weight matrices from edge features:
- For each edge (i, j), pass edge features through an MLP to produce a weight matrix W_ij
- Multiply the source node's features by this edge-specific weight matrix: m_ij = W_ij * h_j
- Aggregate messages across all neighbors: h_i' = sum(m_ij)
The math (simplified)
# Edge-conditioned weight matrix
W_ij = NN(e_ij) # NN maps edge features to a weight matrix
# W_ij shape: [out_channels, in_channels]
# Edge-specific message
m_ij = W_ij · h_j
# Aggregation
h_i' = h_i · W_self + Σ_j m_ij
Where:
NN = neural network (MLP) mapping edge features to weight matrices
e_ij = edge feature vector (bond type, distance, etc.)
W_self = self-loop weight matrixThe key difference from GINEConv: edge features generate the entire weight matrix, not just an additive modifier. This is maximally expressive for edge conditioning.
PyG implementation
import torch
import torch.nn.functional as F
from torch_geometric.nn import NNConv, global_add_pool
class MPNN(torch.nn.Module):
def __init__(self, node_dim, edge_dim, hidden, out_channels):
super().__init__()
# NN that generates weight matrices from edge features
nn1 = torch.nn.Sequential(
torch.nn.Linear(edge_dim, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, node_dim * hidden), # output = weight matrix
)
nn2 = torch.nn.Sequential(
torch.nn.Linear(edge_dim, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, hidden * hidden),
)
self.conv1 = NNConv(node_dim, hidden, nn1, aggr='add')
self.conv2 = NNConv(hidden, hidden, nn2, aggr='add')
self.classifier = torch.nn.Linear(hidden, out_channels)
def forward(self, x, edge_index, edge_attr, batch):
x = F.relu(self.conv1(x, edge_index, edge_attr))
x = F.relu(self.conv2(x, edge_index, edge_attr))
x = global_add_pool(x, batch)
return self.classifier(x)
# QM9: predict molecular properties
model = MPNN(node_dim=11, edge_dim=4, hidden=64, out_channels=1)The NN output size must be in_channels * out_channels (flattened weight matrix). For a 64->64 layer with 4-dim edge features, the NN outputs 64*64=4096 values per edge.
When to use NNConv
- Quantum chemistry (QM9, QM7). Molecular properties depend critically on bond types and distances. NNConv's edge-conditioned transformation captures this.
- Drug discovery. Molecular binding affinity, toxicity, and ADMET properties depend on specific bond configurations that NNConv models naturally.
- When edge features should fully condition the transformation. If a single bond and a double bond should produce fundamentally different transformations (not just additive modifiers).
When not to use NNConv
- Large graphs. Generating a weight matrix per edge is expensive. On graphs with millions of edges, the memory cost is prohibitive. Use GINEConv.
- When edge features are simple modifiers. If edge features just add context (timestamps, weights) rather than fundamentally changing the transformation, GINEConv or TransformerConv is more efficient.