Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Production9 min read

Scaling PyG to Billion-Node Graphs

Your enterprise graph has 2 billion nodes and 50 billion edges. PyG can train on it, but only if you rearchitect your pipeline from storage to serving. Here is the playbook.

PyTorch Geometric

TL;DR

  • 1A billion-node graph with 128-dim features needs 600+ GB of memory. No single machine handles this. You need distributed storage, sampling, and training.
  • 2Graph partitioning (METIS or hash-based) splits the graph across machines. Each partition stores a subset of nodes and their edges. Cross-partition edges require network communication.
  • 3Distributed NeighborLoader samples across partitions, fetching remote features on demand. Network latency becomes the bottleneck, not GPU compute.
  • 4Feature compression (quantization, dimensionality reduction) can reduce memory 4-8x with minimal accuracy loss. This is often the cheapest scaling win.

The billion-node memory wall

Let us do the math for a real enterprise graph:

  • 2B nodes x 128 float32 features = 1,024 GB (features alone)
  • 50B edges x 2 int64 indices = 800 GB (edge index)
  • Edge features, timestamps, masks = 200+ GB additional
  • Total: 2+ TB just for storage, before any training computation

A single A100 GPU has 80 GB of memory. Even a machine with 2 TB of CPU RAM struggles to load this graph. You need a distributed architecture.

Step 1: Graph partitioning

Split the graph across multiple machines. Two approaches:

partitioning.py
# Hash-based partitioning (simple, fast)
num_partitions = 16
partition_id = node_ids % num_partitions

# METIS partitioning (locality-preserving, slower)
import torch_geometric.transforms as T
from torch_geometric.utils import to_networkx
import metis

# METIS minimizes cross-partition edges
# This preserves local neighborhoods within partitions
# Trading off: O(V + E) preprocessing time for better locality

Hash partitioning is O(1) and trivially parallelizable. METIS partitioning preserves graph locality but requires O(V+E) preprocessing. For billion-node graphs, hash partitioning is usually the only practical option.

The critical metric is the edge cut ratio: the fraction of edges that cross partition boundaries. Hash partitioning cuts ~(1 - 1/P) of edges (93.75% with 16 partitions). METIS can reduce this to 10-30%, dramatically reducing network communication.

Step 2: Distributed feature storage

Node features live on the machine that owns each partition. When neighbor sampling crosses partition boundaries, features must be fetched over the network. This is the bottleneck.

  • Feature compression: Quantize float32 to float16 or int8. Reduces memory and network transfer 2-4x with less than 1% accuracy loss on most tasks.
  • Feature caching: Cache frequently accessed node features (hub nodes) on every machine. A small cache (1% of nodes) can serve 30-50% of feature lookups.
  • Dimensionality reduction: PCA or learned projections can reduce 128-dim features to 32 dims, cutting memory 4x. Train the projector on a subsample first.

Step 3: Distributed neighbor sampling

dist_sampling.py
# PyG's distributed sampling (conceptual)
# Note: torch_geometric.distributed was deprecated in PyG 2.7.
# For current best practices, see PyG distributed training tutorials.
from torch_geometric.distributed import DistNeighborLoader

# Each worker samples from its local partition
# Cross-partition neighbors trigger RPC calls
loader = DistNeighborLoader(
    graph_store=dist_graph_store,
    feature_store=dist_feature_store,
    num_neighbors=[10, 5],
    batch_size=512,
    # Sampling happens locally, features fetched remotely
)

Distributed sampling adds network latency per batch. Reduce fanout and increase batch size to amortize the overhead.

Step 4: Training at scale

With distributed sampling in place, training uses data-parallel GNN updates:

  • Each GPU receives sampled subgraphs from the distributed loader
  • Forward pass and loss computation happen locally on each GPU
  • Gradients are synchronized across GPUs via AllReduce
  • Model parameters are updated synchronously

Scale batch size linearly with GPU count and use learning rate warmup. GNN training is surprisingly sensitive to batch size changes because each batch sees a different subgraph topology.

What breaks at billion scale

  • Stragglers: One partition with more hub nodes takes longer to sample, stalling all other GPUs. Use load balancing based on node degree, not just node count.
  • Graph updates: Adding new nodes or edges to a distributed graph requires re-partitioning or incremental update protocols. Most teams rebuild nightly, which limits freshness.
  • Debugging: Distributed graph bugs are non-reproducible because sampling is stochastic and network-dependent. Log subgraph checksums per batch for reproducibility.

Frequently asked questions

Can PyG handle billion-node graphs?

Yes, but not in a single process. Billion-node graphs require distributed processing: graph partitioning across machines, remote feature stores, and distributed neighbor sampling. PyG provides the primitives (DistNeighborLoader, FeatureStore, GraphStore) but you must assemble the infrastructure yourself. Note: torch_geometric.distributed was deprecated in PyG 2.7.

How much memory does a billion-node graph need?

A billion nodes with 128-dim float32 features requires 512 GB for features alone. Edge indices for 10B edges add another ~80 GB. Total: 600+ GB minimum, far beyond any single GPU. You need distributed storage and sampling.

Should I partition the graph or use distributed sampling?

Use distributed sampling (DistNeighborLoader) for node-level tasks on a single massive graph. Use graph partitioning (METIS, random hash) when you need to distribute training compute across GPUs. Most production systems use both: partition for storage, sample within partitions for training.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.