Production GNN deployment requires solving problems that research notebooks never encounter: building and updating the graph from live data, meeting latency budgets, keeping features fresh, and retraining models on an evolving graph. This gap between research and production is why most GNN projects stall after the proof-of-concept stage. Here is the full pipeline for getting a GNN into production.
The production pipeline
- Graph construction: ETL from source databases into a graph store. Map tables to node types, foreign keys to edge types, and row features to node features.
- Feature engineering: Compute node and edge features from raw data. Handle missing values, normalize numerics, encode categoricals.
- Training: Train the GNN on historical data with temporal splits. Validate on a held-out time period.
- Embedding generation: Run inference on the full graph to produce embeddings for all nodes.
- Serving: Store embeddings in a low-latency key-value store. Serve predictions via API.
- Monitoring: Track prediction quality, data drift, graph statistics, and latency.
- Retraining: Periodically retrain on fresh data and refresh embeddings.
Deployment pattern 1: Batch inference
Pre-compute embeddings for all nodes on a schedule (hourly, daily):
# Batch inference pipeline
# 1. Load the latest graph snapshot
graph = load_graph_from_warehouse(timestamp=now())
# 2. Run GNN inference on all nodes
model = load_model('production_v3')
with torch.no_grad():
embeddings = model.encode(graph) # all node embeddings
# 3. Store in low-latency serving layer
for node_id, embedding in zip(graph.node_ids, embeddings):
redis.set(f"emb:{node_id}", embedding.numpy().tobytes())
# 4. Scoring API reads pre-computed embeddings
@app.get("/predict/churn/{customer_id}")
def predict_churn(customer_id: str):
emb = redis.get(f"emb:{customer_id}")
score = scoring_model.predict(np.frombuffer(emb))
return {"churn_probability": float(score)}Batch inference pre-computes all embeddings. The serving API does a single key-value lookup per prediction, enabling fast prediction lookups.
Batch inference is simple and fast at serving time, but embeddings can be stale (up to the batch interval old). This is acceptable for daily recommendation refreshes but not for real-time fraud detection.
Graph construction is the hard part
In practice, the graph construction pipeline consumes more engineering effort than the model itself:
- ETL complexity: Joining 10-20 source tables, handling schema changes, resolving foreign key inconsistencies.
- Incremental updates: New transactions, new customers, changed relationships must be reflected in the graph without full reconstruction.
- Temporal correctness: The graph at prediction time T must only contain data from before T. This requires versioned graph snapshots or time-filtered queries.
- Monitoring: Graph statistics (node count, average degree, feature distributions) must be tracked for data drift that can silently degrade model performance.
Model optimization for production
- Distillation: Train a large teacher, distill to a small student for serving. 5-50x model compression with 95-99% accuracy retention.
- Quantization: Convert FP32 embeddings to INT8. 4x memory reduction, 2-3x speedup with minimal accuracy loss.
- Neighbor sampling budget: In production, sample fewer neighbors (5-10 instead of 25) for latency. Accuracy drops 1-2% but latency improves 3-5x.
- TorchScript / ONNX export: Export the model for optimized runtime inference without Python overhead.
The managed alternative
Building the full production GNN pipeline (graph construction, feature engineering, training, serving, monitoring, retraining) requires 6-12 months of ML engineering. Managed ML platforms can eliminate much of this by handling the end-to-end pipeline:
- Connect your database (Snowflake, BigQuery, Databricks)
- Define the prediction task in a high-level query language
- Get predictions via API or batch export
Managed platforms handle graph construction, model training, and serving infrastructure under the hood, reducing months of engineering to days.