Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide7 min read

Hierarchical Pooling: Progressively Coarsening Graphs

Hierarchical pooling reduces graph size layer by layer, building a multi-resolution representation. Like image pooling shrinks spatial resolution, graph pooling shrinks the number of nodes, capturing structure at multiple scales.

PyTorch Geometric

TL;DR

  • 1Hierarchical pooling progressively reduces graph size: 100 nodes to 50 to 25 to 10, then global readout. Each level captures structure at a different scale.
  • 2Three strategies: node selection (keep top-k by score), node clustering (merge similar nodes into super-nodes), and edge contraction (collapse edges to merge endpoints).
  • 3TopKPooling is simplest: learn a score per node, keep the top k%. SAGPool adds self-attention. DiffPool uses a GNN to learn soft cluster assignments end-to-end.
  • 4Use hierarchical pooling when graphs have meaningful hierarchy: atoms -> functional groups -> molecule, individuals -> communities -> network.
  • 5For small or flat graphs, global pooling is simpler and often sufficient. Hierarchical pooling adds value when multi-scale structure carries task-relevant signal.

Hierarchical pooling progressively coarsens graphs into smaller representations. Instead of jumping from all node embeddings to a single graph vector in one step, hierarchical pooling reduces the graph gradually: 100 nodes become 50, then 25, then 10, before a final global readout. Each level captures structure at a different scale, similar to how convolutional neural networks build a spatial hierarchy in images.

This is particularly valuable for graphs with natural hierarchy: molecules have atoms that form functional groups that form the whole molecule. Social networks have individuals in communities in larger communities. Hierarchical pooling captures these multi-scale patterns.

Three pooling strategies

Node selection (TopKPooling, SAGPool)

Learn a score for each node and keep the top k%. Dropped nodes are removed from the graph. Edges between remaining nodes are preserved.

topk_pooling.py
import torch
from torch_geometric.nn import GCNConv, TopKPooling, global_mean_pool

class HierarchicalGNN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.pool1 = TopKPooling(hidden_dim, ratio=0.5)  # keep 50%
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.pool2 = TopKPooling(hidden_dim, ratio=0.5)  # keep 50% of remaining
        self.classifier = torch.nn.Linear(hidden_dim, out_dim)

    def forward(self, x, edge_index, batch):
        # Layer 1: message passing + pool
        x = self.conv1(x, edge_index).relu()
        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, batch=batch)
        # Graph now has 50% of original nodes

        # Layer 2: message passing + pool
        x = self.conv2(x, edge_index).relu()
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, batch=batch)
        # Graph now has 25% of original nodes

        # Final global readout
        graph_emb = global_mean_pool(x, batch)
        return self.classifier(graph_emb)

Alternating GNN layers with TopKPooling. Each pool step halves the graph. Two levels of hierarchy before global readout.

Node clustering (DiffPool)

A separate GNN learns to assign nodes to clusters. Nodes in the same cluster are merged into a super-node. The coarsened graph has one node per cluster:

diffpool_concept.py
# DiffPool: learn cluster assignments with a GNN
# For each pooling level:

# 1. GNN produces node embeddings
Z = gnn_embed(x, edge_index)  # [num_nodes, hidden_dim]

# 2. Another GNN produces soft cluster assignments
S = gnn_pool(x, edge_index)   # [num_nodes, num_clusters]
S = softmax(S, dim=-1)        # soft assignment matrix

# 3. Coarsen: aggregate nodes into clusters
X_coarse = S.T @ Z            # [num_clusters, hidden_dim]
A_coarse = S.T @ A @ S        # [num_clusters, num_clusters]

# The coarsened graph has num_clusters nodes
# Repeat at next level with fewer clusters

DiffPool learns which nodes to group. The assignment GNN is trained end-to-end with the task loss.

Enterprise example: circuit classification

An electronics manufacturer classifies circuit designs as passing or failing quality checks. Circuit graphs have natural hierarchy:

  • Level 0: individual components (resistors, capacitors, transistors)
  • Level 1: functional blocks (amplifiers, filters, regulators)
  • Level 2: subsystems (power supply, signal processing, control logic)
  • Level 3: complete circuit

Hierarchical pooling with 3 coarsening levels mirrors this structure. The model first learns component-level interactions, then block-level interactions, then subsystem-level interactions, before classifying the entire circuit. Defects at any level (a bad component, a poorly designed block, a subsystem integration issue) are captured at the appropriate hierarchy level.

When to use hierarchical vs global

  • Hierarchical: large graphs (100+ nodes), natural hierarchy, multi-scale patterns matter
  • Global: small graphs (<50 nodes), no clear hierarchy, simpler to implement and tune

Frequently asked questions

What is hierarchical pooling?

Hierarchical pooling progressively reduces the size of a graph by merging or selecting nodes at each layer. A graph with 100 nodes might be coarsened to 50, then 25, then 10, before a final global pooling step. This creates a multi-resolution view of the graph, similar to how CNNs progressively reduce spatial resolution in images.

What is the difference between hierarchical and global pooling?

Global pooling collapses the entire graph into one vector in a single step. Hierarchical pooling gradually reduces graph size across multiple steps, building a hierarchy of coarsened graphs. Hierarchical pooling preserves intermediate structural information that a single global aggregation would lose.

What hierarchical pooling methods does PyG offer?

PyG provides TopKPooling (select top-k nodes by learned score), SAGPooling (self-attention graph pooling), ASAPooling (adaptive structure-aware pooling), and support for DiffPool (differentiable soft clustering). TopKPooling is simplest; DiffPool is most expressive but requires knowing max graph size.

When should I use hierarchical pooling?

Use hierarchical pooling when the graph has meaningful hierarchical structure: molecules (atoms -> functional groups -> molecule), social networks (individuals -> communities -> network), or documents (words -> sentences -> document). If the graph is small or has no hierarchy, global pooling is simpler and often sufficient.

What is DiffPool?

DiffPool (Differentiable Pooling) uses a GNN to learn a soft assignment matrix that clusters nodes into super-nodes. It is fully differentiable and end-to-end trainable. The GNN learns which nodes should be grouped together based on the downstream task. It is the most expressive but requires a fixed maximum graph size.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.