Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide8 min read

Self-Supervised Learning: Pre-Training GNNs Without Task Labels

Self-supervised learning creates training signal from the graph structure itself. Mask features and predict them. Create two views and match them. Learn general representations, then fine-tune for any task.

PyTorch Geometric

TL;DR

  • 1Self-supervised learning (SSL) trains GNNs without task labels by creating signal from graph structure: mask features and predict them, or contrast different graph views.
  • 2Three paradigms: contrastive (match positive pairs, push away negatives), generative (reconstruct masked features/edges), predictive (predict structural properties).
  • 3SSL is critical when labels are scarce. A bank has billions of transactions but thousands of fraud labels. SSL leverages the unlabeled structure, then fine-tunes with few labels.
  • 4Graph masking (mask features, predict from context) is the graph equivalent of BERT's masked language modeling. It teaches the model to understand structural context.
  • 5Foundation models like KumoRFM use SSL to pre-train on diverse relational databases, enabling zero-shot and few-shot predictions on new databases.

Self-supervised learning pre-trains GNNs without task labels by creating training signal from the graph structure itself. The key insight: a graph contains far more information in its structure (edges, neighborhoods, communities) than in its labels. SSL extracts this structural information into general-purpose node representations that can then be fine-tuned for any downstream task with very few labels.

This is the graph equivalent of how BERT learns language by predicting masked words, or how CLIP learns vision by matching images to captions. The pretext task (what the model trains on) is different from the downstream task (what you actually care about), but the representations transfer.

Three SSL paradigms

Contrastive (most popular)

Create two augmented views of the same graph and train the model to recognize them as the same while distinguishing them from views of different graphs:

contrastive_ssl.py
# Contrastive SSL: GraphCL / GCA style

# 1. Create two augmented views of the same graph
view1 = augment(graph)  # drop 20% edges, mask 10% features
view2 = augment(graph)  # different random drops/masks

# 2. Encode both views with the same GNN
z1 = gnn_encoder(view1)  # [num_nodes, hidden_dim]
z2 = gnn_encoder(view2)  # [num_nodes, hidden_dim]

# 3. Contrastive loss: same node in both views = positive pair
# Different nodes = negative pairs
# InfoNCE loss pushes positives together, negatives apart
loss = info_nce_loss(z1, z2)

# After pre-training: freeze encoder, fine-tune classifier
# with just 100 labeled nodes instead of 10,000

Contrastive SSL learns by matching different views of the same graph. Augmentation quality directly impacts representation quality.

Generative (masking)

Mask parts of the graph and train the model to reconstruct them:

  • Feature masking: zero out 15% of node features, predict the original values from context
  • Edge masking: remove 15% of edges, predict whether they exist from node embeddings
  • Combined: mask both features and edges for richer self-supervision

Predictive

Predict structural properties without masking:

  • Predict node degree from features
  • Predict local clustering coefficient
  • Predict which subgraph pattern (motif) each node belongs to

Enterprise example: fraud detection with 0.1% labels

A payment processor has 100 million transactions but only 10,000 confirmed fraud labels (0.01%). Training a GNN on 10,000 labels in a graph of 100 million edges wastes most of the structural information.

Self-supervised pre-training:

  1. Pre-train the GNN on the full 100M transaction graph using contrastive SSL (no labels needed)
  2. The model learns: typical transaction patterns, normal account behavior, common network structures
  3. Fine-tune with the 10,000 fraud labels
  4. The fine-tuned model knows what “normal” looks like (from SSL) and what “fraud” looks like (from labels)

Result: the SSL pre-trained model achieves 15-25% higher AUROC than a model trained from scratch on only the 10,000 labels, because it leverages the structural patterns in the other 99.99% of the data.

SSL in PyG

ssl_in_pyg.py
# Feature masking SSL in PyG
import torch
from torch_geometric.nn import GCNConv

class MaskedAutoencoder(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        self.encoder = GCNConv(in_dim, hidden_dim)
        self.decoder = torch.nn.Linear(hidden_dim, in_dim)

    def forward(self, x, edge_index, mask_rate=0.15):
        # Create mask
        mask = torch.rand(x.size(0)) < mask_rate
        x_masked = x.clone()
        x_masked[mask] = 0  # zero out masked nodes

        # Encode with masked input
        z = self.encoder(x_masked, edge_index).relu()

        # Decode: predict original features for masked nodes
        x_pred = self.decoder(z)

        # Loss: reconstruction only on masked nodes
        loss = F.mse_loss(x_pred[mask], x[mask])
        return loss, z

Feature masking autoencoder. The model learns to predict masked features from their graph context.

Frequently asked questions

What is self-supervised learning on graphs?

Self-supervised learning (SSL) on graphs trains GNNs without task-specific labels by creating training signal from the graph structure itself. Examples: mask some node features and predict them, mask some edges and predict them, or create two views of the same graph and train the model to recognize them as the same. The resulting model learns general graph representations useful for many downstream tasks.

Why is self-supervised learning important for graphs?

Graph data is abundant but labels are scarce. A bank has billions of transactions (edges) but only thousands of confirmed fraud labels. SSL leverages the unlabeled graph structure to learn representations, then fine-tunes with the few available labels. This dramatically improves performance when labeled data is limited.

What are the main SSL approaches for graphs?

Three categories: (1) Contrastive: learn by comparing similar (positive) and dissimilar (negative) graph views. (2) Generative: learn by reconstructing masked parts of the graph (features or edges). (3) Predictive: learn by predicting structural properties (degree, clustering coefficient, subgraph patterns). Contrastive methods currently dominate.

What is graph masking?

Graph masking is a generative SSL strategy: randomly mask (remove) some node features or edges, then train the GNN to reconstruct the masked elements. This is the graph equivalent of masked language modeling (BERT). The model learns to predict missing information from context, developing rich structural understanding.

How does SSL relate to foundation models like KumoRFM?

KumoRFM is a graph foundation model pre-trained with SSL on diverse relational databases. The SSL pre-training teaches the model general patterns about relational data structure. When applied to a new database, it can make zero-shot predictions without any task-specific training, or be fine-tuned for maximum accuracy.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.