Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide8 min read

Subgraph Sampling: Extracting Local Neighborhoods for Scalable Training

Enterprise graphs have millions of nodes and billions of edges. Subgraph sampling makes GNN training possible on these graphs by extracting manageable mini-batches that preserve local structure.

PyTorch Geometric

TL;DR

  • 1Subgraph sampling extracts small connected subgraphs from large graphs for mini-batch training. Each subgraph preserves local structure while fitting in GPU memory.
  • 2Main methods: ClusterGCN (partition graph, train on partitions), GraphSAINT (random walk/node/edge sampling with normalization), ShaDow-GNN (k-hop extraction around targets).
  • 3Necessary for enterprise data: graphs with millions of nodes and billions of edges cannot fit in GPU memory. Subgraph sampling enables training on graphs 100x larger than memory allows.
  • 4Accuracy trade-off: cross-subgraph edges are lost. With proper normalization, accuracy loss is 1-2% compared to full-graph training, while enabling dramatically larger graphs.
  • 5In PyG: ClusterLoader for partition-based sampling, or NeighborLoader with large neighborhood sizes for subgraph-like behavior. GraphSAINT is available via third-party extensions.

Subgraph sampling extracts small, connected subgraphs from a large graph to create mini-batches for GNN training, enabling message passing on graphs that are too large to process at once. Full-graph training loads all nodes, edges, and features into GPU memory simultaneously. For enterprise relational databases with millions of customers and billions of transactions, this is impossible. Subgraph sampling divides the problem into manageable pieces while preserving enough local structure for effective learning.

Why it matters for enterprise data

Enterprise scale demands it. Consider a retail company with:

  • 10M customer nodes
  • 500M transaction edges
  • 2M product nodes
  • Each node has 50-100 features

Storing all node features alone requires ~50 GB. The adjacency structure adds another ~10 GB. A single forward pass through a 2-layer GNN with 128-dimensional hidden states generates intermediate tensors exceeding 100 GB. No single GPU can handle this.

Subgraph sampling reduces each mini-batch to ~10,000 nodes with their local edges, fitting comfortably in 2-4 GB of GPU memory. Training iterates over many subgraphs to cover the entire graph.

Main subgraph sampling methods

ClusterGCN

Partition the graph into clusters using graph partitioning (METIS). Each mini-batch contains one or a few clusters. Internal edges are preserved; cross-cluster edges are dropped. Simple and memory-efficient.

GraphSAINT

Sample subgraphs via random walks, random node selection, or random edge selection. Apply importance-based normalization to correct for sampling bias. More flexible than ClusterGCN and produces less biased gradients.

ShaDow-GNN

For each target node, extract its complete k-hop subgraph. Each mini-batch contains the subgraphs for a batch of target nodes. Preserves the exact computation that full-graph training would perform for each target node.

subgraph_sampling.py
from torch_geometric.loader import ClusterData, ClusterLoader, NeighborLoader

# ClusterGCN: partition graph into 1000 clusters
cluster_data = ClusterData(data, num_parts=1000)
loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True)
# Each batch = 20 clusters merged into one subgraph

for batch in loader:
    # batch.x: node features for this subgraph
    # batch.edge_index: edges within these clusters
    out = model(batch.x, batch.edge_index)
    loss = criterion(out[batch.train_mask], batch.y[batch.train_mask])
    # Standard backprop

# Alternative: NeighborLoader for subgraph-like sampling
loader = NeighborLoader(
    data,
    num_neighbors=[25, 15],  # sample 25 neighbors at hop 1, 15 at hop 2
    batch_size=512,           # 512 target nodes per batch
    input_nodes=data.train_mask,
)

for batch in loader:
    out = model(batch.x, batch.edge_index)
    loss = criterion(out[:batch.batch_size], batch.y[:batch.batch_size])

ClusterLoader partitions the graph upfront. NeighborLoader samples dynamically per batch. Both create manageable subgraphs from enterprise-scale graphs.

Limitations and what comes next

  1. Cross-subgraph edges are lost: Edges connecting nodes in different subgraphs are dropped during training. This means some structural information is never seen. Overlapping partitions or historical embeddings mitigate this.
  2. Sampling bias: Random sampling methods may over-represent high-degree nodes or under-represent rare patterns. Importance sampling (GraphSAINT) corrects this but adds complexity.
  3. Feature staleness: Some methods cache node embeddings from previous iterations for nodes not in the current subgraph. These cached embeddings become stale as model weights update, introducing approximation error.

Frequently asked questions

What is subgraph sampling?

Subgraph sampling extracts small, connected subgraphs from a large graph to create mini-batches for GNN training. Instead of processing the entire graph at once (which may not fit in memory), the model trains on subgraphs that capture local structure. Each subgraph contains a target node and its k-hop neighborhood, or a cluster of related nodes.

How is subgraph sampling different from neighbor sampling?

Neighbor sampling (NeighborLoader) starts from target nodes and samples a fixed number of neighbors per layer, creating computation trees. Subgraph sampling (ClusterLoader, GraphSAINT) extracts entire connected subgraphs that include all internal edges. Neighbor sampling is node-centric; subgraph sampling is graph-centric. Subgraph methods preserve more graph structure but may include unnecessary nodes.

Why is subgraph sampling necessary for enterprise data?

Enterprise relational databases can have millions of nodes and billions of edges. Full-graph training requires loading everything into GPU memory, which is impossible for large graphs. Subgraph sampling creates manageable mini-batches (thousands of nodes each) that fit in GPU memory while preserving enough local structure for effective message passing.

What are the main subgraph sampling methods?

ClusterGCN: partition the graph using METIS, train on one partition per mini-batch. GraphSAINT: sample subgraphs via random walks, nodes, or edges with importance-based normalization. ShaDow-GNN: extract k-hop subgraphs around target nodes. Each method trades off between computational efficiency, memory usage, and approximation quality.

Does subgraph sampling affect model accuracy?

Subgraph sampling introduces approximation error because each mini-batch sees only a fraction of the graph. Edges between subgraphs are lost. However, with proper normalization (GraphSAINT) or sufficient overlap between partitions, the accuracy loss is minimal. On most enterprise tasks, subgraph-sampled training matches full-graph training within 1-2% accuracy while enabling graphs 100x larger.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.