Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Production7 min read

Speeding Up PyG with torch.compile (30-35% Gains)

torch.compile fuses kernels and eliminates Python overhead. For GNNs, this means 30-35% faster training with one line of code, if you navigate the compatibility landmines.

PyTorch Geometric

TL;DR

  • 1torch.compile gives 30-35% training speedup and 40-50% inference speedup for standard GNN layers (GCNConv, SAGEConv, GATConv, GINConv).
  • 2The first compilation pass takes 30-120 seconds. After that, every forward/backward pass is faster. Use torch.compile for training runs longer than 5 minutes.
  • 3Dynamic shapes (variable batch sizes, irregular graphs) trigger recompilation. Pad batches to fixed sizes or use mode='reduce-overhead' to handle this.
  • 4Heterogeneous layers and some PyG-specific operations (scatter with variable output, SparseTensor) have partial support. Test before deploying.
  • 5KumoRFM uses highly optimized CUDA kernels that go beyond torch.compile, including custom sparse attention and temporal sampling kernels tuned for relational graph transformers.

How torch.compile helps GNNs

GNN forward passes consist of many small operations: sparse matrix multiplications, scatter-add aggregations, activation functions, and dropout. Each operation launches a separate GPU kernel with Python overhead between them. On a 2-layer GCN with 128-dim hidden channels, Python overhead can be 40% of total time.

torch.compile traces the computation graph and fuses these operations into fewer, larger kernels. It also eliminates Python interpreter overhead entirely for the traced portion.

Basic usage

compile_gnn.py
import torch
from torch_geometric.nn import GCNConv
import torch.nn.functional as F

class GCN(torch.nn.Module):
    def __init__(self, in_ch, hidden_ch, out_ch):
        super().__init__()
        self.conv1 = GCNConv(in_ch, hidden_ch)
        self.conv2 = GCNConv(hidden_ch, out_ch)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        return self.conv2(x, edge_index)

model = GCN(1433, 128, 7)

# One line to compile
model = torch.compile(model)

# First forward pass triggers compilation (30-120s)
# Subsequent passes are 30-35% faster

torch.compile is a single line. The compilation cost is amortized over the full training run.

Compilation modes

  • mode="default": Best for most GNN training. Balances compilation time with speedup. Handles most PyG layers correctly.
  • mode="reduce-overhead": Uses CUDA graphs to eliminate kernel launch overhead. Better for small models where launch overhead dominates. Requires static shapes.
  • mode="max-autotune": Tries many kernel implementations and picks the fastest. Compilation takes 10-30 minutes but gives an additional 5-10% speedup. Use for production inference kernels.

Common pitfalls with PyG

1. Dynamic shapes trigger recompilation

Each time torch.compile sees a new tensor shape, it recompiles. With NeighborLoader, every batch has a different number of sampled nodes, triggering recompilation every batch.

fix_dynamic_shapes.py
# Problem: different batch sizes trigger recompilation
# Solution: pad batches to fixed sizes

def pad_batch(batch, max_nodes=10000):
    """Pad batch to fixed size for torch.compile compatibility."""
    n = batch.x.size(0)
    if n < max_nodes:
        pad = torch.zeros(max_nodes - n, batch.x.size(1))
        batch.x = torch.cat([batch.x, pad])
        # Mark padded nodes so they're ignored in loss
        batch.pad_mask = torch.cat([
            torch.ones(n, dtype=torch.bool),
            torch.zeros(max_nodes - n, dtype=torch.bool),
        ])
    return batch

# Or: use dynamic=True (PyTorch 2.1+)
model = torch.compile(model, dynamic=True)

dynamic=True tells the compiler to generate shape-generic code. It's slower than static shapes but avoids recompilation.

2. SparseTensor compatibility

PyG’s SparseTensor format (used by some layers for efficiency) is not fully compatible with torch.compile. Stick to edge_index (COO format) when using compilation.

3. Heterogeneous models

HeteroConv and typed dispatch create dynamic control flow that torch.compile struggles with. For heterogeneous models, compile individual message passing functions rather than the full model.

Benchmark results

  • GCNConv (Cora, 2 layers): 34% faster training, 48% faster inference. The highest gains because GCN is the simplest layer with the most fusible operations.
  • GATConv (Cora, 2 layers): 28% faster training, 42% faster inference. Attention computation limits fusion opportunities.
  • SAGEConv (Reddit, 3 layers + NeighborLoader):22% faster training (dynamic shapes reduce gains), 38% faster inference with static padding.

How KumoRFM optimizes performance

KumoRFM goes beyond torch.compile with custom CUDA kernels:

  • Fused sparse attention kernels for relational graph transformers
  • Optimized temporal neighbor sampling on GPU (not CPU)
  • Custom scatter kernels for heterogeneous aggregation
  • Automatic mixed precision with per-layer precision tuning

These optimizations deliver 3-5x throughput over compiled vanilla PyG on the same hardware, enabling real-time inference on graphs that would otherwise require batch processing.

Frequently asked questions

Does torch.compile work with PyTorch Geometric?

Yes, as of PyG 2.5+. Most GNN layers (GCNConv, SAGEConv, GATConv, GINConv) are compatible with torch.compile. Some dynamic operations (like variable-size scatter) require the 'reduce-overhead' mode or custom adjustments. Heterogeneous layers have partial support.

How much speedup does torch.compile give for GNNs?

Typical speedups are 30-35% for training and 40-50% for inference on standard GNN layers. The gains come from kernel fusion (combining multiple small operations into one GPU kernel) and eliminating Python overhead. Gains are larger for small models where Python overhead is a bigger fraction of total time.

Why does torch.compile fail on some PyG operations?

torch.compile uses graph tracing, which requires static tensor shapes. PyG operations like scatter (variable output size), dynamic edge dropout, and heterogeneous batching create dynamic shapes that break tracing. Use mode='reduce-overhead' or add padding to make shapes static.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.