Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Use Case11 min read

Product Recommendations: Two-Tower GraphSAGE for E-Commerce

Product recommendations drive 35% of Amazon's revenue. Collaborative filtering misses cold-start items and cross-category patterns. Here is how to build a GNN recommendation engine that sees the full user-product graph.

PyTorch Geometric

TL;DR

  • 1E-commerce recommendation is a link prediction problem on a bipartite user-product graph. GNNs capture multi-hop patterns: users who bought X also viewed Y, which is frequently bought with Z.
  • 2Two-tower GraphSAGE produces separate user and item embeddings. At serving time, retrieve top-K items via approximate nearest neighbor search in sub-10ms.
  • 3GraphSAGE is inductive, meaning new users and items get embeddings immediately from their features and connections, solving the cold-start problem that plagues collaborative filtering.
  • 4On RelBench recommendation benchmarks, GNNs achieve 75.83 AUROC vs 62.44 for flat-table baselines. The graph captures collaborative signal that feature engineering cannot replicate.
  • 5KumoRFM generates recommendations with one PQL query, automatically handling graph construction, architecture selection, and serving. 76.71 AUROC zero-shot.

The business problem

Recommendation engines drive 35% of purchases on Amazon and 75% of what people watch on Netflix. A 1% improvement in recommendation quality translates to millions in incremental revenue for large e-commerce platforms. The challenge is scale: millions of users, millions of products, and billions of interactions, all changing in real time.

Traditional collaborative filtering (matrix factorization) decomposes the user-item interaction matrix into latent factors. It works well for popular items with many interactions but fails on cold-start items, long-tail products, and cross-category recommendations. It also ignores item attributes, user demographics, and session context.

Why flat ML fails

Feature-based models (XGBoost on user-item feature pairs) add attributes but lose the collaborative signal. They treat each user-item pair independently, missing the transitive patterns that make recommendations powerful:

  • No transitivity: “Users who bought A also bought B” requires seeing the graph, not just features of A and B
  • Cold-start blindness: New items with zero interactions get zero signal, regardless of how rich their features are
  • Category silos: Cross-category patterns (camera buyers also buy tripods) require multi-hop graph traversal
  • Temporal decay: Recent interactions should matter more, but flat features lose temporal ordering

The relational schema

schema.txt
Node types:
  User     (id, age_bucket, geo, signup_date)
  Product  (id, category, brand, price, description_emb)
  Session  (id, timestamp, device, duration)

Edge types:
  User    --[purchased]-->  Product   (timestamp, quantity)
  User    --[viewed]-->     Product   (timestamp, dwell_time)
  User    --[has_session]--> Session
  Session --[contains]-->   Product   (position, click)
  Product --[co_purchased]--> Product (frequency)

Three node types, five edge types. The co_purchased edges encode collaborative signal directly in the graph.

PyG architecture: two-tower GraphSAGE

The two-tower design separates user and item encoding for efficient serving. Each tower uses SAGEConv layers to aggregate neighborhood information. The user tower aggregates from purchased/viewed products. The item tower aggregates from users who interacted and co-purchased items.

rec_model.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, Linear

class TwoTowerGNN(torch.nn.Module):
    def __init__(self, hidden_dim=128):
        super().__init__()
        # User tower
        self.user_lin = Linear(-1, hidden_dim)
        self.user_conv1 = SAGEConv(hidden_dim, hidden_dim)
        self.user_conv2 = SAGEConv(hidden_dim, hidden_dim)

        # Item tower
        self.item_lin = Linear(-1, hidden_dim)
        self.item_conv1 = SAGEConv(hidden_dim, hidden_dim)
        self.item_conv2 = SAGEConv(hidden_dim, hidden_dim)

    def encode_users(self, x_user, edge_index_u2i):
        x = self.user_lin(x_user)
        x = F.relu(self.user_conv1(x, edge_index_u2i))
        x = self.user_conv2(x, edge_index_u2i)
        return F.normalize(x, dim=-1)

    def encode_items(self, x_item, edge_index_i2i):
        x = self.item_lin(x_item)
        x = F.relu(self.item_conv1(x, edge_index_i2i))
        x = self.item_conv2(x, edge_index_i2i)
        return F.normalize(x, dim=-1)

    def forward(self, x_user, x_item, edge_u2i, edge_i2i,
                pos_edges, neg_edges):
        user_emb = self.encode_users(x_user, edge_u2i)
        item_emb = self.encode_items(x_item, edge_i2i)

        # BPR loss: positive pairs score > negative pairs
        pos_score = (user_emb[pos_edges[0]] *
                     item_emb[pos_edges[1]]).sum(dim=-1)
        neg_score = (user_emb[neg_edges[0]] *
                     item_emb[neg_edges[1]]).sum(dim=-1)
        loss = -F.logsigmoid(pos_score - neg_score).mean()
        return loss

Two-tower GraphSAGE with BPR loss. At serving time, precompute item embeddings, then retrieve top-K via FAISS or ScaNN for sub-10ms latency.

Training considerations

  • Negative sampling: For each positive (user, item) pair, sample 5-10 negative items. Use popularity-weighted sampling to create harder negatives.
  • Mini-batch training: Use PyG's LinkNeighborLoader to sample subgraphs around training edges. Full-graph training is infeasible at scale.
  • Temporal split: Train on interactions before time T, validate on T to T+7 days, test on T+7 to T+14. Never leak future interactions.
  • Embedding refresh: Item embeddings can be precomputed daily. User embeddings need real-time or near-real-time updates as new interactions arrive.

Expected performance

On RelBench recommendation benchmarks:

  • Popularity baseline: ~45 AUROC
  • Matrix factorization: ~55 AUROC
  • LightGBM (flat-table): 62.44 AUROC
  • GNN (GraphSAGE two-tower): 75.83 AUROC
  • KumoRFM (zero-shot): 76.71 AUROC

Or use KumoRFM in one line

KumoRFM PQL
PREDICT product_id FOR user
USING user, product, purchase, view, session

One PQL query. KumoRFM constructs the user-product graph, trains a relational graph transformer, and serves top-K recommendations via API.

KumoRFM handles graph construction, architecture selection, negative sampling, temporal splitting, and serving automatically. It achieves 76.71 AUROC zero-shot, matching or exceeding hand-tuned two-tower models that take months to build and deploy.

Frequently asked questions

Why use GNNs instead of collaborative filtering for recommendations?

Collaborative filtering (matrix factorization) only uses the user-item interaction matrix. GNNs additionally encode item attributes, user demographics, session sequences, and multi-hop connections (users who bought X also browsed Y). This richer signal is especially valuable for cold-start users and items with few interactions.

What is a two-tower GNN architecture?

A two-tower model produces separate embeddings for users and items using two GNN encoders. At inference, you compute the user embedding once, then retrieve the top-K nearest item embeddings using approximate nearest neighbor search. This decoupled design enables sub-10ms retrieval over millions of items.

How does GraphSAGE handle new users with no purchase history?

GraphSAGE is inductive: it learns an aggregation function, not fixed embeddings. A new user with even a single browsing session can be embedded by aggregating features from the items they viewed. This is a major advantage over matrix factorization, which requires retraining to add new users.

How many neighbors should GraphSAGE sample per hop?

Typical settings are 15-25 neighbors for the first hop and 10-15 for the second hop. More neighbors improve accuracy but increase compute. For real-time serving, use fewer neighbors (10, 5) with cached embeddings to keep latency under 50ms.

Can KumoRFM generate product recommendations without building a GNN?

Yes. KumoRFM takes your relational database (users, products, purchases, views) and generates recommendations with a single PQL query. It automatically constructs the heterogeneous graph, trains a production-grade graph transformer, and serves predictions via API.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.