Batching is the first practical challenge anyone encounters when training GNNs. Images all have the same shape, so you stack them into a tensor: [batch_size, channels, height, width]. Graphs have different numbers of nodes and edges. Graph A might have 15 nodes and 20 edges. Graph B might have 150 nodes and 400 edges. You cannot stack these into a regular tensor.
The PyG solution: disjoint graph batching
PyG merges all graphs in a mini-batch into a single large graph where the individual graphs are disconnected components. Three operations:
- Concatenate node features: stack all node feature matrices vertically
- Offset edge indices: graph B's node indices are shifted by the number of nodes in graph A, so edges in B reference the correct nodes in the combined graph
- Create batch vector: a vector mapping each node to its source graph index
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_add_pool
# Dataset of small graphs (e.g., molecules)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in loader:
# batch.x: [total_nodes, features] (all 32 graphs' nodes)
# batch.edge_index: [2, total_edges] (all edges, offset)
# batch.batch: [total_nodes] (0,0,...,1,1,...,31,31,...)
# GNN layers process the merged graph
x = model.conv1(batch.x, batch.edge_index)
x = model.conv2(x, batch.edge_index)
# Per-graph readout using batch vector
graph_embeddings = global_add_pool(x, batch.batch)
# graph_embeddings: [32, hidden_dim] (one per graph)32 molecular graphs become one disconnected graph. GNN layers are unaware of the batching. The batch vector separates them for graph-level pooling.
Why disconnected merging works
The key insight: because the merged graphs are disconnected (no edges between them), message passing within each component is identical to processing each graph independently. Node 5 in graph A only receives messages from its neighbors in graph A. It never receives messages from nodes in graph B because there are no connecting edges.
Graph-level pooling with the batch vector
After message passing, you need one embedding per graph (for graph-level tasks like molecular property prediction). The batch vector enables this:
global_add_pool(x, batch): sum all node embeddings within each graphglobal_mean_pool(x, batch): average all node embeddings within each graphglobal_max_pool(x, batch): take the max across node embeddings within each graph
The batch vector acts as a GROUP BY: it groups nodes by their source graph and applies the pooling function within each group.
Large single graph: neighbor sampling
The disjoint-graph approach works for datasets of many small graphs (molecules, circuits). For a single large graph (social network with 100M nodes), you need a different strategy:
- NeighborLoader: samples a fixed number of neighbors at each hop for a batch of target nodes. Each mini-batch is a subgraph.
- ClusterLoader: pre-partitions the graph into clusters and samples entire clusters per batch.
- GraphSAINTSampler: samples random subgraphs with importance weighting.
These produce mini-batches that fit in GPU memory while preserving enough neighborhood context for message passing. The tradeoff: you see a subset of each node's neighborhood, introducing sampling variance that diminishes with larger samples.