Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide8 min read

Two-Tower Models: Separate Encoders for Users and Items

A two-tower model encodes users and items independently, enabling scalable recommendations via precomputed embeddings and approximate nearest neighbor search. GNNs make the towers structure-aware.

PyTorch Geometric

TL;DR

  • 1A two-tower model has separate encoders (towers) for users and items. Each produces an embedding. Relevance = dot product between embeddings. This decouples encoding from scoring.
  • 2Scalability: precompute all item embeddings offline. At serving time, compute only the user embedding and find nearest items via ANN search. Sub-10ms latency at millions of items.
  • 3GNN-enhanced towers use message passing on the user-item interaction graph. User embeddings reflect purchase history; item embeddings reflect their user base. Collaborative filtering as architecture.
  • 4Training uses contrastive learning with negative sampling: push user embeddings close to purchased items, far from non-purchased items.
  • 5Cold start remains the key challenge. New entities have no graph neighborhood. Feature-based initialization and content-based fallback are standard solutions.

A two-tower model uses separate encoders for users and items in recommendation systems. The user tower maps user features and interaction history to an embedding vector. The item tower maps item features to an embedding vector in the same space. Relevance is scored by the dot product between the two embeddings. This simple architecture enables recommendations at massive scale because item embeddings can be precomputed and indexed for sub-millisecond retrieval.

Every major tech company uses some variant of two-tower models for retrieval: YouTube, Netflix, Spotify, Amazon, Pinterest. The architecture scales to billions of users and millions of items because encoding and scoring are decoupled.

Why two towers instead of one

A single cross-attention model that jointly processes user and item features produces better scores but cannot scale. To recommend from 10 million items, you would need 10 million forward passes per user request. A two-tower model requires exactly 1 forward pass (the user tower) plus an ANN lookup.

  • One model: score = f(user_features, item_features). Must run for every candidate item. O(n) at serving time.
  • Two towers: user_emb = f_user(user_features), item_emb = f_item(item_features). Score = dot(user_emb, item_emb). Item embeddings precomputed. O(1) + ANN at serving time.

GNN-enhanced two-tower model

Standard two towers encode users and items from features alone. GNN-enhanced towers also incorporate the interaction graph:

gnn_two_tower.py
import torch
from torch_geometric.nn import SAGEConv

class GNNTower(torch.nn.Module):
    """Shared architecture for user and item towers."""
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.conv1 = SAGEConv(in_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, out_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

class TwoTowerGNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.user_tower = GNNTower(16, 64, 128)
        self.item_tower = GNNTower(32, 64, 128)

    def encode_users(self, user_x, user_edge_index):
        return self.user_tower(user_x, user_edge_index)

    def encode_items(self, item_x, item_edge_index):
        return self.item_tower(item_x, item_edge_index)

    def score(self, user_emb, item_emb):
        return (user_emb * item_emb).sum(dim=-1)  # dot product

Each tower is a GNN that encodes structural information. Users learn from their purchased items; items learn from their user base.

Training with contrastive loss

Two-tower models are trained with contrastive learning: for each user, push their embedding close to purchased items (positives) and far from non-purchased items (negatives):

contrastive_training.py
import torch.nn.functional as F

def contrastive_loss(user_emb, pos_item_emb, neg_item_embs):
    # Positive score: user's actual purchase
    pos_score = (user_emb * pos_item_emb).sum(dim=-1)

    # Negative scores: random non-purchased items
    neg_scores = (user_emb.unsqueeze(1) * neg_item_embs).sum(dim=-1)

    # InfoNCE loss: maximize pos_score relative to negatives
    logits = torch.cat([pos_score.unsqueeze(1), neg_scores], dim=1)
    labels = torch.zeros(logits.size(0), dtype=torch.long)
    return F.cross_entropy(logits / temperature, labels)

# For each user:
# - Sample 1 positive item (actual purchase)
# - Sample K negative items (non-purchased)
# - Push user embedding toward positive, away from negatives

InfoNCE contrastive loss. Temperature controls the sharpness of the distribution over candidates.

Enterprise example: e-commerce product discovery

An e-commerce platform with 50 million users and 5 million products needs sub-50ms recommendation latency:

  1. Offline: train GNN two-tower model on the user-product interaction graph. Compute all 5M product embeddings. Index them with FAISS.
  2. Online (per request): compute user embedding (1 GNN forward pass with neighbor sampling, ~10ms). Query FAISS for top 100 nearest product embeddings (~5ms). Re-rank top 100 with a cross-attention model (~20ms).
  3. Total latency: ~35ms for personalized recommendations from 5 million products.

The GNN towers ensure that a user who bought running shoes gets recommendations reflecting collaborative signal: other users who bought those same shoes also bought running socks, fitness trackers, and hydration packs.

Frequently asked questions

What is a two-tower model?

A two-tower model uses two separate neural network encoders: one for users and one for items. Each encoder produces an embedding vector. The relevance score is the dot product (or cosine similarity) between user and item embeddings. This architecture enables precomputing all item embeddings offline and serving recommendations via fast approximate nearest neighbor search.

Why use two towers instead of one model?

Scalability. A single cross-attention model that scores every (user, item) pair requires O(users x items) forward passes. A two-tower model computes user and item embeddings independently, then scores with a dot product. Item embeddings are precomputed once. At serving time, only the user embedding is computed live, then matched against the precomputed item index.

How do GNNs improve two-tower models?

Standard two-tower models encode users and items from features alone. GNN-enhanced two-tower models also incorporate graph structure: a user's embedding reflects their interaction history (purchased items), and an item's embedding reflects its user base. This collaborative signal significantly improves recommendation quality.

What is the cold start problem in two-tower models?

New users with no interaction history and new items with no user engagement have poor embeddings because the GNN has no edges to propagate through. Solutions include feature-based initialization (use demographics/metadata), content-based fallback (encode item descriptions), and few-shot propagation from similar entities.

What is approximate nearest neighbor (ANN) search?

ANN algorithms (FAISS, ScaNN, HNSW) find the closest item embeddings to a user embedding without comparing against every item. They use indexing structures that trade a small accuracy loss for massive speed gains. At scale (millions of items), ANN search returns top-100 candidates in single-digit milliseconds.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.