Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
April 2025

Speeding Up Graph Learning with torch.compile

A practical guide to optimizing PyG-based GNN training: eliminating re-compilations, graph breaks, and host-device synchronizations for 30-35% speedups.

Akihiro NittaMatthias Fey
01

Why Compilation Matters for Graph Learning

PyTorch 2.0 introduced torch.compile, a JIT compiler that fuses GPU kernels, eliminates Python overhead, and can deliver significant speedups on standard deep learning workloads. But graph neural networks (GNNs) present unique challenges that make compilation far less straightforward than calling torch.compile(model) and moving on.

The core difficulty is irregular data. In GNNs, every mini-batch has a different number of nodes, edges, and neighbors. Neighbor sampling produces subgraphs of varying shapes across iterations. This variability triggers re-compilations, graph breaks, and host-device synchronizations that can erase any performance gains from compilation.

At Kumo, the production stack for Relational Deep Learning (RDL) constructs a large-scale heterogeneous graph from relational database tables. Each table becomes a node type. Each row becomes a node. Primary key-foreign key pairs define the edges. The system then uses PyTorch Geometric (PyG) for message passing, PyTorch Frame for multi-modal feature encoding (numerical, categorical, text, images, timestamps), and PyTorch Lightning for training orchestration.

This guide covers the specific optimizations that produced 30-35% training speedups on real workloads, without sacrificing model accuracy. Every technique is applicable to any PyG-based pipeline.

02

Avoiding Unnecessary Re-Compilations

The single biggest source of wasted time with torch.compile in graph learning is re-compilation. By default, PyTorch compiles an optimized kernel for the exact input shapes it sees. When the next mini-batch has different shapes (which happens every iteration in GNNs due to neighbor sampling), PyTorch re-compiles from scratch. This makes training slower than eager mode, not faster.

Problem 1: Dynamic input shapes

In a standard image pipeline, every batch has the same tensor dimensions. In graph learning, the number of sampled nodes and edges changes per batch. Without intervention, each new shape triggers a full re-compilation.

The fix: use torch.compile(dynamic=True). This tells the compiler to generate kernels that handle variable shapes from the start, using symbolic shape analysis instead of hard-coding dimensions. The compiled kernel works for any input size, eliminating re-compilation entirely.

Problem 2: Learning rate schedulers

A subtler re-compilation trigger: torch.compile treats floating-point scalar arguments as compile-time constants. When a learning rate scheduler updates the LR (a float) at each step, the compiler sees a "new" constant and re-compiles.

The fix: wrap the learning rate in a tensor.

# Before (triggers re-compilation on every LR update)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# After (LR is a tensor, not a compile-time constant)
optimizer = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01))

By converting the learning rate to a tensor, PyTorch treats it as a runtime value rather than a constant, so scheduler updates no longer trigger re-compilation.

03

Graph Breaks and CUDA Graphs

Even after eliminating re-compilations, torch.compile can still underperform if the computation graph is fragmented. Two issues dominate here: graph breaks and the question of whether to enable CUDA Graphs.

Graph breaks

A graph break occurs when torch.compile encounters an operation it cannot trace through. The compiler splits the computation into separate subgraphs, each compiled independently. This prevents kernel fusion across the break boundary, reducing the optimization surface.

In PyG, a common source of graph breaks was the use of tuple keys in ModuleDict lookups for heterogeneous graphs. Converting tuple keys to string keys resolved these breaks, allowing the compiler to trace through the full forward pass without interruption.

The PyG library has been progressively updated to be more torch.compile-friendly, so staying on the latest release picks up these fixes automatically.

1

Identify Breaks

Use TORCH_LOGS to find where the compiler splits the graph into separate subgraphs.

2

Fix Source Code

Convert incompatible patterns (tuple keys, unsupported ops) to compiler-friendly alternatives.

3

Verify Fusion

Confirm the full forward pass compiles as a single graph for maximum kernel fusion.

CUDA Graphs: when not to use them

CUDA Graphs record a sequence of GPU kernels and replay them as a single unit, eliminating per-kernel launch overhead. For fixed-shape workloads (image classification, transformer inference with fixed sequence lengths), this delivers substantial speedups.

For GNN training with neighbor sampling, CUDA Graphs are counterproductive. The number of nodes in each mini-batch of subgraphs differs across iterations. CUDA Graphs must re-record the kernel sequence whenever shapes change, which increases memory usage and introduces slowdowns from repeated recording.

The Kumo pipeline explicitly disables CUDA Graphs:

# CUDA Graphs disabled: mini-batch shapes vary per iteration
model = torch.compile(model, dynamic=True)
# Do NOT use: torch.compile(model, mode="reduce-overhead")
# "reduce-overhead" enables CUDA Graphs internally
04

Eliminating Host-Device Synchronization

The most impactful optimization category is removing unnecessary synchronizations between the CPU (host) and GPU (device). Each synchronization forces the GPU to finish all queued work before the CPU can proceed, creating a pipeline bubble. In practice, these synchronization points often dominate training time more than the actual computation.

Problem 1: Metric accumulation on CPU

A common pattern in training loops is accumulating metrics like accuracy on the CPU:

# Before: forces GPU sync every iteration
metric += int((label == pred).sum())

# After: keep accumulation on GPU, sync once at epoch end
metric += (label == pred).sum()

The int() call forces a device-to-host transfer, which blocks until the GPU finishes all pending kernels. Removing it keeps the running total on the GPU as a tensor. You only transfer the final accumulated value to CPU once per epoch.

Problem 2: torch.repeat_interleave without output_size

The function torch.repeat_interleave is commonly used in GNN message passing to expand node features along edges. Without the output_size argument, PyTorch must synchronize to calculate the output tensor size before allocating memory and launching kernels.

# Before: triggers synchronization to compute output size
expanded = torch.repeat_interleave(x, repeats, dim=0)

# After: provide output_size to avoid sync
expanded = torch.repeat_interleave(x, repeats, dim=0,
                                    output_size=known_size)

When you know the output size in advance (which is common in GNN pipelines where edge counts are known from the graph structure), passing it directly lets PyTorch allocate memory and launch kernels without waiting.

Problem 3: Third-party library synchronizations

Even well-maintained libraries can introduce hidden synchronizations. Older versions of torchmetrics, for example, introduced device syncs in their metric computation. Upgrading to newer versions resolved these issues. The lesson: profile before and after library upgrades, and pin versions that are sync-free.

05

The Full Optimization Pipeline

Applying these techniques is not a single switch. Each optimization addresses a different layer of the stack, and they compound. Here is the sequence used in the Kumo RDL pipeline:

1

Enable torch.compile

Use dynamic=True to handle variable subgraph sizes from neighbor sampling.

2

Fix Re-Compilations

Wrap LR in a tensor. Ensure no other scalars are treated as compile-time constants.

3

Resolve Graph Breaks

Update PyG. Convert tuple keys to strings in ModuleDict. Verify single-graph compilation.

4

Disable CUDA Graphs

Avoid reduce-overhead mode. Variable mini-batch shapes make CUDA Graphs counterproductive.

5

Remove Host-Device Syncs

GPU-side metric accumulation. Provide output_size arguments. Upgrade torchmetrics.

The RDL architecture itself consists of three stages that all benefit from these optimizations:

  1. Feature encoding with PyTorch Frame: encodes multi-modal column data (numerical, categorical, text, images, timestamps) into dense embeddings.
  2. Message passing with PyG: performs heterogeneous message passing across the relational graph, propagating information between connected tables.
  3. Training orchestration with PyTorch Lightning: manages the training loop, distributed training, and checkpointing.

Each stage has different compilation characteristics. Feature encoding involves standard tensor operations that compile cleanly. Message passing has the variable-shape challenges described above. Training orchestration wraps the outer loop and benefits from sync removal.

06

Benchmark Results

All benchmarks were run on a g6.4xlarge instance using PyTorch 2.7.0 with CUDA 12.8. The dataset is the Kaggle H&M recommendation challenge, a real-world e-commerce dataset with users, items, and transactions. Three tasks were evaluated, each representing a different prediction type:

Task descriptions for the H&M benchmark dataset.
TaskTypeDescription
user-churnNode ClassificationPredict whether a customer will churn within one week.
item-salesNode RegressionEstimate weekly article sales volume.
user-item-purchaseLink PredictionForecast which articles a customer will purchase over seven days.

Speedup results

Across all three tasks, torch.compile with the full optimization pipeline delivered consistent speedups of 30% to 35% compared to eager mode execution. Critically, there was no reduction in model predictive accuracy on any task. The compiled models produced identical predictions to their eager-mode counterparts.

Training speedup with torch.compile (higher is better). All tasks show 30-35% improvement.
TaskModeSpeedupAccuracy Impact
user-churntorch.compile + optimizations~30-35%None
item-salestorch.compile + optimizations~30-35%None
user-item-purchasetorch.compile + optimizations~30-35%None

The speedups are consistent across task types (classification, regression, link prediction), which indicates the optimizations operate at the infrastructure level rather than being task-specific. Whether you are predicting churn, forecasting sales, or generating recommendations, the same compilation pipeline delivers the same performance gains.

What the numbers mean in practice

A 35% training speedup translates directly to cost and iteration speed. A training run that previously took 10 hours now takes 6.5 hours. Over hundreds of experiments during model development, this compounds into days of saved GPU time and faster experimentation cycles.

07

Practical Checklist

If you are running a PyG-based GNN pipeline and want to apply these optimizations, here is the concrete checklist:

Step 1: Enable compilation with dynamic shapes

model = torch.compile(model, dynamic=True)

Do not use mode="reduce-overhead" (it enables CUDA Graphs, which hurts variable-shape workloads).

Step 2: Fix learning rate re-compilation

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=torch.tensor(0.01)  # tensor, not float
)

Step 3: Eliminate graph breaks

  • Update PyG to the latest version (many graph breaks have been fixed upstream).
  • Convert tuple keys to string keys in any custom ModuleDict usage.
  • Set TORCH_LOGS="graph_breaks" to find remaining breaks in your code.

Step 4: Remove host-device synchronizations

  • Replace int(tensor) and float(tensor) calls in training loops with tensor-native accumulation.
  • Pass output_size to torch.repeat_interleave whenever the size is known.
  • Upgrade torchmetrics and other metric libraries.

Step 5: Profile and verify

  • Use the PyTorch Profiler to confirm synchronization points are eliminated.
  • Compare eager vs. compiled training loss curves to verify numerical equivalence.
  • Measure wall-clock time per epoch, not per step (compilation overhead is amortized over the epoch).

Recommended resources

Try KumoRFM on your own data

Zero-shot predictions are free. Fine-tuning is available with a trial.