Introducing KumoRFM-2: The Most Powerful Predictive Model, for Humans and Agents

Learn more
PyG/Production8 min read

Serving GNN Predictions in Production

Training a GNN is the easy part. Serving it reliably in production while the graph changes underneath you is where the real complexity lives.

PyTorch Geometric

TL;DR

  • 1GNN inference has two phases: subgraph construction (neighbor sampling) and model forward pass. Subgraph construction is usually the bottleneck, taking 10-50ms vs 2-5ms for the model.
  • 2Batch inference (precompute all predictions periodically) is the most practical and cost-effective serving pattern for most use cases.
  • 3Precompute and cache node embeddings. At serving time, fetch cached embeddings and run only the final prediction layer for fast lookups.
  • 4Graph staleness is the hidden serving challenge: the graph used for inference may be minutes or hours behind the live database. Define your freshness SLA and build refresh pipelines accordingly.

The serving challenge

GNN inference is fundamentally different from standard ML serving. A tabular model takes a feature vector and produces a prediction. A GNN takes a node, constructs its neighborhood subgraph, fetches features for all nodes in that subgraph, and then runs the model.

This means GNN inference involves a graph database query (neighbor lookup), a feature store query (feature retrieval), and a model forward pass. Each adds latency and failure modes.

Batch inference

The simplest serving pattern: precompute predictions for all nodes (or all nodes that matter) on a schedule.

batch_inference.py
import torch
from torch_geometric.loader import NeighborLoader

model.eval()
predictions = {}

# Score all nodes in batches
loader = NeighborLoader(
    data, num_neighbors=[10, 5],
    batch_size=4096, input_nodes=all_nodes,
)

with torch.no_grad():
    for batch in loader:
        out = model(batch.x.cuda(), batch.edge_index.cuda())
        preds = out[:batch.batch_size].cpu()
        for i, node_id in enumerate(batch.n_id[:batch.batch_size]):
            predictions[node_id.item()] = preds[i]

# Store in Redis/DynamoDB for online lookup
store_predictions(predictions)

Batch inference runs on a schedule (hourly, daily). Predictions are precomputed and stored in a key-value store for instant retrieval.

When batch works

  • Recommendation scores refreshed daily
  • Credit risk scores updated hourly
  • Customer churn probabilities recomputed weekly
  • Any use case where hour-old predictions are acceptable

Scaling serving throughput

  • Model compilation: Use torch.compile with mode="max-autotune" for serving. One-time compilation cost, 40-50% faster inference forever.
  • Request batching: Collect incoming requests for 5-10ms and batch them into a single GPU forward pass. This amortizes GPU kernel launch overhead across multiple predictions.
  • Quantization: INT8 quantization cuts model size and inference time by ~2x with minimal accuracy loss. Use torch.quantization or ONNX Runtime quantization.
  • CPU serving: For latency-insensitive batch predictions, CPU inference avoids GPU provisioning costs. Modern CPUs with AVX-512 can serve simple GNN models at 1000+ predictions/second.

What breaks in production

  • Cold cache: After a deployment, the embedding cache is empty. The first predictions are either slow (cache miss triggers full inference) or wrong (stale fallback). Pre-warm the cache before routing traffic.
  • Graph-model version skew: The graph structure changes (new node types, schema migrations) but the model expects the old structure. Version the graph schema alongside the model and validate compatibility at deployment time.
  • Feature store latency spikes: A feature store outage or slowdown cascades into GNN serving latency. Set aggressive timeouts and fall back to default features (zeros or population means) rather than failing the request.

Frequently asked questions

Can I export a PyG model to ONNX?

Partially. Simple GNN models (GCNConv, SAGEConv with fixed graph structure) can be exported to ONNX. However, PyG's scatter operations and dynamic graph construction don't have direct ONNX equivalents. For production serving, TorchScript or torch.compile with AOTInductor are more reliable options.

What latency should I expect for batch GNN inference serving?

For batch inference with precomputed embeddings stored in a key-value store, expect sub-1ms serving latency per prediction (a simple lookup). The batch computation itself depends on graph size — scoring all nodes in a million-node graph typically takes minutes to hours depending on model complexity and hardware.

How often should I refresh batch predictions?

It depends on your freshness requirements. Daily refreshes work for recommendations and risk scores. Hourly refreshes work for more time-sensitive use cases. The right cadence depends on how quickly your underlying data changes and your business tolerance for staleness.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.