Global pooling summarizes an entire graph into a single fixed-size vector. After message passing produces a representation for each node, global pooling aggregates all node representations into one graph-level vector. A molecular graph with 30 atoms and a molecular graph with 100 atoms both produce the same-size output vector. This fixed-size representation can then be fed to a classifier or regressor for graph-level predictions.
Without global pooling, GNN outputs are per-node. You need global pooling whenever the prediction is about the entire graph: molecular toxicity, protein function, circuit performance, or subgraph classification.
Basic pooling methods
from torch_geometric.nn import global_mean_pool, global_add_pool, global_max_pool
# After GNN message passing:
# x.shape: [total_nodes_in_batch, hidden_dim]
# batch: [total_nodes_in_batch] mapping each node to its graph
# Mean: average all node embeddings per graph
graph_emb = global_mean_pool(x, batch) # [num_graphs, hidden_dim]
# Sum: sum all node embeddings per graph
graph_emb = global_add_pool(x, batch) # [num_graphs, hidden_dim]
# Max: element-wise max across all nodes per graph
graph_emb = global_max_pool(x, batch) # [num_graphs, hidden_dim]
# Concatenate multiple for richer representation
import torch
graph_emb = torch.cat([
global_mean_pool(x, batch),
global_add_pool(x, batch),
global_max_pool(x, batch),
], dim=-1) # [num_graphs, 3 * hidden_dim]Three basic pooling methods. Concatenating all three gives the richest graph-level representation.
Advanced pooling methods
- Set2Set: uses an LSTM with attention to iteratively read out the graph. Most expressive but slowest. Good for small-to-medium graphs where quality matters most.
- SortPooling: sorts nodes by their feature values, takes the top-k, then applies a 1D CNN. Captures the “most important” nodes while maintaining ordering.
- Virtual node: adds a special node connected to every other node. After message passing, this virtual node has aggregated information from all real nodes. Its embedding serves as the graph representation.
Enterprise example: molecular property prediction
A pharmaceutical company predicts molecular properties (solubility, toxicity, binding affinity) from molecular graphs:
import torch
from torch_geometric.nn import GINConv, global_add_pool
from torch.nn import Sequential, Linear, ReLU
class MoleculeClassifier(torch.nn.Module):
def __init__(self, num_features, hidden_dim, num_classes):
super().__init__()
nn1 = Sequential(Linear(num_features, hidden_dim), ReLU(),
Linear(hidden_dim, hidden_dim))
nn2 = Sequential(Linear(hidden_dim, hidden_dim), ReLU(),
Linear(hidden_dim, hidden_dim))
self.conv1 = GINConv(nn1)
self.conv2 = GINConv(nn2)
self.pool = global_add_pool # sum preserves molecule size
self.classifier = Linear(hidden_dim, num_classes)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
# Pool: variable-size molecule -> fixed-size vector
graph_emb = self.pool(x, batch)
return self.classifier(graph_emb)
# Input: batch of molecular graphs (different sizes)
# Output: [batch_size, num_classes] predictions per moleculeGIN + global_add_pool + classifier. The pooling step bridges variable-size molecules to fixed-size predictions.
The batch vector explained
PyG processes mini-batches by concatenating multiple graphs into one large disconnected graph. The batch vector tracks which node belongs to which original graph:
- Graph 0 has 3 nodes: batch = [0, 0, 0, ...]
- Graph 1 has 4 nodes: batch = [..., 1, 1, 1, 1, ...]
- Graph 2 has 2 nodes: batch = [..., 2, 2]
- Full batch vector: [0, 0, 0, 1, 1, 1, 1, 2, 2]
Global pooling uses this vector to aggregate per-graph. global_mean_pool(x, batch) averages nodes 0-2 for graph 0, nodes 3-6 for graph 1, and nodes 7-8 for graph 2, producing 3 graph-level vectors.