The serving challenge
GNN inference is fundamentally different from standard ML serving. A tabular model takes a feature vector and produces a prediction. A GNN takes a node, constructs its neighborhood subgraph, fetches features for all nodes in that subgraph, and then runs the model.
This means GNN inference involves a graph database query (neighbor lookup), a feature store query (feature retrieval), and a model forward pass. Each adds latency and failure modes.
Batch inference
The simplest serving pattern: precompute predictions for all nodes (or all nodes that matter) on a schedule.
import torch
from torch_geometric.loader import NeighborLoader
model.eval()
predictions = {}
# Score all nodes in batches
loader = NeighborLoader(
data, num_neighbors=[10, 5],
batch_size=4096, input_nodes=all_nodes,
)
with torch.no_grad():
for batch in loader:
out = model(batch.x.cuda(), batch.edge_index.cuda())
preds = out[:batch.batch_size].cpu()
for i, node_id in enumerate(batch.n_id[:batch.batch_size]):
predictions[node_id.item()] = preds[i]
# Store in Redis/DynamoDB for online lookup
store_predictions(predictions)Batch inference runs on a schedule (hourly, daily). Predictions are precomputed and stored in a key-value store for instant retrieval.
When batch works
- Recommendation scores refreshed daily
- Credit risk scores updated hourly
- Customer churn probabilities recomputed weekly
- Any use case where hour-old predictions are acceptable
Scaling serving throughput
- Model compilation: Use torch.compile with mode="max-autotune" for serving. One-time compilation cost, 40-50% faster inference forever.
- Request batching: Collect incoming requests for 5-10ms and batch them into a single GPU forward pass. This amortizes GPU kernel launch overhead across multiple predictions.
- Quantization: INT8 quantization cuts model size and inference time by ~2x with minimal accuracy loss. Use torch.quantization or ONNX Runtime quantization.
- CPU serving: For latency-insensitive batch predictions, CPU inference avoids GPU provisioning costs. Modern CPUs with AVX-512 can serve simple GNN models at 1000+ predictions/second.
What breaks in production
- Cold cache: After a deployment, the embedding cache is empty. The first predictions are either slow (cache miss triggers full inference) or wrong (stale fallback). Pre-warm the cache before routing traffic.
- Graph-model version skew: The graph structure changes (new node types, schema migrations) but the model expects the old structure. Version the graph schema alongside the model and validate compatibility at deployment time.
- Feature store latency spikes: A feature store outage or slowdown cascades into GNN serving latency. Set aggressive timeouts and fall back to default features (zeros or population means) rather than failing the request.