Graph pooling is the operation that compresses all node representations in a graph into a single fixed-size vector, enabling predictions about the graph as a whole rather than individual nodes. After GNN message passing layers compute per-node embeddings, graph pooling aggregates these embeddings into one vector that captures the global structure and features of the entire graph. This pooled representation is then fed to a classifier or regressor for graph-level tasks.
Why it matters for enterprise data
Many enterprise prediction tasks operate at an entity level that spans multiple rows across multiple tables. Predicting account-level fraud risk requires considering all of a customer's transactions, the merchants involved, and the payment methods used. This forms a subgraph. Graph pooling converts that entire subgraph into a single vector that feeds a fraud classifier.
Without graph pooling, you have individual node embeddings but no mechanism to produce a single prediction for the entity. Pooling bridges per-node representations and entity-level decisions.
Three approaches to graph pooling
Global pooling (flat readout)
The simplest approach: apply a permutation-invariant function across all node embeddings at once.
- Mean pooling: Average all node embeddings. Normalizes for graph size. Most common default.
- Sum pooling: Sum all node embeddings. Preserves information about graph size (a graph with 100 nodes produces a larger vector than one with 10).
- Max pooling: Element-wise maximum across all node embeddings. Captures the most extreme signals.
Global pooling adds zero learnable parameters and works surprisingly well for most tasks.
Hierarchical pooling
Progressively coarsens the graph over multiple pooling steps. At each step, groups of nodes are merged into super-nodes, reducing the graph size. Methods include:
- DiffPool: Learns a soft assignment matrix that clusters nodes into super-nodes at each level. Powerful but memory-intensive (O(n^2) per layer).
- TopKPooling: Scores each node and keeps only the top-k fraction. Drops low-scoring nodes, reducing graph size.
- SAGPooling: Combines self-attention with top-k selection. Learns which nodes to keep based on the graph structure.
Attention pooling
Learns a weighted sum of node embeddings where the weights reflect each node's importance to the graph-level prediction:
- Set2Set: Uses an LSTM-based attention mechanism to iteratively refine the graph representation.
- GlobalAttention: A single attention layer that computes importance weights for all nodes.
Concrete example: molecular property prediction
Consider predicting whether a molecule is toxic. The molecule is a graph:
- Atom nodes: features = [element, charge, hybridization]
- Bond edges: features = [bond_type, is_conjugated, is_aromatic]
After 3 GCNConv layers, each atom has a 64-dimensional embedding that encodes its local chemical environment. Graph pooling compresses all atom embeddings into a single 64-dimensional vector representing the entire molecule. A linear classifier then predicts toxicity from this pooled vector.
PyG implementation
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, global_add_pool
class GraphClassifier(torch.nn.Module):
def __init__(self, num_features, hidden_dim, num_classes):
super().__init__()
self.conv1 = GCNConv(num_features, hidden_dim)
self.conv2 = GCNConv(hidden_dim, hidden_dim)
self.conv3 = GCNConv(hidden_dim, hidden_dim)
self.classifier = torch.nn.Linear(hidden_dim, num_classes)
def forward(self, x, edge_index, batch):
# Node-level message passing
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
x = self.conv3(x, edge_index)
# Graph-level pooling: collapse all nodes -> one vector per graph
x = global_mean_pool(x, batch) # shape: [num_graphs, hidden_dim]
# Classify entire graph
return self.classifier(x)
# The 'batch' tensor tells PyG which nodes belong to which graph
# In a batch of 32 graphs, batch[i] = graph index for node i
# global_mean_pool uses this to average nodes within each graphThe batch tensor is key: it maps each node to its graph in the mini-batch. global_mean_pool averages only within each graph.
Limitations and what comes next
- Information loss: Collapsing an entire graph into a fixed-size vector necessarily loses information. Large, complex graphs lose more than small, simple ones. Hierarchical pooling mitigates this by preserving multi-scale structure.
- Size sensitivity: Sum pooling is sensitive to graph size (larger graphs produce larger vectors). Mean pooling loses size information. Combining both (or using Set2Set) helps.
- Node importance: Global pooling treats all nodes equally. In practice, some nodes (e.g., hub accounts, high-value transactions) are far more informative for the graph-level prediction. Attention pooling addresses this but adds computational cost.
Graph transformers with virtual nodes offer an alternative: a special “graph node” that attends to all real nodes, effectively learning a task-specific pooling function end-to-end.