Subgraph sampling extracts small, connected subgraphs from a large graph to create mini-batches for GNN training, enabling message passing on graphs that are too large to process at once. Full-graph training loads all nodes, edges, and features into GPU memory simultaneously. For enterprise relational databases with millions of customers and billions of transactions, this is impossible. Subgraph sampling divides the problem into manageable pieces while preserving enough local structure for effective learning.
Why it matters for enterprise data
Enterprise scale demands it. Consider a retail company with:
- 10M customer nodes
- 500M transaction edges
- 2M product nodes
- Each node has 50-100 features
Storing all node features alone requires ~50 GB. The adjacency structure adds another ~10 GB. A single forward pass through a 2-layer GNN with 128-dimensional hidden states generates intermediate tensors exceeding 100 GB. No single GPU can handle this.
Subgraph sampling reduces each mini-batch to ~10,000 nodes with their local edges, fitting comfortably in 2-4 GB of GPU memory. Training iterates over many subgraphs to cover the entire graph.
Main subgraph sampling methods
ClusterGCN
Partition the graph into clusters using graph partitioning (METIS). Each mini-batch contains one or a few clusters. Internal edges are preserved; cross-cluster edges are dropped. Simple and memory-efficient.
GraphSAINT
Sample subgraphs via random walks, random node selection, or random edge selection. Apply importance-based normalization to correct for sampling bias. More flexible than ClusterGCN and produces less biased gradients.
ShaDow-GNN
For each target node, extract its complete k-hop subgraph. Each mini-batch contains the subgraphs for a batch of target nodes. Preserves the exact computation that full-graph training would perform for each target node.
from torch_geometric.loader import ClusterData, ClusterLoader, NeighborLoader
# ClusterGCN: partition graph into 1000 clusters
cluster_data = ClusterData(data, num_parts=1000)
loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True)
# Each batch = 20 clusters merged into one subgraph
for batch in loader:
# batch.x: node features for this subgraph
# batch.edge_index: edges within these clusters
out = model(batch.x, batch.edge_index)
loss = criterion(out[batch.train_mask], batch.y[batch.train_mask])
# Standard backprop
# Alternative: NeighborLoader for subgraph-like sampling
loader = NeighborLoader(
data,
num_neighbors=[25, 15], # sample 25 neighbors at hop 1, 15 at hop 2
batch_size=512, # 512 target nodes per batch
input_nodes=data.train_mask,
)
for batch in loader:
out = model(batch.x, batch.edge_index)
loss = criterion(out[:batch.batch_size], batch.y[:batch.batch_size])ClusterLoader partitions the graph upfront. NeighborLoader samples dynamically per batch. Both create manageable subgraphs from enterprise-scale graphs.
Limitations and what comes next
- Cross-subgraph edges are lost: Edges connecting nodes in different subgraphs are dropped during training. This means some structural information is never seen. Overlapping partitions or historical embeddings mitigate this.
- Sampling bias: Random sampling methods may over-represent high-degree nodes or under-represent rare patterns. Importance sampling (GraphSAINT) corrects this but adds complexity.
- Feature staleness: Some methods cache node embeddings from previous iterations for nodes not in the current subgraph. These cached embeddings become stale as model weights update, introducing approximation error.