Data-parallel GNN training
The standard approach: replicate the model on each GPU, feed each GPU different sampled mini-batches, and synchronize gradients after each step.
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch_geometric.loader import NeighborLoader
def train(rank, world_size, data, model):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
model = model.to(rank)
model = DDP(model, device_ids=[rank])
# Each GPU gets its own sampler
loader = NeighborLoader(
data,
num_neighbors=[15, 10],
batch_size=1024,
input_nodes=train_mask,
num_workers=4,
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(100):
model.train()
for batch in loader:
batch = batch.to(rank)
out = model(batch.x, batch.edge_index)
loss = F.cross_entropy(
out[:batch.batch_size],
batch.y[:batch.batch_size],
)
loss.backward() # gradients synced by DDP
optimizer.step()
optimizer.zero_grad()
# Launch with torchrun
# torchrun --nproc_per_node=4 train.pyDDP handles gradient synchronization automatically. Each GPU runs independent neighbor sampling, so batches have different subgraph topologies.
Scaling challenges specific to GNNs
1. Sampling overhead scales with GPUs
Each GPU runs its own NeighborLoader. With 4 GPUs and fanout [15, 10], you are doing 4x the neighbor sampling on CPU. CPU cores become the bottleneck before GPU compute does.
- Allocate 4-8 CPU cores per GPU for sampling workers
- Use pin_memory=True and persistent_workers=True
- Consider GPU-based sampling (PyG 2.6+) to offload from CPU
2. Batch size sensitivity
In CNN training, doubling batch size has predictable effects (scale learning rate linearly). In GNN training, doubling batch size changes the subgraph topology distribution. Larger batches include more diverse neighborhoods, which changes the gradient landscape.
# Linear scaling rule with warmup
base_lr = 0.001
base_batch = 1024
effective_batch = 1024 * world_size
scaled_lr = base_lr * (effective_batch / base_batch)
# Warmup for first 5% of training
warmup_epochs = max(1, int(0.05 * total_epochs))
scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=1.0 / world_size,
total_iters=warmup_epochs * steps_per_epoch,
)Warmup is critical for GNNs. Without it, the large initial learning rate causes divergence because the first batches have noisy gradient estimates.
3. Graph storage and feature fetching
When the graph is too large for one machine, graph storage is distributed. Each GPU must fetch features from remote partitions, adding network latency. This is the dominant bottleneck in distributed GNN training.
Distributed graph storage
For graphs that do not fit on a single machine:
- Shared memory: Store the graph in shared memory (torch.multiprocessing) so all GPUs on the same machine access the same copy. This avoids 4x memory for 4 GPUs.
- UVA (Unified Virtual Addressing): Store the graph in CPU memory and access it from GPU via UVA. Slower than GPU memory but handles graphs up to CPU RAM size.
- Remote storage: Partition the graph across machines. Use RPC for cross-partition feature fetching. Highest capacity but highest latency.
What breaks in production
- GPU utilization below 50%: Common in distributed GNN training because GPUs wait for CPU sampling. Monitor GPU utilization and add CPU cores or switch to GPU sampling if low.
- Non-reproducibility: Stochastic sampling + distributed execution = different results every run. Set all seeds (torch, numpy, NeighborLoader) and use deterministic algorithms for reproducible experiments.
- Stragglers: One GPU with a large subgraph (many hub nodes) stalls all others during gradient sync. Use dynamic batch sizing based on subgraph size, not fixed seed count.