Pre-training trains GNNs on large datasets before fine-tuning on specific tasks. The insight: graph structure contains enormous amounts of information that can be extracted without any task labels. A GNN pre-trained on 2 million molecules learns what ring structures mean, how branching affects properties, and how functional groups interact. This knowledge transfers to drug discovery tasks where you have only 500 labeled molecules.
This is the same paradigm that produced BERT, GPT, and CLIP. The graph version is newer but follows the same principle: train a large model on abundant data with self-supervised objectives, then adapt to specific tasks with fine-tuning.
Pre-training objectives
Node-level
- Feature masking: mask 15% of node features, predict them from graph context (like BERT for graphs)
- Context prediction: predict whether two subgraphs come from the same neighborhood
- Property prediction: predict structural properties (degree, clustering) from embeddings
Edge-level
- Link prediction: mask edges and predict their existence from node embeddings
- Edge property prediction: predict edge attributes from endpoint embeddings
Graph-level
- Contrastive: match augmented views of the same graph (GraphCL, GRACE)
- Graph property prediction: predict global properties (if supervised labels available)
import torch
from torch_geometric.nn import GCNConv
class PreTrainingModel(torch.nn.Module):
def __init__(self, in_dim, hidden_dim):
super().__init__()
self.encoder = torch.nn.Sequential(
GCNConv(in_dim, hidden_dim),
torch.nn.ReLU(),
GCNConv(hidden_dim, hidden_dim),
)
# Multiple pre-training heads
self.feature_decoder = torch.nn.Linear(hidden_dim, in_dim)
self.edge_predictor = torch.nn.Bilinear(hidden_dim, hidden_dim, 1)
def pretrain_step(self, x, edge_index, mask_rate=0.15):
# Mask features
mask = torch.rand(x.size(0)) < mask_rate
x_masked = x.clone()
x_masked[mask] = 0
# Encode
z = self.encoder(x_masked, edge_index)
# Node-level loss: reconstruct masked features
feat_loss = F.mse_loss(self.feature_decoder(z[mask]), x[mask])
# Edge-level loss: predict random negative edges
neg_edge = negative_sampling(edge_index, x.size(0))
pos_score = self.edge_predictor(z[edge_index[0]], z[edge_index[1]])
neg_score = self.edge_predictor(z[neg_edge[0]], z[neg_edge[1]])
edge_loss = F.binary_cross_entropy_with_logits(
torch.cat([pos_score, neg_score]),
torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)])
)
return feat_loss + edge_lossMulti-level pre-training: feature reconstruction + link prediction. The encoder learns from both objectives simultaneously.
Enterprise example: multi-database foundation model
KumoRFM pre-trains on diverse relational databases:
- E-commerce databases (customers, orders, products)
- Financial databases (accounts, transactions, merchants)
- Healthcare databases (patients, encounters, prescriptions)
- Social networks (users, posts, interactions)
From this diverse pre-training, KumoRFM learns universal relational patterns: what it means for entities to be connected by foreign keys, how temporal sequences of events predict outcomes, how multi-table relationships create compound features. When pointed at a new database it has never seen, it achieves 76.71 AUROC zero-shot, outperforming task-specific models trained from scratch on the target data.