Hierarchical pooling progressively coarsens graphs into smaller representations. Instead of jumping from all node embeddings to a single graph vector in one step, hierarchical pooling reduces the graph gradually: 100 nodes become 50, then 25, then 10, before a final global readout. Each level captures structure at a different scale, similar to how convolutional neural networks build a spatial hierarchy in images.
This is particularly valuable for graphs with natural hierarchy: molecules have atoms that form functional groups that form the whole molecule. Social networks have individuals in communities in larger communities. Hierarchical pooling captures these multi-scale patterns.
Three pooling strategies
Node selection (TopKPooling, SAGPool)
Learn a score for each node and keep the top k%. Dropped nodes are removed from the graph. Edges between remaining nodes are preserved.
import torch
from torch_geometric.nn import GCNConv, TopKPooling, global_mean_pool
class HierarchicalGNN(torch.nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.conv1 = GCNConv(in_dim, hidden_dim)
self.pool1 = TopKPooling(hidden_dim, ratio=0.5) # keep 50%
self.conv2 = GCNConv(hidden_dim, hidden_dim)
self.pool2 = TopKPooling(hidden_dim, ratio=0.5) # keep 50% of remaining
self.classifier = torch.nn.Linear(hidden_dim, out_dim)
def forward(self, x, edge_index, batch):
# Layer 1: message passing + pool
x = self.conv1(x, edge_index).relu()
x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, batch=batch)
# Graph now has 50% of original nodes
# Layer 2: message passing + pool
x = self.conv2(x, edge_index).relu()
x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, batch=batch)
# Graph now has 25% of original nodes
# Final global readout
graph_emb = global_mean_pool(x, batch)
return self.classifier(graph_emb)Alternating GNN layers with TopKPooling. Each pool step halves the graph. Two levels of hierarchy before global readout.
Node clustering (DiffPool)
A separate GNN learns to assign nodes to clusters. Nodes in the same cluster are merged into a super-node. The coarsened graph has one node per cluster:
# DiffPool: learn cluster assignments with a GNN
# For each pooling level:
# 1. GNN produces node embeddings
Z = gnn_embed(x, edge_index) # [num_nodes, hidden_dim]
# 2. Another GNN produces soft cluster assignments
S = gnn_pool(x, edge_index) # [num_nodes, num_clusters]
S = softmax(S, dim=-1) # soft assignment matrix
# 3. Coarsen: aggregate nodes into clusters
X_coarse = S.T @ Z # [num_clusters, hidden_dim]
A_coarse = S.T @ A @ S # [num_clusters, num_clusters]
# The coarsened graph has num_clusters nodes
# Repeat at next level with fewer clustersDiffPool learns which nodes to group. The assignment GNN is trained end-to-end with the task loss.
Enterprise example: circuit classification
An electronics manufacturer classifies circuit designs as passing or failing quality checks. Circuit graphs have natural hierarchy:
- Level 0: individual components (resistors, capacitors, transistors)
- Level 1: functional blocks (amplifiers, filters, regulators)
- Level 2: subsystems (power supply, signal processing, control logic)
- Level 3: complete circuit
Hierarchical pooling with 3 coarsening levels mirrors this structure. The model first learns component-level interactions, then block-level interactions, then subsystem-level interactions, before classifying the entire circuit. Defects at any level (a bad component, a poorly designed block, a subsystem integration issue) are captured at the appropriate hierarchy level.
When to use hierarchical vs global
- Hierarchical: large graphs (100+ nodes), natural hierarchy, multi-scale patterns matter
- Global: small graphs (<50 nodes), no clear hierarchy, simpler to implement and tune