The business problem
Customer lifetime value determines how much to spend on acquisition, which customers to prioritize for retention, and how to allocate service resources. A customer who will spend $50K over 3 years justifies 10x more acquisition cost than one who will spend $500. Yet most companies use simple heuristics (RFM scoring, average order value) that miss the behavioral patterns that distinguish high-value trajectories from one-time buyers.
The difference between a $10K and $100K customer often is not in their first purchase. It is in what they bought, what similar customers did next, and how their engagement pattern matches known high-value trajectories. This is graph-structured information.
Why flat ML fails
- RFM is lossy: Recency, Frequency, Monetary collapse the entire purchase history into 3 numbers. Two customers with identical RFM scores can have radically different trajectories based on what they bought and how their behavior compares to similar customers.
- No cross-customer signal: Flat models predict each customer independently. GNNs leverage collaborative signal: if customers who buy product A then product B become high-value, a new customer who just bought A gets a higher LTV estimate.
- Category blindness: A customer who buys across 5 categories has different expansion potential than one concentrated in 1 category. The purchase graph naturally captures this breadth.
- Early-stage prediction: After just 1-2 purchases, individual features are sparse. Graph context from similar customers fills the gap.
The relational schema
Node types:
Customer (id, signup_date, channel, geo)
Product (id, category, price, margin)
Category (id, name, avg_basket_size)
Edge types:
Customer --[purchased]--> Product (amount, timestamp, qty)
Product --[belongs_to]--> Category
Customer --[co_shopper]--> Customer (shared_items_count)Three node types. The co_shopper edges connect customers with overlapping purchase histories, enabling collaborative LTV prediction.
PyG architecture: SAGEConv + regression head
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, HeteroConv, Linear
class LTVGNN(torch.nn.Module):
def __init__(self, hidden_dim=128):
super().__init__()
self.customer_lin = Linear(-1, hidden_dim)
self.product_lin = Linear(-1, hidden_dim)
self.category_lin = Linear(-1, hidden_dim)
self.conv1 = HeteroConv({
('customer', 'purchased', 'product'): SAGEConv(
hidden_dim, hidden_dim),
('product', 'belongs_to', 'category'): SAGEConv(
hidden_dim, hidden_dim),
('customer', 'co_shopper', 'customer'): SAGEConv(
hidden_dim, hidden_dim),
}, aggr='mean')
self.conv2 = HeteroConv({
('customer', 'purchased', 'product'): SAGEConv(
hidden_dim, hidden_dim),
('product', 'belongs_to', 'category'): SAGEConv(
hidden_dim, hidden_dim),
('customer', 'co_shopper', 'customer'): SAGEConv(
hidden_dim, hidden_dim),
}, aggr='mean')
# Regression head for LTV prediction
self.regressor = torch.nn.Sequential(
Linear(hidden_dim, 64),
torch.nn.ReLU(),
Linear(64, 1),
)
def forward(self, x_dict, edge_index_dict):
x_dict['customer'] = self.customer_lin(x_dict['customer'])
x_dict['product'] = self.product_lin(x_dict['product'])
x_dict['category'] = self.category_lin(x_dict['category'])
x_dict = {k: F.relu(v) for k, v in
self.conv1(x_dict, edge_index_dict).items()}
x_dict = self.conv2(x_dict, edge_index_dict)
# Predict 12-month forward LTV
return self.regressor(x_dict['customer']).squeeze(-1)HeteroConv aggregates purchase and co-shopper signals. The regression head outputs continuous LTV predictions. Train with Huber loss for robustness to outliers.
Training considerations
- Label definition: 12-month forward revenue is the standard LTV target. Use log-transform to handle the heavy-tailed distribution of customer spend.
- Loss function: Huber loss or quantile regression. MSE is sensitive to high-value outliers. Quantile regression gives prediction intervals for segmentation.
- Temporal split: Train on customers who joined before time T using purchases before T. Predict LTV for the 12 months after T. Validate on the next cohort.
- Co-shopper edges: Build from purchase overlap (Jaccard similarity on purchased product sets). Threshold at similarity > 0.1 to keep the graph manageable.
Expected performance
LTV is a regression task. The right metrics are Spearman rank correlation and normalized MAE, not AUROC:
- RFM heuristic: ~0.35 Spearman correlation
- LightGBM (flat-table): ~0.52 Spearman correlation
- GNN (SAGEConv regression): ~0.68 Spearman correlation
- KumoRFM (zero-shot): ~0.70 Spearman correlation
Or use KumoRFM in one line
PREDICT total_spend_12m FOR customer
USING customer, product, category, purchaseOne PQL query. KumoRFM constructs the purchase graph, captures cross-customer patterns, and outputs LTV predictions per customer.
KumoRFM replaces purchase graph construction, co-shopper edge computation, model training, and log-transform handling with a single query. It achieves 76.71 AUROC zero-shot, automatically discovering the collaborative patterns that drive LTV prediction.