Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Layer8 min read

ClusterGCNConv: Scaling GNNs Through Graph Partitioning

ClusterGCNConv solves the scalability problem differently from SAGEConv: instead of sampling neighbors per node, it partitions the graph into clusters and trains on subgraphs. This is more compute-efficient and enables training deep GCNs on million-node graphs.

PyTorch Geometric

TL;DR

  • 1ClusterGCN partitions the graph into clusters (via METIS or similar), then trains GCN on clusters as mini-batches. Memory is bounded by cluster size, not graph size.
  • 2More compute-efficient than SAGEConv's per-node sampling. All nodes in a cluster share computation. But inter-cluster edges are lost within each batch.
  • 3Fix for edge loss: randomly combine multiple small clusters per mini-batch. This recovers inter-cluster edges and reduces gradient bias.
  • 4Use ClusterGCN for large homogeneous graphs (100K+ nodes) where full-batch training is infeasible. The partitioning strategy works with any GNN layer.
  • 5KumoRFM uses production-optimized graph partitioning and sampling strategies that combine the efficiency of ClusterGCN with the accuracy of full-graph training.

Original Paper

Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks

Chiang et al. (2019). KDD 2019

Read paper →

What ClusterGCNConv does

ClusterGCN follows a two-phase approach:

  1. Partition: Split the graph into K clusters using METIS or spectral clustering. This is done once before training.
  2. Train: At each step, sample one or more clusters, extract the subgraph (nodes + intra-cluster edges), and run standard GCN on the subgraph.

Because each cluster is small, the entire subgraph fits in GPU memory. The model trains on different clusters in each epoch, eventually seeing all nodes and most edges.

PyG implementation

cluster_gcn.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.loader import ClusterData, ClusterLoader

# Step 1: Partition graph into clusters
cluster_data = ClusterData(data, num_parts=1000)
loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True)
# batch_size=20 means 20 clusters per mini-batch

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden)
        self.conv2 = GCNConv(hidden, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        return self.conv2(x, edge_index)

model = GCN(dataset.num_features, 256, dataset.num_classes)

for batch in loader:
    # batch is a subgraph with ~20 clusters
    out = model(batch.x, batch.edge_index)
    loss = F.cross_entropy(out[batch.train_mask], batch.y[batch.train_mask])
    loss.backward()

ClusterData partitions once; ClusterLoader samples cluster groups each step. The GCN model itself is unchanged - the scaling comes from the data loader.

When to use ClusterGCNConv

  • Large homogeneous graphs. Reddit (232K nodes), OGB-Products (2.4M nodes) and similar large single-type graphs.
  • Deep GCN training. ClusterGCN was originally motivated by training deep (5+ layer) GCNs. The bounded memory per batch makes deep models feasible.
  • When compute efficiency matters. SAGEConv's neighbor sampling has redundant computation for overlapping neighborhoods. ClusterGCN avoids this by processing entire subgraphs.

When not to use ClusterGCNConv

  • When inter-cluster edges are critical. Graph structures where important connections span cluster boundaries (e.g., fraud networks with long-range connections) lose signal from cluster partitioning.
  • Inductive learning. The partitioning is fixed. New nodes require re-partitioning. SAGEConv handles new nodes naturally.

How KumoRFM builds on this

KumoRFM's training infrastructure combines insights from both ClusterGCN and neighbor sampling:

  • Intelligent partitioning that respects graph structure and minimizes edge loss across partition boundaries
  • Hybrid sampling that combines cluster-level efficiency with node-level flexibility for new entities
  • Production-scale optimization that trains on billion-node graphs with bounded memory and predictable latency

Frequently asked questions

What is ClusterGCNConv in PyTorch Geometric?

ClusterGCNConv implements the Cluster-GCN approach from Chiang et al. (2019). It first partitions the graph into clusters using an algorithm like METIS, then trains GCN on each cluster (or small groups of clusters) as mini-batches. This bounds memory usage and enables training on graphs that do not fit in GPU memory.

How does ClusterGCN differ from SAGEConv's neighbor sampling?

SAGEConv samples neighbors per node (node-wise sampling), which can lead to redundant computation across nodes in the same neighborhood. ClusterGCN partitions the entire graph into clusters (subgraph-wise sampling), processing all nodes in a cluster together. ClusterGCN is more compute-efficient but may miss inter-cluster edges.

What is the stochastic partition problem?

When training on a single cluster, edges between clusters are lost, which biases gradients. ClusterGCN addresses this by randomly combining multiple small clusters per mini-batch, recovering some inter-cluster edges. The more clusters per batch, the less bias, at the cost of larger batch size.

How many clusters should I partition into?

Partition into 1000-10000 clusters for large graphs and use 10-20 clusters per mini-batch. The number depends on GPU memory: each mini-batch should fit comfortably. More clusters means smaller batches but more inter-cluster edge loss.

Can ClusterGCN work with layers other than GCN?

Yes. The clustering strategy is independent of the GNN layer. You can use GATConv, SAGEConv, or any other layer within each cluster. PyG's ClusterData and ClusterLoader handle the partitioning; the layer inside is your choice.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.