Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide8 min read

Neighbor Sampling: Randomly Sampling k Neighbors Per Node Per Layer

Neighbor sampling is the most widely used scalability technique for GNNs. It caps the number of neighbors each node aggregates from, preventing the exponential neighborhood explosion that makes full-graph computation intractable on large graphs.

PyTorch Geometric

TL;DR

  • 1Neighbor sampling randomly selects k neighbors per node per GNN layer (e.g., 15 at hop 1, 10 at hop 2). This caps computation at k1*k2 = 150 nodes per target, regardless of actual node degree.
  • 2Without sampling, a 2-layer GNN on a graph with average degree 100 touches 10,000 nodes per target. With sampling (k=15), it touches 225. This is the difference between feasible and impossible on enterprise graphs.
  • 3The stochastic approximation is unbiased: expected aggregation equals full-neighbor aggregation. Variance from random sampling acts as regularization.
  • 4Introduced by GraphSAGE. In PyG: NeighborLoader(data, num_neighbors=[15, 10], batch_size=512) handles everything. The first batch_size nodes in each batch are the targets.
  • 5Production standard: virtually all enterprise GNN deployments use neighbor sampling. It is the bridge between GNN algorithms and real-world graph sizes.

Neighbor sampling randomly selects a fixed number of neighbors per node at each GNN layer, preventing the exponential neighborhood expansion that makes full-graph computation intractable on large graphs. Introduced by GraphSAGE (Hamilton et al., 2017), it is the most widely used scalability technique in graph neural networks. Without neighbor sampling, a 2-layer GNN on an enterprise graph with average degree 100 would touch 10,000 nodes for each target node. With neighbor sampling (k=15 per layer), it touches 225. This makes the difference between feasible training and out-of-memory errors.

Why it matters for enterprise data

Enterprise graphs have extreme degree variance. A popular product may have millions of purchase edges. A high-volume merchant has millions of transactions. A well-connected customer has thousands of interactions. Without neighbor sampling, computing the embedding for any of these high-degree nodes would require loading their entire neighborhood into memory.

Neighbor sampling bounds the computational cost per node to a constant, independent of actual degree. This makes training time predictable and memory usage controllable, which are requirements for production ML systems.

How neighbor sampling works

For a 2-layer GNN with neighbor sampling [k1, k2]:

  1. Start: Select a batch of target nodes (e.g., 512 customers to predict churn for)
  2. Layer 2 sampling: For each target node, randomly sample k2 (e.g., 10) 1-hop neighbors
  3. Layer 1 sampling: For each sampled 1-hop neighbor, randomly sample k1 (e.g., 15) 2-hop neighbors
  4. Compute: Run message passing bottom-up (layer 1 first, then layer 2) on the sampled subgraph
neighbor_sampling.py
from torch_geometric.loader import NeighborLoader

# Create a neighbor-sampled dataloader
loader = NeighborLoader(
    data,
    num_neighbors=[15, 10],    # 15 neighbors at hop 1, 10 at hop 2
    batch_size=512,             # 512 target nodes per batch
    input_nodes=data.train_mask,  # which nodes to predict on
    shuffle=True,
)

# Each batch contains:
# - batch.x: features for all sampled nodes
# - batch.edge_index: edges in the sampled subgraph
# - batch.batch_size: number of target nodes (first 512 nodes)
# - batch.n_id: original node IDs for mapping back

model = GNN(in_channels=data.num_features, hidden_channels=64, out_channels=7)

for batch in loader:
    out = model(batch.x, batch.edge_index)
    # Only compute loss on target nodes (first batch_size nodes)
    loss = criterion(out[:batch.batch_size], batch.y[:batch.batch_size])
    loss.backward()
    optimizer.step()

NeighborLoader handles all sampling logic. The first batch.batch_size nodes in each batch are the targets. Loss is computed only on targets.

Concrete example: churn prediction at scale

A telecom company with 50M customers and 2B call/text edges:

  • Full 2-hop neighborhood for a popular customer: ~1M nodes
  • Sampled 2-hop neighborhood (k=[15,10]): ~150 nodes
  • Mini-batch of 512 targets with sampled neighborhoods: ~50K total nodes
  • GPU memory: ~2 GB (vs. ~500 GB for full neighborhoods)

Each epoch samples different random neighbors, so over multiple epochs, each node effectively sees its full neighborhood in expectation. Training completes in hours instead of being impossible.

Limitations and what comes next

  1. Variance: Different random samples produce slightly different gradients each epoch. This variance slows convergence slightly compared to full-neighbor computation but typically does not affect final accuracy.
  2. Rare neighbors matter: Uniform random sampling may miss rare but important neighbors (e.g., the one high-risk transaction among 10,000 normal ones). Importance sampling assigns higher probability to informative neighbors.
  3. Redundant computation: Neighboring target nodes share many sampled neighbors, leading to redundant feature lookups and message passing. Layer-wise sampling (sampling once per layer for all nodes) reduces this.

Frequently asked questions

What is neighbor sampling in GNNs?

Neighbor sampling randomly selects a fixed number of neighbors (e.g., 15) for each node at each GNN layer, rather than using all neighbors. This prevents the exponential growth in computation that occurs when a 2-layer GNN touches all neighbors at both hops. With degree d and 2 layers, full computation touches d^2 nodes. Neighbor sampling with k=15 touches only 15*15=225 nodes regardless of degree.

Why is neighbor sampling necessary?

In enterprise graphs, nodes can have thousands of neighbors (a popular merchant has millions of transactions). A 2-layer GNN without sampling would need to load millions of nodes for a single target node's computation. Neighbor sampling caps this at a manageable number (e.g., 225 for k=15 at 2 layers), making training and inference feasible on large graphs.

Does neighbor sampling hurt accuracy?

Minimal impact. Neighbor sampling introduces variance (different random samples each epoch) but not bias (the expected aggregation is the same as full-neighbor aggregation). With k=15-25 neighbors per layer, the stochastic approximation is close to the full computation. The variance acts as regularization, sometimes even improving generalization.

How does NeighborLoader work in PyG?

NeighborLoader takes a set of target nodes and, for each, samples k1 neighbors at hop 1, k2 neighbors at hop 2, etc. It returns a mini-batch containing all sampled nodes, their features, and the edges between them. The first batch_size nodes in the batch are the target nodes. This is the standard way to do mini-batch GNN training in PyG.

What is the neighborhood explosion problem?

Without sampling, a 2-layer GNN needs all 2-hop neighbors for each target node. If average degree is 100, that is 10,000 nodes per target. For a batch of 512 targets with overlapping neighborhoods, the actual number is lower but still can reach millions. Neighbor sampling caps per-node expansion at k, making total batch size predictable and bounded.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.