The business problem
Global card fraud losses exceeded $30 billion in 2023 and are projected to reach $40 billion by 2027. Every basis point of improvement in detection saves millions. The challenge is not just catching fraud but catching it before the transaction completes, in under 100 milliseconds.
Traditional models see each transaction in isolation: amount, time, merchant category, velocity counters. They miss the relational signal. A $50 coffee purchase looks normal until you see that the card was used at a gas station 300 miles away 10 minutes ago, the merchant shares a terminal ID with three other flagged merchants, and the IP address is associated with a device used in five other fraud cases this week.
Why flat ML fails
Flat-table models like XGBoost or logistic regression operate on hand-engineered features derived from a single row of data. You can add velocity features (transactions per hour), aggregate features (average spend at this merchant), and even some network-derived features (degree of separation from known fraud). But these features are:
- Static snapshots that miss evolving patterns
- Manually engineered, requiring domain expertise to define and maintain
- Lossy, collapsing rich relational structure into a few numbers
- One-hop at best, missing the multi-hop patterns that define fraud rings
On RelBench fraud benchmarks, LightGBM with extensive feature engineering achieves 62.44 AUROC. A GNN that directly operates on the transaction graph achieves 75.83 AUROC, a gap that no amount of feature engineering can close because the signal lives in the graph structure itself.
The relational schema
A banking fraud graph typically includes these entities and relationships:
Node types:
Account (id, balance, account_age, type)
Merchant (id, category, avg_txn, terminal_count)
Device (id, os, ip_hash, first_seen)
Transaction (id, amount, timestamp, is_fraud)
Edge types:
Account --[sends_to]--> Account
Account --[transacts_at]--> Merchant
Account --[uses]--> Device
Transaction --[from]--> Account
Transaction --[to]--> MerchantFive node/edge types. Flat-table models collapse this into a single row. GNNs preserve the full structure.
PyG architecture: HeteroConv + GATConv
The heterogeneous schema demands HeteroConv, which wraps a separate GNN layer per edge type. Inside each edge type, we use GATConv so the model can learn which neighbors matter most. A transaction to a flagged merchant should carry more weight than a transaction to a grocery store.
import torch
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, GATConv, Linear
class FraudGNN(torch.nn.Module):
def __init__(self, hidden_dim=64, heads=4):
super().__init__()
# Project each node type to shared dim
self.account_lin = Linear(-1, hidden_dim)
self.merchant_lin = Linear(-1, hidden_dim)
self.device_lin = Linear(-1, hidden_dim)
# Layer 1: type-specific GAT convolutions
self.conv1 = HeteroConv({
('account', 'sends_to', 'account'): GATConv(
hidden_dim, hidden_dim // heads, heads=heads),
('account', 'transacts_at', 'merchant'): GATConv(
hidden_dim, hidden_dim // heads, heads=heads),
('account', 'uses', 'device'): GATConv(
hidden_dim, hidden_dim // heads, heads=heads),
}, aggr='sum')
# Layer 2: second hop
self.conv2 = HeteroConv({
('account', 'sends_to', 'account'): GATConv(
hidden_dim, hidden_dim // heads, heads=heads),
('account', 'transacts_at', 'merchant'): GATConv(
hidden_dim, hidden_dim // heads, heads=heads),
('account', 'uses', 'device'): GATConv(
hidden_dim, hidden_dim // heads, heads=heads),
}, aggr='sum')
self.classifier = Linear(hidden_dim, 1)
def forward(self, x_dict, edge_index_dict):
# Encode node features per type
x_dict['account'] = self.account_lin(x_dict['account'])
x_dict['merchant'] = self.merchant_lin(x_dict['merchant'])
x_dict['device'] = self.device_lin(x_dict['device'])
# Message passing (2 hops)
x_dict = self.conv1(x_dict, edge_index_dict)
x_dict = {k: F.elu(v) for k, v in x_dict.items()}
x_dict = self.conv2(x_dict, edge_index_dict)
# Classify transactions via source account
return self.classifier(x_dict['account']).squeeze(-1)~40 lines of model code. But you still need graph construction, mini-batch sampling with NeighborLoader, class-imbalanced loss, and serving infrastructure.
Training and evaluation
Training a fraud GNN involves several additional considerations beyond the model code:
- Class imbalance: Use focal loss or weighted BCE. Fraud is typically 0.1-0.5% of transactions.
- Mini-batch sampling: Use PyG's
NeighborLoaderto sample 2-hop subgraphs. Full-batch training is infeasible on production-scale transaction graphs. - Temporal splitting: Never leak future information. Train on transactions before time T, validate on T to T+1, test on T+1 to T+2.
- Feature encoding: Categorical features (merchant category, device OS) need embedding layers. Numerical features (amount, balance) need normalization.
Expected performance
On RelBench fraud-related benchmarks, the performance hierarchy is clear:
- LightGBM (flat-table): 62.44 AUROC
- GNN (hand-tuned HeteroConv): 75.83 AUROC
- KumoRFM (zero-shot): 76.71 AUROC
The 13+ point gap between flat-table and GNN represents the structural signal that lives in the graph. The additional 0.88 points from KumoRFM come from its pre-trained relational graph transformer, which has learned patterns across many relational datasets.
Or use KumoRFM in one line
KumoRFM replaces the entire pipeline above with a single Predictive Query:
PREDICT is_fraud FOR transaction
USING account, merchant, device, transactionOne line of PQL. KumoRFM auto-constructs the heterogeneous graph, selects the architecture, trains with temporal awareness, and serves predictions via API. 76.71 AUROC, zero code.
No graph construction. No architecture selection. No training loop. No serving infrastructure. KumoRFM's pre-trained relational graph transformer handles heterogeneous schemas, temporal dynamics, and class imbalance automatically. It achieves 76.71 AUROC on the same benchmark, slightly exceeding hand-tuned GNN baselines, in minutes instead of months.