Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide8 min read

Class Imbalance: Handling Rare Events in Graph Classification

Fraud is 0.1% of transactions. Churn is 5% of customers. Default is 2% of loans. The events that matter most in enterprise prediction are also the rarest. Here is how to make GNNs learn from them.

PyTorch Geometric

TL;DR

  • 1Class imbalance is amplified in GNNs because message passing aggregates neighbor features, and minority-class nodes are typically surrounded by majority-class neighbors, diluting their signal.
  • 2Three families of techniques: loss reweighting (focal loss, class-balanced loss), graph-aware oversampling (GraphSMOTE, edge generation), and architectural modifications (separate aggregation per class).
  • 3Focal loss is the simplest effective approach: it down-weights easy majority examples and focuses training on hard minority cases. Two lines of code change.
  • 4Never evaluate with accuracy. Use AUROC (ranking), AUPRC (precision-recall), and F1 at optimal threshold. A model predicting all-majority gets 99% accuracy but catches zero fraud.
  • 5In production, combine focal loss with temporal splits and calibrated thresholds. The threshold should be set by business cost: false negative cost vs false positive cost.

Class imbalance occurs when one class vastly outnumbers another in the training data. In enterprise graph tasks, this is the norm, not the exception. Fraud accounts for 0.1-1% of transactions. Customer churn is 5-15%. Loan default is 1-5%. A GNN trained with standard cross-entropy loss will learn to predict the majority class for everything, achieving high accuracy while being completely useless.

Why graphs make imbalance worse

Standard class imbalance techniques from tabular ML (SMOTE, oversampling) do not transfer directly to graphs because of message passing. In a GNN, each node's representation is a function of its neighborhood. Minority-class nodes are typically surrounded by majority-class neighbors, so the aggregation step actively pushes minority representations toward the majority distribution.

Consider a fraudulent account node. It connects to merchants, other accounts, and transactions. Most of those neighbors are legitimate. After two layers of message passing, the fraudulent account's representation is dominated by legitimate-account signals. The model learns representations that look similar for both classes.

Technique 1: Loss reweighting

The simplest approach: modify the loss function to pay more attention to minority-class errors.

focal_loss.py
import torch
import torch.nn.functional as F

def focal_loss(logits, targets, alpha=0.25, gamma=2.0):
    """Focal loss: down-weight easy examples, focus on hard ones."""
    ce = F.cross_entropy(logits, targets, reduction='none')
    pt = torch.exp(-ce)  # probability of correct class
    loss = alpha * (1 - pt) ** gamma * ce
    return loss.mean()

# Compare with standard cross-entropy:
# easy majority example: pt=0.95, weight=(0.05)^2=0.0025 -> nearly ignored
# hard minority example: pt=0.30, weight=(0.70)^2=0.49   -> 200x more weight

Focal loss automatically focuses on the examples the model gets wrong, which are disproportionately minority-class examples.

Class-balanced loss

Weight each class inversely proportional to its frequency. If fraud is 0.1% of samples, its loss gets 1000x the weight. Simple and effective, but can be unstable with extreme ratios. A softer version uses the effective number of samples: weight = (1 - beta) / (1 - beta^n), where n is the class count.

Technique 2: Graph-aware oversampling

Standard SMOTE creates synthetic samples by interpolating between existing minority samples in feature space. But graph nodes are not independent: each has a neighborhood. Graph-aware oversampling must create both the synthetic node and its edges.

  • GraphSMOTE: Generates synthetic minority nodes in the GNN's embedding space (after encoding) and adds edges based on similarity to existing minority nodes. This preserves graph structure.
  • GraphENS: Creates ego-networks (subgraphs) for minority nodes by mixing the neighborhood of multiple minority samples. This is more expensive but captures structural patterns.

Technique 3: Architectural modifications

Some approaches modify the GNN itself:

  • Decoupled aggregation: Separate the aggregation for same-class and different-class neighbors. This prevents majority neighbors from drowning out the minority signal.
  • Rebalanced neighbor sampling: During mini-batch training, oversample minority-class neighbors. Instead of uniform neighbor sampling, sample minority neighbors with higher probability so they are better represented in the aggregation.

Evaluation: throw away accuracy

With 0.1% fraud rate, a model that predicts “not fraud” for everything gets 99.9% accuracy. Use these instead:

  • AUROC: Measures ranking quality across all thresholds. How well does the model rank fraud above non-fraud?
  • AUPRC: Measures precision-recall trade-off. More sensitive to minority-class performance than AUROC.
  • F1 at optimal threshold: The threshold that maximizes F1 gives the best balance of precision and recall for the minority class.
  • Recall at fixed precision: Business-relevant metric. “At 90% precision, how much fraud do we catch?”

Frequently asked questions

Why is class imbalance especially problematic for GNNs?

GNNs aggregate neighbor information, and in imbalanced graphs, minority-class nodes are typically surrounded by majority-class neighbors. This means the aggregation step actively dilutes minority signals. A fraudulent account with 50 legitimate neighbors gets its representation pushed toward the majority class during message passing.

What is the best technique for class imbalance in GNNs?

No single technique dominates. The most reliable approach combines focal loss (to focus on hard minority examples), GraphSMOTE (to synthesize new minority nodes in the feature space), and careful evaluation with AUROC and AUPRC rather than accuracy. The optimal combination depends on the imbalance ratio and graph structure.

Should you use accuracy to evaluate imbalanced GNN models?

Never. A model that predicts every node as the majority class achieves 99%+ accuracy when fraud rate is below 1%. Use AUROC (ranking quality), AUPRC (precision-recall for the minority class), and F1 at optimal threshold. Report all three.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.