The fundamental problem
In standard deep learning, a batch of 32 images is a tensor of shape [32, 3, 224, 224]. Every image has the same dimensions. In graph learning, a batch of 32 graphs might contain graphs with 10 nodes and graphs with 10,000 nodes. You cannot stack them into a regular tensor.
PyG solves this with a technique called graph merging: all graphs in the batch become one large disconnected graph, and a batch vector tracks which nodes belong to which original graph.
Graph-level batching
For tasks like molecular property prediction (graph classification), each sample is a complete graph. PyG’s DataLoader handles this automatically:
from torch_geometric.loader import DataLoader
# dataset contains many small graphs (e.g., molecules)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in loader:
# batch.x: all nodes from all 32 graphs, concatenated
# batch.edge_index: all edges, with node indices offset
# batch.batch: [0,0,0,...,1,1,...,31,31] mapping nodes to graphs
out = model(batch.x, batch.edge_index, batch.batch)
# out shape: [32, num_classes] (one prediction per graph)
loss = criterion(out, batch.y)PyG offsets edge indices automatically. Graph 0's nodes are 0..n0, graph 1's nodes are n0..n0+n1, etc. The batch vector tracks the mapping.
How merging works internally
When PyG batches three graphs with 5, 3, and 7 nodes:
- Node features are concatenated: shape [15, feat_dim]
- Edge indices for graph 1 are offset by 5, graph 2 by 8
- The batch vector is [0,0,0,0,0, 1,1,1, 2,2,2,2,2,2,2]
This means the merged graph has no edges between the original graphs. Message passing within the merged graph is mathematically identical to processing each graph independently, but it leverages GPU parallelism by operating on one large sparse matrix.
Node-level batching
For tasks like node classification on a single large graph, you do not batch multiple graphs. Instead, you sample subgraphs around seed nodes using NeighborLoader:
from torch_geometric.loader import NeighborLoader
loader = NeighborLoader(
data, # one large graph
num_neighbors=[15, 10],
batch_size=1024, # seed nodes per batch
input_nodes=train_mask,
)
for batch in loader:
out = model(batch.x, batch.edge_index)
# Only compute loss on seed nodes (first batch_size entries)
loss = criterion(
out[:batch.batch_size],
batch.y[:batch.batch_size],
)Critical: loss is only computed on the first batch_size nodes (the seeds). The rest are context nodes for message passing.
Variable-size batch problems
In graph-level batching, batch computation time varies with the total number of nodes and edges, not the number of graphs. A batch of 32 small molecules (100 nodes total) is 100x faster than a batch of 32 protein graphs (10,000 nodes total). This causes:
- GPU underutilization: Small batches leave the GPU idle. Large batches cause OOM errors.
- Training instability: Gradient magnitude varies with batch size, causing learning rate sensitivity.
- Inference latency spikes: In serving, a batch with one large graph can spike p99 latency by 10x.
Solutions
- Bucket by size: Sort graphs by node count and create batches from similar-sized graphs. This keeps GPU utilization consistent.
- Dynamic batching: Set a maximum total-nodes budget per batch instead of a fixed graph count. Pack as many graphs as fit within the budget.
- Padding: Pad all graphs to the maximum size in the batch. Wastes memory but makes computation predictable and enables torch.compile optimizations.
What breaks in production
- Index arithmetic errors: Manual edge index offsetting is error-prone. Always use PyG’s built-in Batch.from_data_list() instead of manual concatenation.
- Memory fragmentation: Variable-size batches cause GPU memory fragmentation over long training runs. Use torch.cuda.empty_cache() periodically or set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True.
- Multi-worker serialization: PyG Data objects are not pickle-efficient by default. For large batches with num_workers > 0, pin_memory=True and persistent_workers=True significantly reduce data loading overhead.