Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Production7 min read

PyG + Snowflake: Graph ML on Your Data Warehouse

Your enterprise data lives in Snowflake. Your ML team wants to use graph neural networks. Bridging these two worlds requires careful ETL, temporal correctness, and an understanding of what Snowflake does well and what it does not.

PyTorch Geometric

TL;DR

  • 1Snowflake stores relational data that maps naturally to graphs: tables are node types, foreign keys are edge types. But extracting this graph into PyG requires a non-trivial ETL pipeline.
  • 2The extraction pipeline: SQL queries to export tables, Parquet as the intermediate format, Python to convert to PyG HeteroData. Snowpark can run preprocessing inside Snowflake.
  • 3Temporal correctness is critical: use Snowflake's TIME_TRAVEL or window functions to create point-in-time feature snapshots. Without this, you get temporal leakage.
  • 4Data warehouse compute (Snowflake) and ML compute (GPU) are fundamentally different workloads. Do not try to run GNN training inside Snowflake. Extract and load.

The Snowflake-to-PyG pipeline

Most enterprise ML teams store their data in Snowflake. Getting that data into PyG format requires four steps, each with production pitfalls.

Step 1: Schema discovery

Identify which Snowflake tables map to node types and which foreign key relationships map to edge types. For a typical e-commerce deployment:

schema_discovery.sql
-- Discover tables and their relationships
SELECT
    tc.table_name AS source_table,
    kcu.column_name AS fk_column,
    ccu.table_name AS target_table,
    ccu.column_name AS target_column
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
    ON tc.constraint_name = kcu.constraint_name
JOIN information_schema.constraint_column_usage ccu
    ON tc.constraint_name = ccu.constraint_name
WHERE tc.constraint_type = 'FOREIGN KEY';

-- Result:
-- orders -> customers (customer_id)
-- orders -> products (product_id)
-- reviews -> customers (customer_id)
-- reviews -> products (product_id)

Not all Snowflake schemas have formal foreign key constraints. Many use implicit joins. Document your join relationships manually if constraints are missing.

Step 2: Feature extraction

feature_extraction.py
import snowflake.connector
import pandas as pd

conn = snowflake.connector.connect(...)

# Extract node features
customers = pd.read_sql("""
    SELECT customer_id, age, tenure_days, lifetime_value,
           city, segment
    FROM customers
""", conn)

# Extract edges with timestamps
orders = pd.read_sql("""
    SELECT order_id, customer_id, product_id,
           order_date, amount
    FROM orders
    WHERE order_date < '2026-03-01'  -- temporal cutoff
""", conn)

# Export to Parquet for efficient loading
customers.to_parquet("customers.parquet")
orders.to_parquet("orders.parquet")

Always apply a temporal cutoff at extraction time. This is your first defense against temporal leakage.

Step 3: Convert to PyG HeteroData

Convert extracted DataFrames into PyG’s HeteroData format with proper feature encoding and edge construction. See the heterogeneous graphs guide for the full conversion pipeline.

Step 4: Temporal correctness with TIME_TRAVEL

Snowflake’s TIME_TRAVEL feature lets you query data as it existed at any point in the past. This is invaluable for creating temporally correct training data:

time_travel.sql
-- Get customer features as they were on 2026-01-01
SELECT * FROM customers
AT (TIMESTAMP => '2026-01-01 00:00:00'::TIMESTAMP);

-- Or use BEFORE for pre-event snapshots
SELECT * FROM orders
BEFORE (STATEMENT => 'last_query_id');

TIME_TRAVEL is limited to 90 days on Enterprise edition. For longer historical training data, maintain your own snapshot tables.

What breaks in production

  • Schema drift: Snowflake schemas change (columns added, renamed, deleted). Your extraction pipeline breaks silently, producing graphs with missing or wrong features. Add schema validation checks before each extraction.
  • Warehouse contention: ML extraction queries compete with analytics queries for Snowflake compute. Use a dedicated warehouse for ML workloads with auto-suspend to control costs.
  • Data freshness lag: The time between Snowflake data update and PyG model retraining determines prediction staleness. Automate the extraction-training pipeline with Airflow or similar.

Frequently asked questions

Can I train a GNN directly on Snowflake data?

Not directly. PyG requires in-memory tensors. You must extract data from Snowflake, convert it to PyG format (HeteroData), and load it into memory or a feature store. For large datasets, use Snowpark to run preprocessing inside Snowflake before extracting.

How do I extract a graph from Snowflake tables?

Each table becomes a node type. JOIN relationships (foreign keys) become edge types. Export each table's rows as node features and each foreign key relationship as an edge list. Use COPY INTO to export to Parquet, then load into PyG with torch.load or a custom dataloader.

Can KumoRFM connect to Snowflake natively?

Yes. KumoRFM connects to Snowflake via native integration. You point it at your Snowflake tables, define a prediction task in PQL, and it builds the heterogeneous temporal graph automatically. No data export, no PyG code, no ETL pipeline.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.