Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Production8 min read

Multi-GPU and Distributed GNN Training

GNN training on a single GPU takes days for enterprise graphs. Multi-GPU training can cut this to hours, but graph-specific challenges (sampling, communication, stochastic topology) make distributed GNNs harder than distributed CNNs.

PyTorch Geometric

TL;DR

  • 1Use data parallelism (DDP), not model parallelism. GNN models are small (< 10M parameters). The bottleneck is data (large graphs), not model size.
  • 2Each GPU samples independent mini-batches from the graph using NeighborLoader. Gradients are synchronized via AllReduce after each step.
  • 3Scale batch size linearly with GPU count and use learning rate warmup. GNNs are more sensitive to batch size changes than CNNs because each batch has different topology.
  • 4The communication bottleneck is gradient sync (small, fast) plus feature fetching from distributed graph storage (large, slow). Optimize feature fetching first.

Data-parallel GNN training

The standard approach: replicate the model on each GPU, feed each GPU different sampled mini-batches, and synchronize gradients after each step.

ddp_training.py
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch_geometric.loader import NeighborLoader

def train(rank, world_size, data, model):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    model = model.to(rank)
    model = DDP(model, device_ids=[rank])

    # Each GPU gets its own sampler
    loader = NeighborLoader(
        data,
        num_neighbors=[15, 10],
        batch_size=1024,
        input_nodes=train_mask,
        num_workers=4,
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(100):
        model.train()
        for batch in loader:
            batch = batch.to(rank)
            out = model(batch.x, batch.edge_index)
            loss = F.cross_entropy(
                out[:batch.batch_size],
                batch.y[:batch.batch_size],
            )
            loss.backward()  # gradients synced by DDP
            optimizer.step()
            optimizer.zero_grad()

# Launch with torchrun
# torchrun --nproc_per_node=4 train.py

DDP handles gradient synchronization automatically. Each GPU runs independent neighbor sampling, so batches have different subgraph topologies.

Scaling challenges specific to GNNs

1. Sampling overhead scales with GPUs

Each GPU runs its own NeighborLoader. With 4 GPUs and fanout [15, 10], you are doing 4x the neighbor sampling on CPU. CPU cores become the bottleneck before GPU compute does.

  • Allocate 4-8 CPU cores per GPU for sampling workers
  • Use pin_memory=True and persistent_workers=True
  • Consider GPU-based sampling (PyG 2.6+) to offload from CPU

2. Batch size sensitivity

In CNN training, doubling batch size has predictable effects (scale learning rate linearly). In GNN training, doubling batch size changes the subgraph topology distribution. Larger batches include more diverse neighborhoods, which changes the gradient landscape.

lr_scaling.py
# Linear scaling rule with warmup
base_lr = 0.001
base_batch = 1024
effective_batch = 1024 * world_size

scaled_lr = base_lr * (effective_batch / base_batch)

# Warmup for first 5% of training
warmup_epochs = max(1, int(0.05 * total_epochs))
scheduler = torch.optim.lr_scheduler.LinearLR(
    optimizer,
    start_factor=1.0 / world_size,
    total_iters=warmup_epochs * steps_per_epoch,
)

Warmup is critical for GNNs. Without it, the large initial learning rate causes divergence because the first batches have noisy gradient estimates.

3. Graph storage and feature fetching

When the graph is too large for one machine, graph storage is distributed. Each GPU must fetch features from remote partitions, adding network latency. This is the dominant bottleneck in distributed GNN training.

Distributed graph storage

For graphs that do not fit on a single machine:

  • Shared memory: Store the graph in shared memory (torch.multiprocessing) so all GPUs on the same machine access the same copy. This avoids 4x memory for 4 GPUs.
  • UVA (Unified Virtual Addressing): Store the graph in CPU memory and access it from GPU via UVA. Slower than GPU memory but handles graphs up to CPU RAM size.
  • Remote storage: Partition the graph across machines. Use RPC for cross-partition feature fetching. Highest capacity but highest latency.

What breaks in production

  • GPU utilization below 50%: Common in distributed GNN training because GPUs wait for CPU sampling. Monitor GPU utilization and add CPU cores or switch to GPU sampling if low.
  • Non-reproducibility: Stochastic sampling + distributed execution = different results every run. Set all seeds (torch, numpy, NeighborLoader) and use deterministic algorithms for reproducible experiments.
  • Stragglers: One GPU with a large subgraph (many hub nodes) stalls all others during gradient sync. Use dynamic batch sizing based on subgraph size, not fixed seed count.

Frequently asked questions

Does PyG support multi-GPU training?

Yes. PyG works with PyTorch's DistributedDataParallel (DDP) for data-parallel training across multiple GPUs. Each GPU processes different mini-batches sampled from the same graph. For very large graphs, PyG also provides distributed sampling that partitions the graph across machines.

Should I use data parallelism or model parallelism for GNNs?

Use data parallelism (DDP) in almost all cases. GNN models are small (typically < 10M parameters) so model parallelism provides no benefit. The scaling challenge is data: the graph is large, not the model. Use data parallelism with distributed sampling to handle large graphs.

How does batch size affect distributed GNN training?

Scale batch size linearly with GPU count (e.g., 1024 per GPU with 4 GPUs = effective batch size 4096). Use learning rate warmup for the first 5-10% of training. GNNs are more sensitive to batch size than CNNs because each batch samples a different subgraph topology.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.