How torch.compile helps GNNs
GNN forward passes consist of many small operations: sparse matrix multiplications, scatter-add aggregations, activation functions, and dropout. Each operation launches a separate GPU kernel with Python overhead between them. On a 2-layer GCN with 128-dim hidden channels, Python overhead can be 40% of total time.
torch.compile traces the computation graph and fuses these operations into fewer, larger kernels. It also eliminates Python interpreter overhead entirely for the traced portion.
Basic usage
import torch
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
class GCN(torch.nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch):
super().__init__()
self.conv1 = GCNConv(in_ch, hidden_ch)
self.conv2 = GCNConv(hidden_ch, out_ch)
def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.5, training=self.training)
return self.conv2(x, edge_index)
model = GCN(1433, 128, 7)
# One line to compile
model = torch.compile(model)
# First forward pass triggers compilation (30-120s)
# Subsequent passes are 30-35% fastertorch.compile is a single line. The compilation cost is amortized over the full training run.
Compilation modes
- mode="default": Best for most GNN training. Balances compilation time with speedup. Handles most PyG layers correctly.
- mode="reduce-overhead": Uses CUDA graphs to eliminate kernel launch overhead. Better for small models where launch overhead dominates. Requires static shapes.
- mode="max-autotune": Tries many kernel implementations and picks the fastest. Compilation takes 10-30 minutes but gives an additional 5-10% speedup. Use for production inference kernels.
Common pitfalls with PyG
1. Dynamic shapes trigger recompilation
Each time torch.compile sees a new tensor shape, it recompiles. With NeighborLoader, every batch has a different number of sampled nodes, triggering recompilation every batch.
# Problem: different batch sizes trigger recompilation
# Solution: pad batches to fixed sizes
def pad_batch(batch, max_nodes=10000):
"""Pad batch to fixed size for torch.compile compatibility."""
n = batch.x.size(0)
if n < max_nodes:
pad = torch.zeros(max_nodes - n, batch.x.size(1))
batch.x = torch.cat([batch.x, pad])
# Mark padded nodes so they're ignored in loss
batch.pad_mask = torch.cat([
torch.ones(n, dtype=torch.bool),
torch.zeros(max_nodes - n, dtype=torch.bool),
])
return batch
# Or: use dynamic=True (PyTorch 2.1+)
model = torch.compile(model, dynamic=True)dynamic=True tells the compiler to generate shape-generic code. It's slower than static shapes but avoids recompilation.
2. SparseTensor compatibility
PyG’s SparseTensor format (used by some layers for efficiency) is not fully compatible with torch.compile. Stick to edge_index (COO format) when using compilation.
3. Heterogeneous models
HeteroConv and typed dispatch create dynamic control flow that torch.compile struggles with. For heterogeneous models, compile individual message passing functions rather than the full model.
Benchmark results
- GCNConv (Cora, 2 layers): 34% faster training, 48% faster inference. The highest gains because GCN is the simplest layer with the most fusible operations.
- GATConv (Cora, 2 layers): 28% faster training, 42% faster inference. Attention computation limits fusion opportunities.
- SAGEConv (Reddit, 3 layers + NeighborLoader):22% faster training (dynamic shapes reduce gains), 38% faster inference with static padding.
How KumoRFM optimizes performance
KumoRFM goes beyond torch.compile with custom CUDA kernels:
- Fused sparse attention kernels for relational graph transformers
- Optimized temporal neighbor sampling on GPU (not CPU)
- Custom scatter kernels for heterogeneous aggregation
- Automatic mixed precision with per-layer precision tuning
These optimizations deliver 3-5x throughput over compiled vanilla PyG on the same hardware, enabling real-time inference on graphs that would otherwise require batch processing.