The billion-node memory wall
Let us do the math for a real enterprise graph:
- 2B nodes x 128 float32 features = 1,024 GB (features alone)
- 50B edges x 2 int64 indices = 800 GB (edge index)
- Edge features, timestamps, masks = 200+ GB additional
- Total: 2+ TB just for storage, before any training computation
A single A100 GPU has 80 GB of memory. Even a machine with 2 TB of CPU RAM struggles to load this graph. You need a distributed architecture.
Step 1: Graph partitioning
Split the graph across multiple machines. Two approaches:
# Hash-based partitioning (simple, fast)
num_partitions = 16
partition_id = node_ids % num_partitions
# METIS partitioning (locality-preserving, slower)
import torch_geometric.transforms as T
from torch_geometric.utils import to_networkx
import metis
# METIS minimizes cross-partition edges
# This preserves local neighborhoods within partitions
# Trading off: O(V + E) preprocessing time for better localityHash partitioning is O(1) and trivially parallelizable. METIS partitioning preserves graph locality but requires O(V+E) preprocessing. For billion-node graphs, hash partitioning is usually the only practical option.
The critical metric is the edge cut ratio: the fraction of edges that cross partition boundaries. Hash partitioning cuts ~(1 - 1/P) of edges (93.75% with 16 partitions). METIS can reduce this to 10-30%, dramatically reducing network communication.
Step 2: Distributed feature storage
Node features live on the machine that owns each partition. When neighbor sampling crosses partition boundaries, features must be fetched over the network. This is the bottleneck.
- Feature compression: Quantize float32 to float16 or int8. Reduces memory and network transfer 2-4x with less than 1% accuracy loss on most tasks.
- Feature caching: Cache frequently accessed node features (hub nodes) on every machine. A small cache (1% of nodes) can serve 30-50% of feature lookups.
- Dimensionality reduction: PCA or learned projections can reduce 128-dim features to 32 dims, cutting memory 4x. Train the projector on a subsample first.
Step 3: Distributed neighbor sampling
# PyG's distributed sampling (conceptual)
# Note: torch_geometric.distributed was deprecated in PyG 2.7.
# For current best practices, see PyG distributed training tutorials.
from torch_geometric.distributed import DistNeighborLoader
# Each worker samples from its local partition
# Cross-partition neighbors trigger RPC calls
loader = DistNeighborLoader(
graph_store=dist_graph_store,
feature_store=dist_feature_store,
num_neighbors=[10, 5],
batch_size=512,
# Sampling happens locally, features fetched remotely
)Distributed sampling adds network latency per batch. Reduce fanout and increase batch size to amortize the overhead.
Step 4: Training at scale
With distributed sampling in place, training uses data-parallel GNN updates:
- Each GPU receives sampled subgraphs from the distributed loader
- Forward pass and loss computation happen locally on each GPU
- Gradients are synchronized across GPUs via AllReduce
- Model parameters are updated synchronously
Scale batch size linearly with GPU count and use learning rate warmup. GNN training is surprisingly sensitive to batch size changes because each batch sees a different subgraph topology.
What breaks at billion scale
- Stragglers: One partition with more hub nodes takes longer to sample, stalling all other GPUs. Use load balancing based on node degree, not just node count.
- Graph updates: Adding new nodes or edges to a distributed graph requires re-partitioning or incremental update protocols. Most teams rebuild nightly, which limits freshness.
- Debugging: Distributed graph bugs are non-reproducible because sampling is stochastic and network-dependent. Log subgraph checksums per batch for reproducibility.