Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide7 min read

Global Pooling: Summarizing an Entire Graph into a Single Vector

Global pooling compresses a graph of any size into a fixed-length vector by aggregating all node representations. This is the bridge between node-level GNN outputs and graph-level prediction tasks like molecular property prediction.

PyTorch Geometric

TL;DR

  • 1Global pooling aggregates all node embeddings in a graph into one fixed-size vector. This enables graph-level predictions: is this molecule toxic? What is this protein's function?
  • 2Three basic methods: mean (size-invariant), sum (preserves size info), max (captures extreme features). Sum is often best because graph size carries signal.
  • 3Advanced methods: Set2Set (attention-based, most expressive), SortPooling (sort by feature value, apply 1D CNN), virtual node (add a node connected to all others).
  • 4The batch vector in PyG maps each node to its graph in a mini-batch. Global pooling uses this to aggregate per-graph even when all graphs are concatenated.
  • 5Global pooling is required for graph classification and regression. Node-level and edge-level tasks do not need it.

Global pooling summarizes an entire graph into a single fixed-size vector. After message passing produces a representation for each node, global pooling aggregates all node representations into one graph-level vector. A molecular graph with 30 atoms and a molecular graph with 100 atoms both produce the same-size output vector. This fixed-size representation can then be fed to a classifier or regressor for graph-level predictions.

Without global pooling, GNN outputs are per-node. You need global pooling whenever the prediction is about the entire graph: molecular toxicity, protein function, circuit performance, or subgraph classification.

Basic pooling methods

global_pooling_basics.py
from torch_geometric.nn import global_mean_pool, global_add_pool, global_max_pool

# After GNN message passing:
# x.shape: [total_nodes_in_batch, hidden_dim]
# batch: [total_nodes_in_batch] mapping each node to its graph

# Mean: average all node embeddings per graph
graph_emb = global_mean_pool(x, batch)  # [num_graphs, hidden_dim]

# Sum: sum all node embeddings per graph
graph_emb = global_add_pool(x, batch)   # [num_graphs, hidden_dim]

# Max: element-wise max across all nodes per graph
graph_emb = global_max_pool(x, batch)   # [num_graphs, hidden_dim]

# Concatenate multiple for richer representation
import torch
graph_emb = torch.cat([
    global_mean_pool(x, batch),
    global_add_pool(x, batch),
    global_max_pool(x, batch),
], dim=-1)  # [num_graphs, 3 * hidden_dim]

Three basic pooling methods. Concatenating all three gives the richest graph-level representation.

Advanced pooling methods

  • Set2Set: uses an LSTM with attention to iteratively read out the graph. Most expressive but slowest. Good for small-to-medium graphs where quality matters most.
  • SortPooling: sorts nodes by their feature values, takes the top-k, then applies a 1D CNN. Captures the “most important” nodes while maintaining ordering.
  • Virtual node: adds a special node connected to every other node. After message passing, this virtual node has aggregated information from all real nodes. Its embedding serves as the graph representation.

Enterprise example: molecular property prediction

A pharmaceutical company predicts molecular properties (solubility, toxicity, binding affinity) from molecular graphs:

molecular_classification.py
import torch
from torch_geometric.nn import GINConv, global_add_pool
from torch.nn import Sequential, Linear, ReLU

class MoleculeClassifier(torch.nn.Module):
    def __init__(self, num_features, hidden_dim, num_classes):
        super().__init__()
        nn1 = Sequential(Linear(num_features, hidden_dim), ReLU(),
                         Linear(hidden_dim, hidden_dim))
        nn2 = Sequential(Linear(hidden_dim, hidden_dim), ReLU(),
                         Linear(hidden_dim, hidden_dim))
        self.conv1 = GINConv(nn1)
        self.conv2 = GINConv(nn2)
        self.pool = global_add_pool  # sum preserves molecule size
        self.classifier = Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        # Pool: variable-size molecule -> fixed-size vector
        graph_emb = self.pool(x, batch)
        return self.classifier(graph_emb)

# Input: batch of molecular graphs (different sizes)
# Output: [batch_size, num_classes] predictions per molecule

GIN + global_add_pool + classifier. The pooling step bridges variable-size molecules to fixed-size predictions.

The batch vector explained

PyG processes mini-batches by concatenating multiple graphs into one large disconnected graph. The batch vector tracks which node belongs to which original graph:

  • Graph 0 has 3 nodes: batch = [0, 0, 0, ...]
  • Graph 1 has 4 nodes: batch = [..., 1, 1, 1, 1, ...]
  • Graph 2 has 2 nodes: batch = [..., 2, 2]
  • Full batch vector: [0, 0, 0, 1, 1, 1, 1, 2, 2]

Global pooling uses this vector to aggregate per-graph. global_mean_pool(x, batch) averages nodes 0-2 for graph 0, nodes 3-6 for graph 1, and nodes 7-8 for graph 2, producing 3 graph-level vectors.

Frequently asked questions

What is global pooling in GNNs?

Global pooling (also called graph-level readout) aggregates all node representations in a graph into a single fixed-size vector. This converts a variable-size graph (any number of nodes) into a fixed-size representation that can be fed to a classifier or regressor for graph-level prediction tasks.

What global pooling methods are available in PyG?

PyG provides: global_mean_pool (average all node vectors), global_add_pool (sum all node vectors), global_max_pool (element-wise max), global_sort_pool (sort-based pooling), and Set2Set (attention-based pooling). Mean is the default choice; sum preserves graph size information; attention-based methods are most expressive.

When do I need global pooling?

Whenever you need a graph-level prediction: classifying molecules as toxic/non-toxic, predicting protein function, scoring entire subgraphs. Node-level tasks (node classification, link prediction) do not need global pooling because predictions are made per-node or per-edge.

What is the difference between global_mean_pool and global_add_pool?

global_mean_pool averages node embeddings, making the result invariant to graph size. global_add_pool sums them, preserving size information (a graph with 100 nodes gets a larger vector magnitude than one with 10). Use sum when graph size is informative (bigger molecules have different properties than smaller ones).

What is the batch vector in PyG pooling?

When processing mini-batches of graphs, PyG concatenates all graphs into one large graph. The batch vector maps each node to its graph index (e.g., [0,0,0,1,1,1,1,2,2] means nodes 0-2 belong to graph 0, nodes 3-6 to graph 1, nodes 7-8 to graph 2). Global pooling uses this to aggregate per-graph.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.