Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide7 min read

Graph Mini-Batching: Why Batching Graphs is Different from Batching Images

You cannot stack graphs into a tensor the way you stack images. Graphs have different sizes, different numbers of edges, and different structures. PyG solves this with an elegant trick: merge the batch into a single disconnected graph.

PyTorch Geometric

TL;DR

  • 1Images have fixed sizes (224x224) and stack into regular tensors. Graphs have variable sizes (10 nodes or 10,000 nodes) and cannot be stacked. Standard batching does not work.
  • 2PyG's solution: merge all graphs in a batch into one large disconnected graph. Concatenate node features, offset edge indices, and add a batch vector mapping each node to its source graph.
  • 3The batch vector enables per-graph operations: global_add_pool(x, batch) produces one embedding per graph by summing nodes within each graph. Message passing works unchanged because the merged graphs are disconnected.
  • 4For datasets of many small graphs (molecules, circuits): use PyG's DataLoader with Batch.from_data_list(). For one large graph (social network): use NeighborLoader for subgraph sampling.
  • 5This batching is transparent to GNN layers. GCNConv, GATConv, and all other layers process the batched graph exactly as they would a single graph. The batch vector is only used for readout.

Batching is the first practical challenge anyone encounters when training GNNs. Images all have the same shape, so you stack them into a tensor: [batch_size, channels, height, width]. Graphs have different numbers of nodes and edges. Graph A might have 15 nodes and 20 edges. Graph B might have 150 nodes and 400 edges. You cannot stack these into a regular tensor.

The PyG solution: disjoint graph batching

PyG merges all graphs in a mini-batch into a single large graph where the individual graphs are disconnected components. Three operations:

  1. Concatenate node features: stack all node feature matrices vertically
  2. Offset edge indices: graph B's node indices are shifted by the number of nodes in graph A, so edges in B reference the correct nodes in the combined graph
  3. Create batch vector: a vector mapping each node to its source graph index
graph_batching.py
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_add_pool

# Dataset of small graphs (e.g., molecules)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for batch in loader:
    # batch.x: [total_nodes, features]  (all 32 graphs' nodes)
    # batch.edge_index: [2, total_edges]  (all edges, offset)
    # batch.batch: [total_nodes]  (0,0,...,1,1,...,31,31,...)

    # GNN layers process the merged graph
    x = model.conv1(batch.x, batch.edge_index)
    x = model.conv2(x, batch.edge_index)

    # Per-graph readout using batch vector
    graph_embeddings = global_add_pool(x, batch.batch)
    # graph_embeddings: [32, hidden_dim]  (one per graph)

32 molecular graphs become one disconnected graph. GNN layers are unaware of the batching. The batch vector separates them for graph-level pooling.

Why disconnected merging works

The key insight: because the merged graphs are disconnected (no edges between them), message passing within each component is identical to processing each graph independently. Node 5 in graph A only receives messages from its neighbors in graph A. It never receives messages from nodes in graph B because there are no connecting edges.

Graph-level pooling with the batch vector

After message passing, you need one embedding per graph (for graph-level tasks like molecular property prediction). The batch vector enables this:

  • global_add_pool(x, batch): sum all node embeddings within each graph
  • global_mean_pool(x, batch): average all node embeddings within each graph
  • global_max_pool(x, batch): take the max across node embeddings within each graph

The batch vector acts as a GROUP BY: it groups nodes by their source graph and applies the pooling function within each group.

Large single graph: neighbor sampling

The disjoint-graph approach works for datasets of many small graphs (molecules, circuits). For a single large graph (social network with 100M nodes), you need a different strategy:

  • NeighborLoader: samples a fixed number of neighbors at each hop for a batch of target nodes. Each mini-batch is a subgraph.
  • ClusterLoader: pre-partitions the graph into clusters and samples entire clusters per batch.
  • GraphSAINTSampler: samples random subgraphs with importance weighting.

These produce mini-batches that fit in GPU memory while preserving enough neighborhood context for message passing. The tradeoff: you see a subset of each node's neighborhood, introducing sampling variance that diminishes with larger samples.

Frequently asked questions

Why can't you batch graphs the same way as images?

Images have a fixed size (e.g., 224x224x3). You stack them into a tensor of shape [batch_size, 3, 224, 224]. Graphs have variable sizes: graph A has 10 nodes and 15 edges, graph B has 50 nodes and 80 edges. You cannot stack them into a regular tensor. Different padding strategies waste memory and complicate computation.

How does PyG batch graphs?

PyG merges all graphs in a batch into a single large disconnected graph. The node features are concatenated, edge indices are offset (so graph B's nodes start after graph A's), and a batch vector tracks which node belongs to which graph. This creates a single sparse graph that can be processed by standard GNN layers, with the batch vector enabling per-graph operations (like graph-level pooling).

What is the batch vector?

The batch vector is a 1D tensor of length total_nodes that maps each node to its source graph. If graph 0 has 10 nodes and graph 1 has 15 nodes, the batch vector is [0,0,0,...,0,1,1,1,...,1] (10 zeros followed by 15 ones). This enables per-graph operations: global_add_pool(x, batch) sums node features within each graph to produce graph-level embeddings.

Does graph mini-batching work for large single graphs?

For a single large graph (like a social network), you cannot batch multiple copies. Instead, you use neighbor sampling: sample a fixed number of neighbors at each hop and create subgraphs around target nodes. PyG's NeighborLoader implements this. Each mini-batch is a collection of subgraphs, not complete graphs. This is different from the disjoint-graph batching used for molecular datasets.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.