The business problem
Recommendation engines drive 35% of purchases on Amazon and 75% of what people watch on Netflix. A 1% improvement in recommendation quality translates to millions in incremental revenue for large e-commerce platforms. The challenge is scale: millions of users, millions of products, and billions of interactions, all changing in real time.
Traditional collaborative filtering (matrix factorization) decomposes the user-item interaction matrix into latent factors. It works well for popular items with many interactions but fails on cold-start items, long-tail products, and cross-category recommendations. It also ignores item attributes, user demographics, and session context.
Why flat ML fails
Feature-based models (XGBoost on user-item feature pairs) add attributes but lose the collaborative signal. They treat each user-item pair independently, missing the transitive patterns that make recommendations powerful:
- No transitivity: “Users who bought A also bought B” requires seeing the graph, not just features of A and B
- Cold-start blindness: New items with zero interactions get zero signal, regardless of how rich their features are
- Category silos: Cross-category patterns (camera buyers also buy tripods) require multi-hop graph traversal
- Temporal decay: Recent interactions should matter more, but flat features lose temporal ordering
The relational schema
Node types:
User (id, age_bucket, geo, signup_date)
Product (id, category, brand, price, description_emb)
Session (id, timestamp, device, duration)
Edge types:
User --[purchased]--> Product (timestamp, quantity)
User --[viewed]--> Product (timestamp, dwell_time)
User --[has_session]--> Session
Session --[contains]--> Product (position, click)
Product --[co_purchased]--> Product (frequency)Three node types, five edge types. The co_purchased edges encode collaborative signal directly in the graph.
PyG architecture: two-tower GraphSAGE
The two-tower design separates user and item encoding for efficient serving. Each tower uses SAGEConv layers to aggregate neighborhood information. The user tower aggregates from purchased/viewed products. The item tower aggregates from users who interacted and co-purchased items.
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, Linear
class TwoTowerGNN(torch.nn.Module):
def __init__(self, hidden_dim=128):
super().__init__()
# User tower
self.user_lin = Linear(-1, hidden_dim)
self.user_conv1 = SAGEConv(hidden_dim, hidden_dim)
self.user_conv2 = SAGEConv(hidden_dim, hidden_dim)
# Item tower
self.item_lin = Linear(-1, hidden_dim)
self.item_conv1 = SAGEConv(hidden_dim, hidden_dim)
self.item_conv2 = SAGEConv(hidden_dim, hidden_dim)
def encode_users(self, x_user, edge_index_u2i):
x = self.user_lin(x_user)
x = F.relu(self.user_conv1(x, edge_index_u2i))
x = self.user_conv2(x, edge_index_u2i)
return F.normalize(x, dim=-1)
def encode_items(self, x_item, edge_index_i2i):
x = self.item_lin(x_item)
x = F.relu(self.item_conv1(x, edge_index_i2i))
x = self.item_conv2(x, edge_index_i2i)
return F.normalize(x, dim=-1)
def forward(self, x_user, x_item, edge_u2i, edge_i2i,
pos_edges, neg_edges):
user_emb = self.encode_users(x_user, edge_u2i)
item_emb = self.encode_items(x_item, edge_i2i)
# BPR loss: positive pairs score > negative pairs
pos_score = (user_emb[pos_edges[0]] *
item_emb[pos_edges[1]]).sum(dim=-1)
neg_score = (user_emb[neg_edges[0]] *
item_emb[neg_edges[1]]).sum(dim=-1)
loss = -F.logsigmoid(pos_score - neg_score).mean()
return lossTwo-tower GraphSAGE with BPR loss. At serving time, precompute item embeddings, then retrieve top-K via FAISS or ScaNN for sub-10ms latency.
Training considerations
- Negative sampling: For each positive (user, item) pair, sample 5-10 negative items. Use popularity-weighted sampling to create harder negatives.
- Mini-batch training: Use PyG's
LinkNeighborLoaderto sample subgraphs around training edges. Full-graph training is infeasible at scale. - Temporal split: Train on interactions before time T, validate on T to T+7 days, test on T+7 to T+14. Never leak future interactions.
- Embedding refresh: Item embeddings can be precomputed daily. User embeddings need real-time or near-real-time updates as new interactions arrive.
Expected performance
On RelBench recommendation benchmarks:
- Popularity baseline: ~45 AUROC
- Matrix factorization: ~55 AUROC
- LightGBM (flat-table): 62.44 AUROC
- GNN (GraphSAGE two-tower): 75.83 AUROC
- KumoRFM (zero-shot): 76.71 AUROC
Or use KumoRFM in one line
PREDICT product_id FOR user
USING user, product, purchase, view, sessionOne PQL query. KumoRFM constructs the user-product graph, trains a relational graph transformer, and serves top-K recommendations via API.
KumoRFM handles graph construction, architecture selection, negative sampling, temporal splitting, and serving automatically. It achieves 76.71 AUROC zero-shot, matching or exceeding hand-tuned two-tower models that take months to build and deploy.