Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide8 min read

Graph Pooling: Reducing Graph Representations to Fixed-Size Vectors

GNN layers produce per-node representations. Graph pooling collapses them into a single vector that represents the entire graph. This is necessary for any task where the prediction target is a graph, not a node.

PyTorch Geometric

TL;DR

  • 1Graph pooling aggregates all node representations into a single fixed-size vector representing the entire graph. It is the graph-level equivalent of global average pooling in CNNs.
  • 2Three main approaches: global pooling (sum/mean/max all nodes at once), hierarchical pooling (progressively coarsen the graph), and attention pooling (learn which nodes matter most for the graph-level task).
  • 3Global mean pooling is the default starting point. It is simple, parameter-free, and effective. Use attention pooling when some nodes are far more important than others for the prediction.
  • 4In enterprise data, graph pooling lets you make entity-level predictions from subgraphs: pool a customer's transaction subgraph for account-level risk, pool a supply chain subgraph for delay prediction.
  • 5In PyG: global_mean_pool(x, batch) handles batching automatically. For hierarchical pooling, use TopKPooling or SAGPooling layers between GNN layers.

Graph pooling is the operation that compresses all node representations in a graph into a single fixed-size vector, enabling predictions about the graph as a whole rather than individual nodes. After GNN message passing layers compute per-node embeddings, graph pooling aggregates these embeddings into one vector that captures the global structure and features of the entire graph. This pooled representation is then fed to a classifier or regressor for graph-level tasks.

Why it matters for enterprise data

Many enterprise prediction tasks operate at an entity level that spans multiple rows across multiple tables. Predicting account-level fraud risk requires considering all of a customer's transactions, the merchants involved, and the payment methods used. This forms a subgraph. Graph pooling converts that entire subgraph into a single vector that feeds a fraud classifier.

Without graph pooling, you have individual node embeddings but no mechanism to produce a single prediction for the entity. Pooling bridges per-node representations and entity-level decisions.

Three approaches to graph pooling

Global pooling (flat readout)

The simplest approach: apply a permutation-invariant function across all node embeddings at once.

  • Mean pooling: Average all node embeddings. Normalizes for graph size. Most common default.
  • Sum pooling: Sum all node embeddings. Preserves information about graph size (a graph with 100 nodes produces a larger vector than one with 10).
  • Max pooling: Element-wise maximum across all node embeddings. Captures the most extreme signals.

Global pooling adds zero learnable parameters and works surprisingly well for most tasks.

Hierarchical pooling

Progressively coarsens the graph over multiple pooling steps. At each step, groups of nodes are merged into super-nodes, reducing the graph size. Methods include:

  • DiffPool: Learns a soft assignment matrix that clusters nodes into super-nodes at each level. Powerful but memory-intensive (O(n^2) per layer).
  • TopKPooling: Scores each node and keeps only the top-k fraction. Drops low-scoring nodes, reducing graph size.
  • SAGPooling: Combines self-attention with top-k selection. Learns which nodes to keep based on the graph structure.

Attention pooling

Learns a weighted sum of node embeddings where the weights reflect each node's importance to the graph-level prediction:

  • Set2Set: Uses an LSTM-based attention mechanism to iteratively refine the graph representation.
  • GlobalAttention: A single attention layer that computes importance weights for all nodes.

Concrete example: molecular property prediction

Consider predicting whether a molecule is toxic. The molecule is a graph:

  • Atom nodes: features = [element, charge, hybridization]
  • Bond edges: features = [bond_type, is_conjugated, is_aromatic]

After 3 GCNConv layers, each atom has a 64-dimensional embedding that encodes its local chemical environment. Graph pooling compresses all atom embeddings into a single 64-dimensional vector representing the entire molecule. A linear classifier then predicts toxicity from this pooled vector.

PyG implementation

graph_pooling_pyg.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, global_add_pool

class GraphClassifier(torch.nn.Module):
    def __init__(self, num_features, hidden_dim, num_classes):
        super().__init__()
        self.conv1 = GCNConv(num_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        self.classifier = torch.nn.Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index, batch):
        # Node-level message passing
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)

        # Graph-level pooling: collapse all nodes -> one vector per graph
        x = global_mean_pool(x, batch)  # shape: [num_graphs, hidden_dim]

        # Classify entire graph
        return self.classifier(x)

# The 'batch' tensor tells PyG which nodes belong to which graph
# In a batch of 32 graphs, batch[i] = graph index for node i
# global_mean_pool uses this to average nodes within each graph

The batch tensor is key: it maps each node to its graph in the mini-batch. global_mean_pool averages only within each graph.

Limitations and what comes next

  1. Information loss: Collapsing an entire graph into a fixed-size vector necessarily loses information. Large, complex graphs lose more than small, simple ones. Hierarchical pooling mitigates this by preserving multi-scale structure.
  2. Size sensitivity: Sum pooling is sensitive to graph size (larger graphs produce larger vectors). Mean pooling loses size information. Combining both (or using Set2Set) helps.
  3. Node importance: Global pooling treats all nodes equally. In practice, some nodes (e.g., hub accounts, high-value transactions) are far more informative for the graph-level prediction. Attention pooling addresses this but adds computational cost.

Graph transformers with virtual nodes offer an alternative: a special “graph node” that attends to all real nodes, effectively learning a task-specific pooling function end-to-end.

Frequently asked questions

What is graph pooling?

Graph pooling is the operation that compresses all node representations in a graph into a single fixed-size vector. It is the graph equivalent of global average pooling in CNNs. This pooled vector represents the entire graph and is fed to a classifier or regressor for graph-level predictions. Without pooling, you have per-node outputs but no way to make predictions about the graph as a whole.

What is the difference between global and hierarchical pooling?

Global pooling aggregates all node representations in one step (sum, mean, or max across all nodes). It is simple and effective. Hierarchical pooling progressively coarsens the graph over multiple steps, merging clusters of nodes into super-nodes. Hierarchical methods (DiffPool, SAGPool) can capture multi-scale structure but add complexity and computational cost.

When do you need graph pooling?

Graph pooling is required for graph-level tasks: graph classification (is this molecular graph toxic?), graph regression (predict a molecule's property), or any task where the prediction target is the entire graph rather than individual nodes or edges. For node-level or edge-level tasks, pooling is not needed.

How does graph pooling apply to enterprise data?

In enterprise settings, graph pooling lets you make predictions about entire subgraphs. For example, pool all nodes in a customer's transaction subgraph to predict account-level risk. Pool all nodes in a supply chain subgraph to predict delivery delay. Pool all interactions in a support ticket cluster to classify issue severity.

Which graph pooling method should I use?

Start with global mean pooling (global_mean_pool in PyG). It works well for most tasks and adds zero learnable parameters. If your graphs vary significantly in size, try combining sum and mean via Set2Set or attention pooling. Hierarchical pooling (DiffPool) helps when multi-scale structure matters, but it is harder to train and slower.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.