Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Use Case11 min read

Churn Prediction: GATConv on SaaS User-Activity Graphs

SaaS companies lose 5-7% of revenue annually to churn. Flat-table models miss the social and behavioral context that predicts disengagement. Here is how to build a GNN that sees the full user-activity graph.

PyTorch Geometric

TL;DR

  • 1SaaS churn is a graph problem. Users exist in context: their team, the features they adopt, the support tickets they file. A user whose team has gone quiet is at higher risk than their individual metrics suggest.
  • 2GATConv learns attention weights per neighbor, distinguishing active teammates (retention signal) from dormant accounts (noise). GCNConv treats all neighbors equally and misses this.
  • 3On RelBench churn benchmarks, GNNs reach 75.83 AUROC vs 62.44 for LightGBM. The relational context provides 13+ points of lift.
  • 4The PyG model is ~35 lines, but production churn systems also need temporal feature pipelines, retraining automation, and integration with CRM tools.
  • 5KumoRFM predicts churn with one PQL query (76.71 AUROC zero-shot), automatically constructing the user-activity graph from your SaaS database.

The business problem

For SaaS companies, reducing churn by just 1% can increase company valuation by 12%. Customer acquisition costs 5-25x more than retention, making churn prediction one of the highest-ROI ML applications. The goal: identify at-risk accounts 30-90 days before cancellation, giving customer success teams time to intervene.

Traditional churn models use features like login frequency, feature adoption percentage, support ticket count, and contract renewal date. These features describe the user in isolation. They miss the relational context: is the user's entire team disengaging? Are they the only active user on a 50-seat account? Did their champion (the power user who drove adoption) just leave?

Why flat ML fails

Flat-table churn models (logistic regression, XGBoost) operate on per-user feature vectors. They cannot capture:

  • Social anchoring: Users with active teammates churn 60% less. But “number of active teammates” is a lossy feature that misses who those teammates are and what they do.
  • Adoption depth: A user who uses 3 features is not the same as a user who uses 3 different features. The feature-usage graph captures which feature clusters drive stickiness.
  • Contagion effects: Churn is social. When a team lead churns, their reports follow within 60 days. Flat models cannot propagate this risk signal through the graph.
  • Temporal patterns: The sequence of feature adoption and disadoption matters. A user who tried a feature and stopped is different from one who never tried it.

The relational schema

schema.txt
Node types:
  User      (id, role, signup_date, plan_tier)
  Feature   (id, category, complexity, release_date)
  Account   (id, seats, mrr, contract_end)
  Ticket    (id, severity, resolution_time, csat_score)

Edge types:
  User    --[uses]-->        Feature  (frequency, last_used)
  User    --[belongs_to]-->  Account
  User    --[filed]-->       Ticket
  User    --[collaborates]--> User    (shared_docs, messages)

Four node types capture the full context: who the user is, what they use, who they work with, and how they interact with support.

PyG architecture: GATConv with attention

We use GATConv because neighbor importance varies dramatically. An active teammate sending daily messages is a much stronger retention signal than a dormant user on the same account. Attention weights let the model learn this distinction automatically.

churn_model.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv, HeteroConv, Linear

class ChurnGNN(torch.nn.Module):
    def __init__(self, hidden_dim=64, heads=4):
        super().__init__()
        self.user_lin = Linear(-1, hidden_dim)
        self.feature_lin = Linear(-1, hidden_dim)
        self.account_lin = Linear(-1, hidden_dim)
        self.ticket_lin = Linear(-1, hidden_dim)

        self.conv1 = HeteroConv({
            ('user', 'uses', 'feature'): GATConv(
                hidden_dim, hidden_dim // heads, heads=heads),
            ('user', 'belongs_to', 'account'): GATConv(
                hidden_dim, hidden_dim // heads, heads=heads),
            ('user', 'filed', 'ticket'): GATConv(
                hidden_dim, hidden_dim // heads, heads=heads),
            ('user', 'collaborates', 'user'): GATConv(
                hidden_dim, hidden_dim // heads, heads=heads),
        }, aggr='sum')

        self.conv2 = HeteroConv({
            ('user', 'uses', 'feature'): GATConv(
                hidden_dim, hidden_dim // heads, heads=heads),
            ('user', 'belongs_to', 'account'): GATConv(
                hidden_dim, hidden_dim // heads, heads=heads),
            ('user', 'filed', 'ticket'): GATConv(
                hidden_dim, hidden_dim // heads, heads=heads),
            ('user', 'collaborates', 'user'): GATConv(
                hidden_dim, hidden_dim // heads, heads=heads),
        }, aggr='sum')

        self.classifier = Linear(hidden_dim, 1)

    def forward(self, x_dict, edge_index_dict):
        x_dict['user'] = self.user_lin(x_dict['user'])
        x_dict['feature'] = self.feature_lin(x_dict['feature'])
        x_dict['account'] = self.account_lin(x_dict['account'])
        x_dict['ticket'] = self.ticket_lin(x_dict['ticket'])

        x_dict = {k: F.elu(v) for k, v in
                  self.conv1(x_dict, edge_index_dict).items()}
        x_dict = self.conv2(x_dict, edge_index_dict)

        return torch.sigmoid(
            self.classifier(x_dict['user']).squeeze(-1))

HeteroConv wrapping GATConv per edge type. Attention weights automatically learn which teammates, features, and tickets are most predictive of churn.

Training considerations

  • Label definition: Define churn as “no login for 30 consecutive days” or “cancellation within 90 days.” The choice affects model behavior significantly.
  • Temporal split: Train on users who churned/stayed before time T, predict for users active at T. Never use future activity as features.
  • Class imbalance: Monthly churn rates of 2-5% mean 95%+ of labels are negative. Use focal loss or SMOTE-style oversampling on the graph.
  • Feature engineering: Edge features (usage frequency, recency) carry strong signal. Encode them as edge attributes in PyG's HeteroData.

Expected performance

  • Logistic regression (flat): ~58 AUROC
  • LightGBM (flat-table): 62.44 AUROC
  • GNN (GATConv heterogeneous): 75.83 AUROC
  • KumoRFM (zero-shot): 76.71 AUROC

Or use KumoRFM in one line

KumoRFM PQL
PREDICT is_churned FOR user
USING user, feature, account, ticket

One PQL query. KumoRFM auto-constructs the user-activity graph, handles temporal dynamics, and outputs churn probabilities per user.

KumoRFM replaces the graph construction, model architecture, training loop, and temporal feature engineering with a single query. It achieves 76.71 AUROC zero-shot, slightly exceeding hand-tuned GATConv models, in minutes instead of months.

Frequently asked questions

Why are GNNs better than logistic regression for churn prediction?

Logistic regression sees each user's features in isolation. GNNs see each user in the context of their team, the features they use, and the support interactions they have had. A user whose entire team has gone quiet is at much higher churn risk than their individual metrics suggest.

Why use GATConv instead of GCNConv for churn?

GATConv learns attention weights per neighbor, so it can distinguish between a user's daily active teammates (strong retention signal) and dormant accounts on the same plan (weak signal). GCNConv treats all neighbors equally, diluting the signal from high-value connections.

How do you define the user-activity graph for SaaS churn?

Nodes are users, features, support tickets, and accounts. Edges connect users to features they use (with frequency), users to their account, users to support tickets they filed, and users to other users they collaborate with. The graph structure captures adoption depth and social anchoring.

What is the right prediction window for SaaS churn?

Typically 30-90 days. Predicting churn 30 days out gives the customer success team time to intervene. Shorter windows (7 days) are more accurate but leave less time to act. Longer windows (180 days) are less accurate but enable proactive engagement programs.

How does KumoRFM handle churn prediction differently?

KumoRFM automatically constructs a temporal heterogeneous graph from your SaaS database, capturing user-feature, user-account, and user-user relationships with temporal dynamics. It predicts churn with one PQL query, no graph construction or model training required.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.