Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Production7 min read

PyG + Databricks: Graph ML on the Lakehouse

Databricks gives you Spark, Delta Lake, Unity Catalog, and GPU clusters. PyG gives you graph neural networks. Combining them requires bridging the Spark-PyTorch gap without losing data governance or training speed.

PyTorch Geometric

TL;DR

  • 1Databricks ML Runtime supports PyG installation on GPU clusters. Use Spark for preprocessing and PyG for training in the same notebook.
  • 2The main challenge is the Spark-to-PyG data transfer. toPandas() works for small tables but hits memory limits above ~50M rows. Use Arrow or Petastorm for larger datasets.
  • 3Delta Lake's time travel and versioning provide temporal correctness for training data. Use VERSION AS OF or TIMESTAMP AS OF for point-in-time feature snapshots.
  • 4Unity Catalog provides data governance (who can access what) that extends to your GNN training pipeline. Critical for regulated industries.

Setting up PyG on Databricks

Databricks ML Runtime includes PyTorch but not PyG. Install it at cluster startup or in a notebook:

Databricks notebook
# Install PyG on Databricks GPU cluster
%pip install torch_geometric
%pip install pyg_lib torch_scatter torch_sparse -f \
    https://data.pyg.org/whl/torch-2.3.0+cu121.html

# Verify GPU access
import torch
print(f"GPU available: {torch.cuda.is_available()}")
print(f"Device: {torch.cuda.get_device_name(0)}")

Pin the PyG version to match your PyTorch version. Mismatched versions cause silent errors in scatter operations.

From Delta Lake to PyG

The data transfer pipeline has three stages:

Stage 1: Spark preprocessing

Use Spark for heavy-lifting: joins, aggregations, filtering. Spark handles billion-row tables that would never fit in Pandas.

spark_preprocessing.py
from pyspark.sql import functions as F

# Read Delta tables
customers = spark.read.table("catalog.schema.customers")
orders = spark.read.table("catalog.schema.orders")

# Temporal filter
orders_filtered = orders.filter(
    F.col("order_date") < "2026-03-01"
)

# Feature engineering in Spark (fast, distributed)
customer_features = customers.join(
    orders_filtered.groupBy("customer_id").agg(
        F.count("*").alias("order_count"),
        F.avg("amount").alias("avg_order"),
        F.max("order_date").alias("last_order"),
    ),
    on="customer_id",
    how="left",
)

Stage 2: Transfer to PyG

spark_to_pyg.py
import torch
from torch_geometric.data import HeteroData

# For tables < 50M rows: toPandas()
customers_pd = customer_features.toPandas()
orders_pd = orders_filtered.select(
    "customer_id", "product_id"
).toPandas()

# Build HeteroData
data = HeteroData()
data["customer"].x = torch.tensor(
    customers_pd[feature_cols].values, dtype=torch.float32
)

# Build edge index from foreign keys
src = customers_pd["customer_id"].map(customer_id_map)
dst = orders_pd["product_id"].map(product_id_map)
data["customer", "purchased", "product"].edge_index = \
    torch.tensor([src.values, dst.values], dtype=torch.long)

toPandas() collects all data to the driver node. For tables above 50M rows, use Petastorm or write to Parquet and load with PyArrow.

Stage 3: Train on GPU

Train the GNN on Databricks GPU nodes. Use MLflow for experiment tracking and model registry:

training.py
import mlflow

mlflow.pytorch.autolog()

with mlflow.start_run():
    model = train_gnn(data, config)

    # Log to MLflow Model Registry
    mlflow.pytorch.log_model(model, "gnn_model")

    # Write predictions back to Delta Lake
    predictions = inference(model, data)
    spark.createDataFrame(predictions).write.format("delta") \
        .saveAsTable("catalog.schema.gnn_predictions")

What breaks in production

  • Driver OOM: toPandas() loads everything onto the driver node. A 100M-row table with 50 columns can exceed driver memory. Increase driver memory or use Arrow-based transfer.
  • Cluster startup time: GPU clusters with PyG dependencies take 5-10 minutes to start. Use cluster pools to keep warm instances available.
  • Version conflicts: Databricks ML Runtime pins PyTorch versions. PyG requires specific PyTorch versions. Check compatibility before upgrading either.

Frequently asked questions

Can I train PyG models inside Databricks?

Yes. Databricks ML Runtime includes PyTorch, and you can pip install torch_geometric on any GPU cluster. Use Spark for data preprocessing and Petastorm or Spark's toPandas() to convert Delta tables to PyG format. Train on GPU nodes in the same cluster.

How do I load Delta Lake tables into PyG?

Use Spark to read Delta tables, filter and preprocess, then convert to Pandas DataFrames with toPandas(). From Pandas, construct PyG HeteroData objects with torch tensors. For large tables, use Petastorm or Arrow-based transfer to avoid the Pandas memory overhead.

Does KumoRFM integrate with Databricks?

Yes. KumoRFM connects to Databricks via Unity Catalog. It reads Delta tables directly, builds the graph automatically, and writes predictions back as Delta tables. No PyG code, no Spark-to-PyG conversion needed.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.