Introducing KumoRFM-2: The Most Powerful Predictive Model, for Humans and Agents

Learn more
PyG/Guide8 min read

Production Deployment: Taking GNN Models from Notebook to Production Serving

Your GNN achieves 85% AUROC in a Jupyter notebook. Getting it into production reliably is where the real complexity lives. Here is the full production deployment pipeline.

PyTorch Geometric

TL;DR

  • 1Production GNN deployment requires solving four problems research ignores: graph construction pipeline, inference latency, feature freshness, and model updating on evolving graphs.
  • 2Batch deployment (pre-compute all embeddings periodically, serve from cache) is the most common production pattern. Pre-computed embeddings reduce serving to a single lookup + lightweight scoring model.
  • 3The graph construction pipeline is often harder than the model: ETL from source systems, incremental graph updates, feature engineering, temporal correctness, and monitoring for data drift.
  • 4Managed ML platforms can eliminate much of this complexity by handling graph construction, model training, and serving infrastructure under the hood.

Production GNN deployment requires solving problems that research notebooks never encounter: building and updating the graph from live data, meeting latency budgets, keeping features fresh, and retraining models on an evolving graph. This gap between research and production is why most GNN projects stall after the proof-of-concept stage. Here is the full pipeline for getting a GNN into production.

The production pipeline

  1. Graph construction: ETL from source databases into a graph store. Map tables to node types, foreign keys to edge types, and row features to node features.
  2. Feature engineering: Compute node and edge features from raw data. Handle missing values, normalize numerics, encode categoricals.
  3. Training: Train the GNN on historical data with temporal splits. Validate on a held-out time period.
  4. Embedding generation: Run inference on the full graph to produce embeddings for all nodes.
  5. Serving: Store embeddings in a low-latency key-value store. Serve predictions via API.
  6. Monitoring: Track prediction quality, data drift, graph statistics, and latency.
  7. Retraining: Periodically retrain on fresh data and refresh embeddings.

Deployment pattern 1: Batch inference

Pre-compute embeddings for all nodes on a schedule (hourly, daily):

batch_inference.py
# Batch inference pipeline
# 1. Load the latest graph snapshot
graph = load_graph_from_warehouse(timestamp=now())

# 2. Run GNN inference on all nodes
model = load_model('production_v3')
with torch.no_grad():
    embeddings = model.encode(graph)  # all node embeddings

# 3. Store in low-latency serving layer
for node_id, embedding in zip(graph.node_ids, embeddings):
    redis.set(f"emb:{node_id}", embedding.numpy().tobytes())

# 4. Scoring API reads pre-computed embeddings
@app.get("/predict/churn/{customer_id}")
def predict_churn(customer_id: str):
    emb = redis.get(f"emb:{customer_id}")
    score = scoring_model.predict(np.frombuffer(emb))
    return {"churn_probability": float(score)}

Batch inference pre-computes all embeddings. The serving API does a single key-value lookup per prediction, enabling fast prediction lookups.

Batch inference is simple and fast at serving time, but embeddings can be stale (up to the batch interval old). This is acceptable for daily recommendation refreshes but not for real-time fraud detection.

Graph construction is the hard part

In practice, the graph construction pipeline consumes more engineering effort than the model itself:

  • ETL complexity: Joining 10-20 source tables, handling schema changes, resolving foreign key inconsistencies.
  • Incremental updates: New transactions, new customers, changed relationships must be reflected in the graph without full reconstruction.
  • Temporal correctness: The graph at prediction time T must only contain data from before T. This requires versioned graph snapshots or time-filtered queries.
  • Monitoring: Graph statistics (node count, average degree, feature distributions) must be tracked for data drift that can silently degrade model performance.

Model optimization for production

  • Distillation: Train a large teacher, distill to a small student for serving. 5-50x model compression with 95-99% accuracy retention.
  • Quantization: Convert FP32 embeddings to INT8. 4x memory reduction, 2-3x speedup with minimal accuracy loss.
  • Neighbor sampling budget: In production, sample fewer neighbors (5-10 instead of 25) for latency. Accuracy drops 1-2% but latency improves 3-5x.
  • TorchScript / ONNX export: Export the model for optimized runtime inference without Python overhead.

The managed alternative

Building the full production GNN pipeline (graph construction, feature engineering, training, serving, monitoring, retraining) requires 6-12 months of ML engineering. Managed ML platforms can eliminate much of this by handling the end-to-end pipeline:

  • Connect your database (Snowflake, BigQuery, Databricks)
  • Define the prediction task in a high-level query language
  • Get predictions via API or batch export

Managed platforms handle graph construction, model training, and serving infrastructure under the hood, reducing months of engineering to days.

Frequently asked questions

What are the main challenges of deploying GNNs in production?

Four main challenges: (1) Graph construction: keeping the graph up-to-date as new data arrives. (2) Inference latency: neighborhood expansion makes GNN inference slower than MLP inference. (3) Feature freshness: node features must reflect the latest data. (4) Model updating: retraining on an evolving graph without downtime.

How fast can production GNN inference be?

With pre-computed node embeddings and incremental updates, GNN inference for a single node can be under 10ms. Without pre-computation (full neighborhood expansion per request), latency is 50-500ms depending on graph size and model depth. Batch inference can process millions of nodes per minute on GPUs.

Should I use real-time or batch GNN inference?

Batch inference (pre-compute all embeddings periodically) for most use cases: recommendations, risk scoring, customer segmentation. Real-time inference (compute per request) for latency-sensitive decisions: fraud detection on individual transactions, real-time personalization. Many production systems use a hybrid: batch embeddings + real-time lightweight scoring.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.