Setting up PyG on Databricks
Databricks ML Runtime includes PyTorch but not PyG. Install it at cluster startup or in a notebook:
# Install PyG on Databricks GPU cluster
%pip install torch_geometric
%pip install pyg_lib torch_scatter torch_sparse -f \
https://data.pyg.org/whl/torch-2.3.0+cu121.html
# Verify GPU access
import torch
print(f"GPU available: {torch.cuda.is_available()}")
print(f"Device: {torch.cuda.get_device_name(0)}")Pin the PyG version to match your PyTorch version. Mismatched versions cause silent errors in scatter operations.
From Delta Lake to PyG
The data transfer pipeline has three stages:
Stage 1: Spark preprocessing
Use Spark for heavy-lifting: joins, aggregations, filtering. Spark handles billion-row tables that would never fit in Pandas.
from pyspark.sql import functions as F
# Read Delta tables
customers = spark.read.table("catalog.schema.customers")
orders = spark.read.table("catalog.schema.orders")
# Temporal filter
orders_filtered = orders.filter(
F.col("order_date") < "2026-03-01"
)
# Feature engineering in Spark (fast, distributed)
customer_features = customers.join(
orders_filtered.groupBy("customer_id").agg(
F.count("*").alias("order_count"),
F.avg("amount").alias("avg_order"),
F.max("order_date").alias("last_order"),
),
on="customer_id",
how="left",
)Stage 2: Transfer to PyG
import torch
from torch_geometric.data import HeteroData
# For tables < 50M rows: toPandas()
customers_pd = customer_features.toPandas()
orders_pd = orders_filtered.select(
"customer_id", "product_id"
).toPandas()
# Build HeteroData
data = HeteroData()
data["customer"].x = torch.tensor(
customers_pd[feature_cols].values, dtype=torch.float32
)
# Build edge index from foreign keys
src = customers_pd["customer_id"].map(customer_id_map)
dst = orders_pd["product_id"].map(product_id_map)
data["customer", "purchased", "product"].edge_index = \
torch.tensor([src.values, dst.values], dtype=torch.long)toPandas() collects all data to the driver node. For tables above 50M rows, use Petastorm or write to Parquet and load with PyArrow.
Stage 3: Train on GPU
Train the GNN on Databricks GPU nodes. Use MLflow for experiment tracking and model registry:
import mlflow
mlflow.pytorch.autolog()
with mlflow.start_run():
model = train_gnn(data, config)
# Log to MLflow Model Registry
mlflow.pytorch.log_model(model, "gnn_model")
# Write predictions back to Delta Lake
predictions = inference(model, data)
spark.createDataFrame(predictions).write.format("delta") \
.saveAsTable("catalog.schema.gnn_predictions")What breaks in production
- Driver OOM: toPandas() loads everything onto the driver node. A 100M-row table with 50 columns can exceed driver memory. Increase driver memory or use Arrow-based transfer.
- Cluster startup time: GPU clusters with PyG dependencies take 5-10 minutes to start. Use cluster pools to keep warm instances available.
- Version conflicts: Databricks ML Runtime pins PyTorch versions. PyG requires specific PyTorch versions. Check compatibility before upgrading either.