Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Production7 min read

Mini-Batching on Graphs: Why It's Different

In image deep learning, batching is trivial: stack tensors. In graph deep learning, batching is a design decision that affects accuracy, memory, and throughput. Here is what you need to know.

PyTorch Geometric

TL;DR

  • 1Graphs have irregular structure (different numbers of nodes and edges), so standard tensor stacking doesn't work. PyG solves this with the Batch object.
  • 2Graph-level batching merges multiple small graphs into one disconnected graph. Node-level batching samples subgraphs from one large graph. They use different loaders.
  • 3The batch vector maps each node to its source graph. This is essential for graph-level pooling and is the most common source of indexing bugs in production.
  • 4Variable-size batches cause GPU utilization spikes. Pad or bucket graphs by size to keep batch computation consistent.

The fundamental problem

In standard deep learning, a batch of 32 images is a tensor of shape [32, 3, 224, 224]. Every image has the same dimensions. In graph learning, a batch of 32 graphs might contain graphs with 10 nodes and graphs with 10,000 nodes. You cannot stack them into a regular tensor.

PyG solves this with a technique called graph merging: all graphs in the batch become one large disconnected graph, and a batch vector tracks which nodes belong to which original graph.

Graph-level batching

For tasks like molecular property prediction (graph classification), each sample is a complete graph. PyG’s DataLoader handles this automatically:

graph_batching.py
from torch_geometric.loader import DataLoader

# dataset contains many small graphs (e.g., molecules)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for batch in loader:
    # batch.x: all nodes from all 32 graphs, concatenated
    # batch.edge_index: all edges, with node indices offset
    # batch.batch: [0,0,0,...,1,1,...,31,31] mapping nodes to graphs

    out = model(batch.x, batch.edge_index, batch.batch)
    # out shape: [32, num_classes] (one prediction per graph)
    loss = criterion(out, batch.y)

PyG offsets edge indices automatically. Graph 0's nodes are 0..n0, graph 1's nodes are n0..n0+n1, etc. The batch vector tracks the mapping.

How merging works internally

When PyG batches three graphs with 5, 3, and 7 nodes:

  • Node features are concatenated: shape [15, feat_dim]
  • Edge indices for graph 1 are offset by 5, graph 2 by 8
  • The batch vector is [0,0,0,0,0, 1,1,1, 2,2,2,2,2,2,2]

This means the merged graph has no edges between the original graphs. Message passing within the merged graph is mathematically identical to processing each graph independently, but it leverages GPU parallelism by operating on one large sparse matrix.

Node-level batching

For tasks like node classification on a single large graph, you do not batch multiple graphs. Instead, you sample subgraphs around seed nodes using NeighborLoader:

node_batching.py
from torch_geometric.loader import NeighborLoader

loader = NeighborLoader(
    data,                       # one large graph
    num_neighbors=[15, 10],
    batch_size=1024,            # seed nodes per batch
    input_nodes=train_mask,
)

for batch in loader:
    out = model(batch.x, batch.edge_index)
    # Only compute loss on seed nodes (first batch_size entries)
    loss = criterion(
        out[:batch.batch_size],
        batch.y[:batch.batch_size],
    )

Critical: loss is only computed on the first batch_size nodes (the seeds). The rest are context nodes for message passing.

Variable-size batch problems

In graph-level batching, batch computation time varies with the total number of nodes and edges, not the number of graphs. A batch of 32 small molecules (100 nodes total) is 100x faster than a batch of 32 protein graphs (10,000 nodes total). This causes:

  • GPU underutilization: Small batches leave the GPU idle. Large batches cause OOM errors.
  • Training instability: Gradient magnitude varies with batch size, causing learning rate sensitivity.
  • Inference latency spikes: In serving, a batch with one large graph can spike p99 latency by 10x.

Solutions

  • Bucket by size: Sort graphs by node count and create batches from similar-sized graphs. This keeps GPU utilization consistent.
  • Dynamic batching: Set a maximum total-nodes budget per batch instead of a fixed graph count. Pack as many graphs as fit within the budget.
  • Padding: Pad all graphs to the maximum size in the batch. Wastes memory but makes computation predictable and enables torch.compile optimizations.

What breaks in production

  • Index arithmetic errors: Manual edge index offsetting is error-prone. Always use PyG’s built-in Batch.from_data_list() instead of manual concatenation.
  • Memory fragmentation: Variable-size batches cause GPU memory fragmentation over long training runs. Use torch.cuda.empty_cache() periodically or set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True.
  • Multi-worker serialization: PyG Data objects are not pickle-efficient by default. For large batches with num_workers > 0, pin_memory=True and persistent_workers=True significantly reduce data loading overhead.

Frequently asked questions

Why can't I use standard PyTorch DataLoader with graphs?

Standard DataLoader stacks tensors of the same shape. Graphs have different numbers of nodes and edges, so they can't be stacked. PyG's Batch object solves this by merging graphs into a single disconnected graph and tracking which nodes belong to which original graph via the batch vector.

What is the batch vector in PyG?

The batch vector is a tensor that maps each node to its source graph within a mini-batch. If you batch 3 graphs with 10, 15, and 12 nodes, the batch vector has 37 entries: [0,0,...,0,1,1,...,1,2,2,...,2]. Use it for graph-level pooling: scatter_mean(x, batch.batch, dim=0) gives one vector per graph.

How is node-level batching different from graph-level batching?

Graph-level batching (for graph classification) merges multiple small graphs into one disconnected graph using PyG's Batch. Node-level batching (for node classification on one large graph) uses NeighborLoader to sample subgraphs around seed nodes. They use different loaders and different loss computation patterns.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.