Class imbalance occurs when one class vastly outnumbers another in the training data. In enterprise graph tasks, this is the norm, not the exception. Fraud accounts for 0.1-1% of transactions. Customer churn is 5-15%. Loan default is 1-5%. A GNN trained with standard cross-entropy loss will learn to predict the majority class for everything, achieving high accuracy while being completely useless.
Why graphs make imbalance worse
Standard class imbalance techniques from tabular ML (SMOTE, oversampling) do not transfer directly to graphs because of message passing. In a GNN, each node's representation is a function of its neighborhood. Minority-class nodes are typically surrounded by majority-class neighbors, so the aggregation step actively pushes minority representations toward the majority distribution.
Consider a fraudulent account node. It connects to merchants, other accounts, and transactions. Most of those neighbors are legitimate. After two layers of message passing, the fraudulent account's representation is dominated by legitimate-account signals. The model learns representations that look similar for both classes.
Technique 1: Loss reweighting
The simplest approach: modify the loss function to pay more attention to minority-class errors.
import torch
import torch.nn.functional as F
def focal_loss(logits, targets, alpha=0.25, gamma=2.0):
"""Focal loss: down-weight easy examples, focus on hard ones."""
ce = F.cross_entropy(logits, targets, reduction='none')
pt = torch.exp(-ce) # probability of correct class
loss = alpha * (1 - pt) ** gamma * ce
return loss.mean()
# Compare with standard cross-entropy:
# easy majority example: pt=0.95, weight=(0.05)^2=0.0025 -> nearly ignored
# hard minority example: pt=0.30, weight=(0.70)^2=0.49 -> 200x more weightFocal loss automatically focuses on the examples the model gets wrong, which are disproportionately minority-class examples.
Class-balanced loss
Weight each class inversely proportional to its frequency. If fraud is 0.1% of samples, its loss gets 1000x the weight. Simple and effective, but can be unstable with extreme ratios. A softer version uses the effective number of samples: weight = (1 - beta) / (1 - beta^n), where n is the class count.
Technique 2: Graph-aware oversampling
Standard SMOTE creates synthetic samples by interpolating between existing minority samples in feature space. But graph nodes are not independent: each has a neighborhood. Graph-aware oversampling must create both the synthetic node and its edges.
- GraphSMOTE: Generates synthetic minority nodes in the GNN's embedding space (after encoding) and adds edges based on similarity to existing minority nodes. This preserves graph structure.
- GraphENS: Creates ego-networks (subgraphs) for minority nodes by mixing the neighborhood of multiple minority samples. This is more expensive but captures structural patterns.
Technique 3: Architectural modifications
Some approaches modify the GNN itself:
- Decoupled aggregation: Separate the aggregation for same-class and different-class neighbors. This prevents majority neighbors from drowning out the minority signal.
- Rebalanced neighbor sampling: During mini-batch training, oversample minority-class neighbors. Instead of uniform neighbor sampling, sample minority neighbors with higher probability so they are better represented in the aggregation.
Evaluation: throw away accuracy
With 0.1% fraud rate, a model that predicts “not fraud” for everything gets 99.9% accuracy. Use these instead:
- AUROC: Measures ranking quality across all thresholds. How well does the model rank fraud above non-fraud?
- AUPRC: Measures precision-recall trade-off. More sensitive to minority-class performance than AUROC.
- F1 at optimal threshold: The threshold that maximizes F1 gives the best balance of precision and recall for the minority class.
- Recall at fixed precision: Business-relevant metric. “At 90% precision, how much fraud do we catch?”