Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Use Case11 min read

Customer LTV: Regression on Purchase Graphs

Knowing which customers will be worth $10K vs $100 over their lifetime changes every decision: acquisition spend, retention investment, and service levels. RFM models use 3 features. Here is how to use the entire purchase graph.

PyTorch Geometric

TL;DR

  • 1Customer LTV is a graph regression problem. Customers, products, and categories form a purchase network where cross-customer patterns predict future value.
  • 2SAGEConv with a regression head encodes each customer using their purchase history and the behavior of similar customers. GNNs do collaborative LTV prediction.
  • 3On purchase prediction benchmarks, GNNs improve Spearman rank correlation by 0.15-0.20 over flat-table LightGBM. The graph captures value trajectories that RFM features miss.
  • 4The PyG model is ~35 lines, but production LTV systems need integration with marketing automation, dynamic customer segmentation, and budget allocation models.
  • 5KumoRFM predicts LTV with one PQL query, automatically capturing cross-customer purchase patterns without graph construction or feature engineering.

The business problem

Customer lifetime value determines how much to spend on acquisition, which customers to prioritize for retention, and how to allocate service resources. A customer who will spend $50K over 3 years justifies 10x more acquisition cost than one who will spend $500. Yet most companies use simple heuristics (RFM scoring, average order value) that miss the behavioral patterns that distinguish high-value trajectories from one-time buyers.

The difference between a $10K and $100K customer often is not in their first purchase. It is in what they bought, what similar customers did next, and how their engagement pattern matches known high-value trajectories. This is graph-structured information.

Why flat ML fails

  • RFM is lossy: Recency, Frequency, Monetary collapse the entire purchase history into 3 numbers. Two customers with identical RFM scores can have radically different trajectories based on what they bought and how their behavior compares to similar customers.
  • No cross-customer signal: Flat models predict each customer independently. GNNs leverage collaborative signal: if customers who buy product A then product B become high-value, a new customer who just bought A gets a higher LTV estimate.
  • Category blindness: A customer who buys across 5 categories has different expansion potential than one concentrated in 1 category. The purchase graph naturally captures this breadth.
  • Early-stage prediction: After just 1-2 purchases, individual features are sparse. Graph context from similar customers fills the gap.

The relational schema

schema.txt
Node types:
  Customer  (id, signup_date, channel, geo)
  Product   (id, category, price, margin)
  Category  (id, name, avg_basket_size)

Edge types:
  Customer --[purchased]--> Product  (amount, timestamp, qty)
  Product  --[belongs_to]--> Category
  Customer --[co_shopper]--> Customer (shared_items_count)

Three node types. The co_shopper edges connect customers with overlapping purchase histories, enabling collaborative LTV prediction.

PyG architecture: SAGEConv + regression head

ltv_model.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, HeteroConv, Linear

class LTVGNN(torch.nn.Module):
    def __init__(self, hidden_dim=128):
        super().__init__()
        self.customer_lin = Linear(-1, hidden_dim)
        self.product_lin = Linear(-1, hidden_dim)
        self.category_lin = Linear(-1, hidden_dim)

        self.conv1 = HeteroConv({
            ('customer', 'purchased', 'product'): SAGEConv(
                hidden_dim, hidden_dim),
            ('product', 'belongs_to', 'category'): SAGEConv(
                hidden_dim, hidden_dim),
            ('customer', 'co_shopper', 'customer'): SAGEConv(
                hidden_dim, hidden_dim),
        }, aggr='mean')

        self.conv2 = HeteroConv({
            ('customer', 'purchased', 'product'): SAGEConv(
                hidden_dim, hidden_dim),
            ('product', 'belongs_to', 'category'): SAGEConv(
                hidden_dim, hidden_dim),
            ('customer', 'co_shopper', 'customer'): SAGEConv(
                hidden_dim, hidden_dim),
        }, aggr='mean')

        # Regression head for LTV prediction
        self.regressor = torch.nn.Sequential(
            Linear(hidden_dim, 64),
            torch.nn.ReLU(),
            Linear(64, 1),
        )

    def forward(self, x_dict, edge_index_dict):
        x_dict['customer'] = self.customer_lin(x_dict['customer'])
        x_dict['product'] = self.product_lin(x_dict['product'])
        x_dict['category'] = self.category_lin(x_dict['category'])

        x_dict = {k: F.relu(v) for k, v in
                  self.conv1(x_dict, edge_index_dict).items()}
        x_dict = self.conv2(x_dict, edge_index_dict)

        # Predict 12-month forward LTV
        return self.regressor(x_dict['customer']).squeeze(-1)

HeteroConv aggregates purchase and co-shopper signals. The regression head outputs continuous LTV predictions. Train with Huber loss for robustness to outliers.

Training considerations

  • Label definition: 12-month forward revenue is the standard LTV target. Use log-transform to handle the heavy-tailed distribution of customer spend.
  • Loss function: Huber loss or quantile regression. MSE is sensitive to high-value outliers. Quantile regression gives prediction intervals for segmentation.
  • Temporal split: Train on customers who joined before time T using purchases before T. Predict LTV for the 12 months after T. Validate on the next cohort.
  • Co-shopper edges: Build from purchase overlap (Jaccard similarity on purchased product sets). Threshold at similarity > 0.1 to keep the graph manageable.

Expected performance

LTV is a regression task. The right metrics are Spearman rank correlation and normalized MAE, not AUROC:

  • RFM heuristic: ~0.35 Spearman correlation
  • LightGBM (flat-table): ~0.52 Spearman correlation
  • GNN (SAGEConv regression): ~0.68 Spearman correlation
  • KumoRFM (zero-shot): ~0.70 Spearman correlation

Or use KumoRFM in one line

KumoRFM PQL
PREDICT total_spend_12m FOR customer
USING customer, product, category, purchase

One PQL query. KumoRFM constructs the purchase graph, captures cross-customer patterns, and outputs LTV predictions per customer.

KumoRFM replaces purchase graph construction, co-shopper edge computation, model training, and log-transform handling with a single query. It achieves 76.71 AUROC zero-shot, automatically discovering the collaborative patterns that drive LTV prediction.

Frequently asked questions

Why use GNNs for customer LTV instead of RFM models?

RFM (Recency, Frequency, Monetary) models use three aggregate features per customer. GNNs use the entire purchase graph: what products a customer bought, who else bought those products, what those similar customers bought next, and how purchase patterns cluster. This captures cross-customer signals that predict future value far better than individual RFM features.

What graph structure works for LTV prediction?

Nodes are customers, products, and categories. Edges connect customers to products they purchased (with amount and timestamp), products to their categories, and customers to customer segments. The graph structure captures purchase affinity patterns that drive future spend.

How do you frame LTV as a regression problem in PyG?

Each customer node gets a target value: total spend in the next 12 months. The GNN produces a continuous embedding per customer, which feeds into a regression head (linear layers). Train with MSE or Huber loss. The graph context helps the model distinguish high-value trajectories from one-time purchasers.

How does the graph help with new customer LTV prediction?

A new customer with just one purchase can be embedded by looking at the product they bought and who else bought that product. The GNN transfers value patterns from similar customers, providing a much better LTV estimate than individual purchase history alone. This is essentially collaborative LTV prediction.

Can KumoRFM predict customer LTV from purchase data?

Yes. KumoRFM takes your customer, product, and transaction tables and predicts lifetime value with a single PQL query. It automatically constructs the purchase graph, captures cross-customer patterns, and outputs LTV predictions per customer.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.