Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Production8 min read

Neighbor Sampling Strategies for Large Graphs

Your graph has 50 million nodes. Your GPU has 24 GB of memory. Neighbor sampling bridges the gap, but the default settings will waste 80% of your compute budget. Here is how to tune it.

PyTorch Geometric

TL;DR

  • 1NeighborLoader samples a fixed number of neighbors per node per layer, creating mini-batches that fit in GPU memory. This is the only way to train GNNs on graphs above ~1M nodes.
  • 2Fanout (neighbors per layer) is the key tuning parameter. Start with [15, 10] for 2-layer models. Each additional neighbor multiplies memory exponentially across layers.
  • 3The neighborhood explosion problem: with fanout [15, 10], each seed node pulls in up to 150 nodes (15 * 10). A batch of 1024 seeds can touch 150K nodes. This is the real memory bottleneck.
  • 4Use temporal sampling for production models: only sample neighbors whose edges existed before the prediction timestamp. Without this, you get temporal leakage.

Why you need neighbor sampling

A full-graph forward pass on a 50M-node graph with 128-dim features requires ~25 GB just for the feature matrix. Add edge indices, intermediate activations, and gradients, and you need 100+ GB. No single GPU handles this.

Neighbor sampling solves this by training on subgraphs. For each batch of “seed” nodes (the nodes you want predictions for), it samples a fixed number of neighbors at each GNN layer, creating a small computation tree that fits in memory.

NeighborLoader basics

neighbor_loader.py
from torch_geometric.loader import NeighborLoader

# Sample 15 neighbors at layer 1, 10 at layer 2
loader = NeighborLoader(
    data,
    num_neighbors=[15, 10],  # fanout per layer
    batch_size=1024,          # seed nodes per batch
    input_nodes=train_mask,   # which nodes to train on
    shuffle=True,
    num_workers=4,
)

for batch in loader:
    # batch is a subgraph containing:
    # - 1024 seed nodes (the targets)
    # - ~15K 1-hop neighbors
    # - ~150K 2-hop neighbors (sampled)
    out = model(batch.x, batch.edge_index)
    loss = criterion(out[:batch.batch_size], batch.y[:batch.batch_size])
    loss.backward()

The key detail: only compute loss on the first batch_size nodes (the seeds). The remaining nodes are context for message passing.

The neighborhood explosion problem

Neighbor sampling has exponential blowup across layers. With fanout [15, 10] and a 2-layer model, each seed node can touch up to 15 x 10 = 150 unique nodes. A batch of 1024 seeds can reference up to 153,600 nodes (though overlap reduces this in practice).

This exponential growth is why GNN models rarely use more than 3 layers in production. A 4-layer model with fanout [15, 10, 10, 5] means each seed potentially touches 7,500 nodes. A batch of 1024 seeds could reference 7.6M nodes, which defeats the purpose of sampling.

Tuning fanout for your graph

  • Start low: [10, 5] gives decent accuracy and fast training. Profile memory and accuracy before increasing.
  • Asymmetric fanout: Sample more at the first hop (closest neighbors matter most) and fewer at deeper hops. [20, 10, 5] outperforms [12, 12, 12] at the same total compute.
  • Degree-aware sampling: Nodes with 1000+ neighbors (hubs) get sampled uniformly by default, wasting quota on random edges. Use importance sampling to prioritize recent or high-weight edges.

Temporal neighbor sampling

In production, you are predicting the future using the past. If your graph includes edges from the future (events that have not happened yet at prediction time), your model will learn to cheat.

temporal_sampling.py
from torch_geometric.loader import NeighborLoader

# Add timestamps to edges
data.edge_time = edge_timestamps  # tensor of Unix timestamps

# Only sample edges before the prediction time
loader = NeighborLoader(
    data,
    num_neighbors=[15, 10],
    batch_size=1024,
    time_attr="edge_time",      # temporal filtering
    input_time=pred_timestamps,  # when each seed is being predicted
)

Temporal sampling ensures the model only sees edges that existed at prediction time. Without this, offline accuracy will be artificially high.

Alternative sampling strategies

  • ClusterGCN (ClusterLoader): Partitions the graph into clusters via METIS and trains on full clusters. Faster per-batch but loses cross-cluster edges. Best for homogeneous graphs where cluster structure is meaningful.
  • GraphSAINT: Samples subgraphs using random walks or node/edge sampling with importance weights. Better theoretical guarantees than NeighborLoader but more complex to implement.
  • ShaDow-GNN: Extracts fixed-size ego-graphs per node. Provides exact (not sampled) k-hop neighborhoods but requires more preprocessing.

What breaks in production

  • Stale samples: If you precompute sampled subgraphs for speed, they become stale as the graph updates. Sample fresh each epoch or implement incremental re-sampling.
  • Hub node bottleneck: Nodes with millions of edges (e.g., a popular product with 10M purchases) create enormous subgraphs regardless of fanout. Cap maximum degree or use importance-weighted sampling.
  • Worker contention: NeighborLoader with num_workers > 0 can create CPU-GPU transfer bottlenecks. Profile with pin_memory=True and tune num_workers to your CPU count.

Frequently asked questions

What is neighbor sampling in PyG?

Neighbor sampling is a technique for training GNNs on graphs too large to fit in GPU memory. Instead of loading the full graph, PyG's NeighborLoader samples a fixed number of neighbors per node per layer, creating small subgraphs (mini-batches) that fit in memory. This is essential for any graph with more than ~1M nodes.

How do I choose the right fanout for NeighborLoader?

Fanout defines how many neighbors to sample per layer. Common starting points: [15, 10] for 2-layer models or [20, 15, 10] for 3-layer models. Higher fanout = better accuracy but more memory. Profile your GPU memory and reduce fanout if you hit OOM errors. For production, [10, 5] often gives 95% of the accuracy at 3x the throughput.

What is the difference between NeighborLoader and ClusterLoader?

NeighborLoader samples a fixed number of neighbors per seed node (node-centric). ClusterLoader partitions the graph into clusters and trains on entire clusters (graph-centric). NeighborLoader gives better per-node accuracy but is slower per batch. ClusterLoader is faster but introduces cluster-boundary artifacts. Use NeighborLoader for node-level tasks and ClusterLoader when throughput matters more than accuracy.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.