Graph coarsening reduces a graph by merging groups of similar nodes into supernodes while preserving the graph's structural properties. A graph of 100,000 nodes can be coarsened to 10,000, then to 1,000, then to 100. At each level, the coarsened graph retains the community structure, connectivity patterns, and spectral properties of the original. This enables hierarchical GNN architectures that capture multi-scale structural patterns.
Why coarsen?
- Graph-level prediction: Classifying an entire molecule or graph requires compressing all node representations into one vector. Hierarchical coarsening does this more effectively than global mean/sum pooling.
- Multi-scale patterns: Local motifs (triangles, functional groups) and global patterns (communities, graph diameter) exist at different scales. Hierarchical processing captures both.
- Computational efficiency: Coarsening reduces the number of nodes and edges, making subsequent GNN layers faster while retaining important structural information.
Coarsening methods
Score-based pooling (TopKPool, SAGPool)
Learn a score for each node indicating its importance. Keep the top-K fraction, drop the rest, and rewire edges among surviving nodes.
from torch_geometric.nn import TopKPooling, SAGEConv, global_mean_pool
class HierarchicalGNN(torch.nn.Module):
def __init__(self, in_channels, hidden, out_channels):
super().__init__()
# Level 1: full resolution
self.conv1 = SAGEConv(in_channels, hidden)
self.pool1 = TopKPooling(hidden, ratio=0.5) # keep 50%
# Level 2: half resolution
self.conv2 = SAGEConv(hidden, hidden)
self.pool2 = TopKPooling(hidden, ratio=0.5) # keep 25%
# Level 3: quarter resolution
self.conv3 = SAGEConv(hidden, hidden)
self.classifier = torch.nn.Linear(hidden, out_channels)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index).relu()
x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, batch=batch)
x = self.conv2(x, edge_index).relu()
x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, batch=batch)
x = self.conv3(x, edge_index).relu()
x = global_mean_pool(x, batch) # final graph representation
return self.classifier(x)Three-level hierarchical GNN. Each TopKPooling layer removes 50% of nodes, creating a pyramid of increasingly abstract representations.
Clustering-based pooling (DiffPool)
Learn a soft assignment matrix that maps each node to a cluster. Supernodes are weighted combinations of original nodes. Edges between supernodes are derived from inter-cluster edge density.
DiffPool is more expressive than score-based methods because it learns the grouping structure jointly with the GNN. The trade-off: it requires O(n^2) memory for the assignment matrix, limiting scalability to graphs with fewer than ~5,000 nodes.
Edge contraction
Select edges to contract (merge endpoint nodes). The merged supernode inherits the union of both nodes' edges. Selection criteria: merge nodes with similar representations, high edge weight, or by graph matching algorithms.
Preserving structure during coarsening
Good coarsening preserves:
- Community structure: Nodes in the same community should be merged together, not across communities.
- Spectral properties: The eigenvalues of the coarsened graph Laplacian should approximate the original's.
- Important motifs: Triangles, cliques, and cycles in the original should have counterparts in the coarsened graph.