A two-tower model uses separate encoders for users and items in recommendation systems. The user tower maps user features and interaction history to an embedding vector. The item tower maps item features to an embedding vector in the same space. Relevance is scored by the dot product between the two embeddings. This simple architecture enables recommendations at massive scale because item embeddings can be precomputed and indexed for sub-millisecond retrieval.
Every major tech company uses some variant of two-tower models for retrieval: YouTube, Netflix, Spotify, Amazon, Pinterest. The architecture scales to billions of users and millions of items because encoding and scoring are decoupled.
Why two towers instead of one
A single cross-attention model that jointly processes user and item features produces better scores but cannot scale. To recommend from 10 million items, you would need 10 million forward passes per user request. A two-tower model requires exactly 1 forward pass (the user tower) plus an ANN lookup.
- One model: score = f(user_features, item_features). Must run for every candidate item. O(n) at serving time.
- Two towers: user_emb = f_user(user_features), item_emb = f_item(item_features). Score = dot(user_emb, item_emb). Item embeddings precomputed. O(1) + ANN at serving time.
GNN-enhanced two-tower model
Standard two towers encode users and items from features alone. GNN-enhanced towers also incorporate the interaction graph:
import torch
from torch_geometric.nn import SAGEConv
class GNNTower(torch.nn.Module):
"""Shared architecture for user and item towers."""
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.conv1 = SAGEConv(in_dim, hidden_dim)
self.conv2 = SAGEConv(hidden_dim, out_dim)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
return self.conv2(x, edge_index)
class TwoTowerGNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.user_tower = GNNTower(16, 64, 128)
self.item_tower = GNNTower(32, 64, 128)
def encode_users(self, user_x, user_edge_index):
return self.user_tower(user_x, user_edge_index)
def encode_items(self, item_x, item_edge_index):
return self.item_tower(item_x, item_edge_index)
def score(self, user_emb, item_emb):
return (user_emb * item_emb).sum(dim=-1) # dot productEach tower is a GNN that encodes structural information. Users learn from their purchased items; items learn from their user base.
Training with contrastive loss
Two-tower models are trained with contrastive learning: for each user, push their embedding close to purchased items (positives) and far from non-purchased items (negatives):
import torch.nn.functional as F
def contrastive_loss(user_emb, pos_item_emb, neg_item_embs):
# Positive score: user's actual purchase
pos_score = (user_emb * pos_item_emb).sum(dim=-1)
# Negative scores: random non-purchased items
neg_scores = (user_emb.unsqueeze(1) * neg_item_embs).sum(dim=-1)
# InfoNCE loss: maximize pos_score relative to negatives
logits = torch.cat([pos_score.unsqueeze(1), neg_scores], dim=1)
labels = torch.zeros(logits.size(0), dtype=torch.long)
return F.cross_entropy(logits / temperature, labels)
# For each user:
# - Sample 1 positive item (actual purchase)
# - Sample K negative items (non-purchased)
# - Push user embedding toward positive, away from negativesInfoNCE contrastive loss. Temperature controls the sharpness of the distribution over candidates.
Enterprise example: e-commerce product discovery
An e-commerce platform with 50 million users and 5 million products needs sub-50ms recommendation latency:
- Offline: train GNN two-tower model on the user-product interaction graph. Compute all 5M product embeddings. Index them with FAISS.
- Online (per request): compute user embedding (1 GNN forward pass with neighbor sampling, ~10ms). Query FAISS for top 100 nearest product embeddings (~5ms). Re-rank top 100 with a cross-attention model (~20ms).
- Total latency: ~35ms for personalized recommendations from 5 million products.
The GNN towers ensure that a user who bought running shoes gets recommendations reflecting collaborative signal: other users who bought those same shoes also bought running socks, fitness trackers, and hydration packs.