Self-supervised learning pre-trains GNNs without task labels by creating training signal from the graph structure itself. The key insight: a graph contains far more information in its structure (edges, neighborhoods, communities) than in its labels. SSL extracts this structural information into general-purpose node representations that can then be fine-tuned for any downstream task with very few labels.
This is the graph equivalent of how BERT learns language by predicting masked words, or how CLIP learns vision by matching images to captions. The pretext task (what the model trains on) is different from the downstream task (what you actually care about), but the representations transfer.
Three SSL paradigms
Contrastive (most popular)
Create two augmented views of the same graph and train the model to recognize them as the same while distinguishing them from views of different graphs:
# Contrastive SSL: GraphCL / GCA style
# 1. Create two augmented views of the same graph
view1 = augment(graph) # drop 20% edges, mask 10% features
view2 = augment(graph) # different random drops/masks
# 2. Encode both views with the same GNN
z1 = gnn_encoder(view1) # [num_nodes, hidden_dim]
z2 = gnn_encoder(view2) # [num_nodes, hidden_dim]
# 3. Contrastive loss: same node in both views = positive pair
# Different nodes = negative pairs
# InfoNCE loss pushes positives together, negatives apart
loss = info_nce_loss(z1, z2)
# After pre-training: freeze encoder, fine-tune classifier
# with just 100 labeled nodes instead of 10,000Contrastive SSL learns by matching different views of the same graph. Augmentation quality directly impacts representation quality.
Generative (masking)
Mask parts of the graph and train the model to reconstruct them:
- Feature masking: zero out 15% of node features, predict the original values from context
- Edge masking: remove 15% of edges, predict whether they exist from node embeddings
- Combined: mask both features and edges for richer self-supervision
Predictive
Predict structural properties without masking:
- Predict node degree from features
- Predict local clustering coefficient
- Predict which subgraph pattern (motif) each node belongs to
Enterprise example: fraud detection with 0.1% labels
A payment processor has 100 million transactions but only 10,000 confirmed fraud labels (0.01%). Training a GNN on 10,000 labels in a graph of 100 million edges wastes most of the structural information.
Self-supervised pre-training:
- Pre-train the GNN on the full 100M transaction graph using contrastive SSL (no labels needed)
- The model learns: typical transaction patterns, normal account behavior, common network structures
- Fine-tune with the 10,000 fraud labels
- The fine-tuned model knows what “normal” looks like (from SSL) and what “fraud” looks like (from labels)
Result: the SSL pre-trained model achieves 15-25% higher AUROC than a model trained from scratch on only the 10,000 labels, because it leverages the structural patterns in the other 99.99% of the data.
SSL in PyG
# Feature masking SSL in PyG
import torch
from torch_geometric.nn import GCNConv
class MaskedAutoencoder(torch.nn.Module):
def __init__(self, in_dim, hidden_dim):
super().__init__()
self.encoder = GCNConv(in_dim, hidden_dim)
self.decoder = torch.nn.Linear(hidden_dim, in_dim)
def forward(self, x, edge_index, mask_rate=0.15):
# Create mask
mask = torch.rand(x.size(0)) < mask_rate
x_masked = x.clone()
x_masked[mask] = 0 # zero out masked nodes
# Encode with masked input
z = self.encoder(x_masked, edge_index).relu()
# Decode: predict original features for masked nodes
x_pred = self.decoder(z)
# Loss: reconstruction only on masked nodes
loss = F.mse_loss(x_pred[mask], x[mask])
return loss, zFeature masking autoencoder. The model learns to predict masked features from their graph context.