The serving challenge
GNN inference is fundamentally different from standard ML serving. A tabular model takes a feature vector and produces a prediction. A GNN takes a node, constructs its neighborhood subgraph, fetches features for all nodes in that subgraph, and then runs the model.
This means GNN inference involves a graph database query (neighbor lookup), a feature store query (feature retrieval), and a model forward pass. Each adds latency and failure modes.
Batch inference
The simplest serving pattern: precompute predictions for all nodes (or all nodes that matter) on a schedule.
import torch
from torch_geometric.loader import NeighborLoader
model.eval()
predictions = {}
# Score all nodes in batches
loader = NeighborLoader(
data, num_neighbors=[10, 5],
batch_size=4096, input_nodes=all_nodes,
)
with torch.no_grad():
for batch in loader:
out = model(batch.x.cuda(), batch.edge_index.cuda())
preds = out[:batch.batch_size].cpu()
for i, node_id in enumerate(batch.n_id[:batch.batch_size]):
predictions[node_id.item()] = preds[i]
# Store in Redis/DynamoDB for online lookup
store_predictions(predictions)Batch inference runs on a schedule (hourly, daily). Predictions are precomputed and stored in a key-value store for instant retrieval.
When batch works
- Recommendation scores refreshed daily
- Credit risk scores updated hourly
- Customer churn probabilities recomputed weekly
- Any use case where hour-old predictions are acceptable
Real-time inference
For fraud detection or real-time personalization, you need predictions on the latest graph state in under 100ms.
Strategy 1: Precomputed embeddings
Precompute node embeddings (intermediate GNN representations) and cache them. At inference time, fetch cached embeddings for the target node and its neighbors, then run only the final prediction layer.
# Offline: precompute embeddings for all nodes
with torch.no_grad():
embeddings = model.get_embeddings(data.x, data.edge_index)
# Store in Redis with node_id as key
for i in range(data.num_nodes):
redis.set(f"emb:{i}", embeddings[i].numpy().tobytes())
# Online: fast inference using cached embeddings
def predict(node_id, neighbor_ids):
# Fetch precomputed embeddings (1-2ms from Redis)
node_emb = get_embedding(node_id)
neighbor_embs = [get_embedding(n) for n in neighbor_ids]
# Aggregate (simple mean, no GNN needed)
context = torch.mean(torch.stack(neighbor_embs), dim=0)
combined = torch.cat([node_emb, context])
# Final prediction layer only (< 1ms)
return model.predictor(combined)Precomputed embeddings reduce inference from 50ms to 5ms. The tradeoff: embeddings become stale as the graph changes. Refresh on a schedule that matches your freshness SLA.
Strategy 2: Online subgraph construction
For maximum freshness, construct the subgraph at request time from a graph database (Neo4j, Neptune, TigerGraph). This gives live predictions but adds 20-50ms of latency for the graph query.
Scaling serving throughput
- Model compilation: Use torch.compile with mode="max-autotune" for serving. One-time compilation cost, 40-50% faster inference forever.
- Request batching: Collect incoming requests for 5-10ms and batch them into a single GPU forward pass. This amortizes GPU kernel launch overhead across multiple predictions.
- Quantization: INT8 quantization cuts model size and inference time by ~2x with minimal accuracy loss. Use torch.quantization or ONNX Runtime quantization.
- CPU serving: For latency-insensitive batch predictions, CPU inference avoids GPU provisioning costs. Modern CPUs with AVX-512 can serve simple GNN models at 1000+ predictions/second.
What breaks in production
- Cold cache: After a deployment, the embedding cache is empty. The first predictions are either slow (cache miss triggers full inference) or wrong (stale fallback). Pre-warm the cache before routing traffic.
- Graph-model version skew: The graph structure changes (new node types, schema migrations) but the model expects the old structure. Version the graph schema alongside the model and validate compatibility at deployment time.
- Feature store latency spikes: A feature store outage or slowdown cascades into GNN serving latency. Set aggressive timeouts and fall back to default features (zeros or population means) rather than failing the request.