Neighbor sampling randomly selects a fixed number of neighbors per node at each GNN layer, preventing the exponential neighborhood expansion that makes full-graph computation intractable on large graphs. Introduced by GraphSAGE (Hamilton et al., 2017), it is the most widely used scalability technique in graph neural networks. Without neighbor sampling, a 2-layer GNN on an enterprise graph with average degree 100 would touch 10,000 nodes for each target node. With neighbor sampling (k=15 per layer), it touches 225. This makes the difference between feasible training and out-of-memory errors.
Why it matters for enterprise data
Enterprise graphs have extreme degree variance. A popular product may have millions of purchase edges. A high-volume merchant has millions of transactions. A well-connected customer has thousands of interactions. Without neighbor sampling, computing the embedding for any of these high-degree nodes would require loading their entire neighborhood into memory.
Neighbor sampling bounds the computational cost per node to a constant, independent of actual degree. This makes training time predictable and memory usage controllable, which are requirements for production ML systems.
How neighbor sampling works
For a 2-layer GNN with neighbor sampling [k1, k2]:
- Start: Select a batch of target nodes (e.g., 512 customers to predict churn for)
- Layer 2 sampling: For each target node, randomly sample k2 (e.g., 10) 1-hop neighbors
- Layer 1 sampling: For each sampled 1-hop neighbor, randomly sample k1 (e.g., 15) 2-hop neighbors
- Compute: Run message passing bottom-up (layer 1 first, then layer 2) on the sampled subgraph
from torch_geometric.loader import NeighborLoader
# Create a neighbor-sampled dataloader
loader = NeighborLoader(
data,
num_neighbors=[15, 10], # 15 neighbors at hop 1, 10 at hop 2
batch_size=512, # 512 target nodes per batch
input_nodes=data.train_mask, # which nodes to predict on
shuffle=True,
)
# Each batch contains:
# - batch.x: features for all sampled nodes
# - batch.edge_index: edges in the sampled subgraph
# - batch.batch_size: number of target nodes (first 512 nodes)
# - batch.n_id: original node IDs for mapping back
model = GNN(in_channels=data.num_features, hidden_channels=64, out_channels=7)
for batch in loader:
out = model(batch.x, batch.edge_index)
# Only compute loss on target nodes (first batch_size nodes)
loss = criterion(out[:batch.batch_size], batch.y[:batch.batch_size])
loss.backward()
optimizer.step()NeighborLoader handles all sampling logic. The first batch.batch_size nodes in each batch are the targets. Loss is computed only on targets.
Concrete example: churn prediction at scale
A telecom company with 50M customers and 2B call/text edges:
- Full 2-hop neighborhood for a popular customer: ~1M nodes
- Sampled 2-hop neighborhood (k=[15,10]): ~150 nodes
- Mini-batch of 512 targets with sampled neighborhoods: ~50K total nodes
- GPU memory: ~2 GB (vs. ~500 GB for full neighborhoods)
Each epoch samples different random neighbors, so over multiple epochs, each node effectively sees its full neighborhood in expectation. Training completes in hours instead of being impossible.
Limitations and what comes next
- Variance: Different random samples produce slightly different gradients each epoch. This variance slows convergence slightly compared to full-neighbor computation but typically does not affect final accuracy.
- Rare neighbors matter: Uniform random sampling may miss rare but important neighbors (e.g., the one high-risk transaction among 10,000 normal ones). Importance sampling assigns higher probability to informative neighbors.
- Redundant computation: Neighboring target nodes share many sampled neighbors, leading to redundant feature lookups and message passing. Layer-wise sampling (sampling once per layer for all nodes) reduces this.