Why you need neighbor sampling
A full-graph forward pass on a 50M-node graph with 128-dim features requires ~25 GB just for the feature matrix. Add edge indices, intermediate activations, and gradients, and you need 100+ GB. No single GPU handles this.
Neighbor sampling solves this by training on subgraphs. For each batch of “seed” nodes (the nodes you want predictions for), it samples a fixed number of neighbors at each GNN layer, creating a small computation tree that fits in memory.
NeighborLoader basics
from torch_geometric.loader import NeighborLoader
# Sample 15 neighbors at layer 1, 10 at layer 2
loader = NeighborLoader(
data,
num_neighbors=[15, 10], # fanout per layer
batch_size=1024, # seed nodes per batch
input_nodes=train_mask, # which nodes to train on
shuffle=True,
num_workers=4,
)
for batch in loader:
# batch is a subgraph containing:
# - 1024 seed nodes (the targets)
# - ~15K 1-hop neighbors
# - ~150K 2-hop neighbors (sampled)
out = model(batch.x, batch.edge_index)
loss = criterion(out[:batch.batch_size], batch.y[:batch.batch_size])
loss.backward()The key detail: only compute loss on the first batch_size nodes (the seeds). The remaining nodes are context for message passing.
The neighborhood explosion problem
Neighbor sampling has exponential blowup across layers. With fanout [15, 10] and a 2-layer model, each seed node can touch up to 15 x 10 = 150 unique nodes. A batch of 1024 seeds can reference up to 153,600 nodes (though overlap reduces this in practice).
This exponential growth is why GNN models rarely use more than 3 layers in production. A 4-layer model with fanout [15, 10, 10, 5] means each seed potentially touches 7,500 nodes. A batch of 1024 seeds could reference 7.6M nodes, which defeats the purpose of sampling.
Tuning fanout for your graph
- Start low: [10, 5] gives decent accuracy and fast training. Profile memory and accuracy before increasing.
- Asymmetric fanout: Sample more at the first hop (closest neighbors matter most) and fewer at deeper hops. [20, 10, 5] outperforms [12, 12, 12] at the same total compute.
- Degree-aware sampling: Nodes with 1000+ neighbors (hubs) get sampled uniformly by default, wasting quota on random edges. Use importance sampling to prioritize recent or high-weight edges.
Temporal neighbor sampling
In production, you are predicting the future using the past. If your graph includes edges from the future (events that have not happened yet at prediction time), your model will learn to cheat.
from torch_geometric.loader import NeighborLoader
# Add timestamps to edges
data.edge_time = edge_timestamps # tensor of Unix timestamps
# Only sample edges before the prediction time
loader = NeighborLoader(
data,
num_neighbors=[15, 10],
batch_size=1024,
time_attr="edge_time", # temporal filtering
input_time=pred_timestamps, # when each seed is being predicted
)Temporal sampling ensures the model only sees edges that existed at prediction time. Without this, offline accuracy will be artificially high.
Alternative sampling strategies
- ClusterGCN (ClusterLoader): Partitions the graph into clusters via METIS and trains on full clusters. Faster per-batch but loses cross-cluster edges. Best for homogeneous graphs where cluster structure is meaningful.
- GraphSAINT: Samples subgraphs using random walks or node/edge sampling with importance weights. Better theoretical guarantees than NeighborLoader but more complex to implement.
- ShaDow-GNN: Extracts fixed-size ego-graphs per node. Provides exact (not sampled) k-hop neighborhoods but requires more preprocessing.
What breaks in production
- Stale samples: If you precompute sampled subgraphs for speed, they become stale as the graph updates. Sample fresh each epoch or implement incremental re-sampling.
- Hub node bottleneck: Nodes with millions of edges (e.g., a popular product with 10M purchases) create enormous subgraphs regardless of fanout. Cap maximum degree or use importance-weighted sampling.
- Worker contention: NeighborLoader with num_workers > 0 can create CPU-GPU transfer bottlenecks. Profile with pin_memory=True and tune num_workers to your CPU count.