The business problem
For SaaS companies, reducing churn by just 1% can increase company valuation by 12%. Customer acquisition costs 5-25x more than retention, making churn prediction one of the highest-ROI ML applications. The goal: identify at-risk accounts 30-90 days before cancellation, giving customer success teams time to intervene.
Traditional churn models use features like login frequency, feature adoption percentage, support ticket count, and contract renewal date. These features describe the user in isolation. They miss the relational context: is the user's entire team disengaging? Are they the only active user on a 50-seat account? Did their champion (the power user who drove adoption) just leave?
Why flat ML fails
Flat-table churn models (logistic regression, XGBoost) operate on per-user feature vectors. They cannot capture:
- Social anchoring: Users with active teammates churn 60% less. But “number of active teammates” is a lossy feature that misses who those teammates are and what they do.
- Adoption depth: A user who uses 3 features is not the same as a user who uses 3 different features. The feature-usage graph captures which feature clusters drive stickiness.
- Contagion effects: Churn is social. When a team lead churns, their reports follow within 60 days. Flat models cannot propagate this risk signal through the graph.
- Temporal patterns: The sequence of feature adoption and disadoption matters. A user who tried a feature and stopped is different from one who never tried it.
The relational schema
Node types:
User (id, role, signup_date, plan_tier)
Feature (id, category, complexity, release_date)
Account (id, seats, mrr, contract_end)
Ticket (id, severity, resolution_time, csat_score)
Edge types:
User --[uses]--> Feature (frequency, last_used)
User --[belongs_to]--> Account
User --[filed]--> Ticket
User --[collaborates]--> User (shared_docs, messages)Four node types capture the full context: who the user is, what they use, who they work with, and how they interact with support.
PyG architecture: GATConv with attention
We use GATConv because neighbor importance varies dramatically. An active teammate sending daily messages is a much stronger retention signal than a dormant user on the same account. Attention weights let the model learn this distinction automatically.
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv, HeteroConv, Linear
class ChurnGNN(torch.nn.Module):
def __init__(self, hidden_dim=64, heads=4):
super().__init__()
self.user_lin = Linear(-1, hidden_dim)
self.feature_lin = Linear(-1, hidden_dim)
self.account_lin = Linear(-1, hidden_dim)
self.ticket_lin = Linear(-1, hidden_dim)
self.conv1 = HeteroConv({
('user', 'uses', 'feature'): GATConv(
hidden_dim, hidden_dim // heads, heads=heads),
('user', 'belongs_to', 'account'): GATConv(
hidden_dim, hidden_dim // heads, heads=heads),
('user', 'filed', 'ticket'): GATConv(
hidden_dim, hidden_dim // heads, heads=heads),
('user', 'collaborates', 'user'): GATConv(
hidden_dim, hidden_dim // heads, heads=heads),
}, aggr='sum')
self.conv2 = HeteroConv({
('user', 'uses', 'feature'): GATConv(
hidden_dim, hidden_dim // heads, heads=heads),
('user', 'belongs_to', 'account'): GATConv(
hidden_dim, hidden_dim // heads, heads=heads),
('user', 'filed', 'ticket'): GATConv(
hidden_dim, hidden_dim // heads, heads=heads),
('user', 'collaborates', 'user'): GATConv(
hidden_dim, hidden_dim // heads, heads=heads),
}, aggr='sum')
self.classifier = Linear(hidden_dim, 1)
def forward(self, x_dict, edge_index_dict):
x_dict['user'] = self.user_lin(x_dict['user'])
x_dict['feature'] = self.feature_lin(x_dict['feature'])
x_dict['account'] = self.account_lin(x_dict['account'])
x_dict['ticket'] = self.ticket_lin(x_dict['ticket'])
x_dict = {k: F.elu(v) for k, v in
self.conv1(x_dict, edge_index_dict).items()}
x_dict = self.conv2(x_dict, edge_index_dict)
return torch.sigmoid(
self.classifier(x_dict['user']).squeeze(-1))HeteroConv wrapping GATConv per edge type. Attention weights automatically learn which teammates, features, and tickets are most predictive of churn.
Training considerations
- Label definition: Define churn as “no login for 30 consecutive days” or “cancellation within 90 days.” The choice affects model behavior significantly.
- Temporal split: Train on users who churned/stayed before time T, predict for users active at T. Never use future activity as features.
- Class imbalance: Monthly churn rates of 2-5% mean 95%+ of labels are negative. Use focal loss or SMOTE-style oversampling on the graph.
- Feature engineering: Edge features (usage frequency, recency) carry strong signal. Encode them as edge attributes in PyG's
HeteroData.
Expected performance
- Logistic regression (flat): ~58 AUROC
- LightGBM (flat-table): 62.44 AUROC
- GNN (GATConv heterogeneous): 75.83 AUROC
- KumoRFM (zero-shot): 76.71 AUROC
Or use KumoRFM in one line
PREDICT is_churned FOR user
USING user, feature, account, ticketOne PQL query. KumoRFM auto-constructs the user-activity graph, handles temporal dynamics, and outputs churn probabilities per user.
KumoRFM replaces the graph construction, model architecture, training loop, and temporal feature engineering with a single query. It achieves 76.71 AUROC zero-shot, slightly exceeding hand-tuned GATConv models, in minutes instead of months.