What GINConv does
GINConv answers a fundamental question: what is the most powerful aggregation a message-passing GNN can perform? The answer is sum aggregation followed by an injective function (an MLP):
- Sum all neighbor feature vectors (preserving the multiset)
- Add the node's own features (scaled by a learnable parameter)
- Pass through a multi-layer perceptron (MLP)
This simple design achieves the theoretical maximum expressiveness for any message-passing GNN, matching the Weisfeiler-Leman test.
The math (simplified)
h_i' = MLP( (1 + epsilon) · h_i + Σ_j h_j )
Where:
Σ_j h_j = SUM of all neighbor features (not mean, not max)
epsilon = learnable scalar (or fixed at 0)
MLP = multi-layer perceptron (at least 2 layers)
Why sum? Consider three multisets:
{1, 1, 1} mean=1, max=1, sum=3
{1, 1} mean=1, max=1, sum=2
{1} mean=1, max=1, sum=1
Mean and max cannot distinguish these. Sum can.Sum aggregation preserves multiset cardinality and composition. The MLP provides the injective mapping needed for WL-equivalent expressiveness.
PyG implementation
import torch
import torch.nn.functional as F
from torch_geometric.nn import GINConv, global_add_pool
class GIN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers=5):
super().__init__()
self.convs = torch.nn.ModuleList()
for i in range(num_layers):
in_dim = in_channels if i == 0 else hidden_channels
# MLP with 2 layers (critical for expressiveness)
mlp = torch.nn.Sequential(
torch.nn.Linear(in_dim, hidden_channels),
torch.nn.ReLU(),
torch.nn.Linear(hidden_channels, hidden_channels),
)
self.convs.append(GINConv(mlp))
self.classifier = torch.nn.Linear(hidden_channels, out_channels)
def forward(self, x, edge_index, batch):
for conv in self.convs:
x = conv(x, edge_index)
x = F.relu(x)
# Graph-level readout
x = global_add_pool(x, batch) # sum pooling (not mean!)
return self.classifier(x)
# Usage on molecular dataset
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG')
loader = DataLoader(dataset, batch_size=32, shuffle=True)
model = GIN(dataset.num_features, 64, dataset.num_classes)
for batch in loader:
out = model(batch.x, batch.edge_index, batch.batch)Note: use global_add_pool (sum) for graph-level readout, not global_mean_pool. Sum readout preserves the expressiveness that GIN provides.
When to use GINConv
- Graph classification. Molecules, proteins, chemical compounds. Tasks where you need to distinguish structurally different graphs. GIN is the standard baseline for graph classification benchmarks.
- When structural discrimination matters. If your task requires telling apart graphs with similar summary statistics but different structures (e.g., a ring vs a chain with the same number of nodes), GIN is the right choice.
- Graph-level pre-training. GINEConv (GIN with edge features) is the standard backbone for graph self-supervised learning strategies like those in Hu et al. (2019).
- Theoretical research. When you need a provably maximally expressive baseline to compare against new architectures.
When not to use GINConv
1. Node classification on large graphs
Sum aggregation amplifies high-degree node features, which can cause numerical instability on graphs with power-law degree distributions. GCNConv's degree normalization or GATConv's softmax normalization handle this more gracefully.
2. When edge features matter
GINConv ignores edge features entirely. For molecular graphs with bond types, knowledge graphs with relation types, or transaction networks with amounts, use GINEConv or NNConv instead.
3. Heterogeneous graphs
GINConv applies the same MLP to all node types. Enterprise relational databases with multiple table types need type-specific transformations. Use HGTConv or HeteroConv.