Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide7 min read

Pre-Training: Training GNNs on Large Datasets Before Fine-Tuning

Pre-training trains a GNN on massive graph data with self-supervised objectives, building general representations that transfer to downstream tasks. This is how foundation models like KumoRFM achieve state-of-the-art with zero or few labels.

PyTorch Geometric

TL;DR

  • 1Pre-training trains a GNN on large unlabeled graph data before fine-tuning on the target task. The model learns general structural patterns that transfer to specific predictions.
  • 2Three objective levels: node-level (predict masked features), edge-level (predict masked edges), graph-level (contrastive between views). Combining levels works best.
  • 3Pre-training provides the biggest lift when labeled data is scarce. With 100 labels, a pre-trained model can match a from-scratch model trained on 10,000 labels.
  • 4Scale matters: pre-training on 2M+ molecules or 100M+ relational rows produces significantly stronger representations than pre-training on smaller datasets.
  • 5KumoRFM is a graph foundation model pre-trained on diverse relational databases. It achieves zero-shot 76.71 AUROC on RelBench, outperforming task-specific models trained from scratch.

Pre-training trains GNNs on large datasets before fine-tuning on specific tasks. The insight: graph structure contains enormous amounts of information that can be extracted without any task labels. A GNN pre-trained on 2 million molecules learns what ring structures mean, how branching affects properties, and how functional groups interact. This knowledge transfers to drug discovery tasks where you have only 500 labeled molecules.

This is the same paradigm that produced BERT, GPT, and CLIP. The graph version is newer but follows the same principle: train a large model on abundant data with self-supervised objectives, then adapt to specific tasks with fine-tuning.

Pre-training objectives

Node-level

  • Feature masking: mask 15% of node features, predict them from graph context (like BERT for graphs)
  • Context prediction: predict whether two subgraphs come from the same neighborhood
  • Property prediction: predict structural properties (degree, clustering) from embeddings

Edge-level

  • Link prediction: mask edges and predict their existence from node embeddings
  • Edge property prediction: predict edge attributes from endpoint embeddings

Graph-level

  • Contrastive: match augmented views of the same graph (GraphCL, GRACE)
  • Graph property prediction: predict global properties (if supervised labels available)
multi_level_pretraining.py
import torch
from torch_geometric.nn import GCNConv

class PreTrainingModel(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        self.encoder = torch.nn.Sequential(
            GCNConv(in_dim, hidden_dim),
            torch.nn.ReLU(),
            GCNConv(hidden_dim, hidden_dim),
        )
        # Multiple pre-training heads
        self.feature_decoder = torch.nn.Linear(hidden_dim, in_dim)
        self.edge_predictor = torch.nn.Bilinear(hidden_dim, hidden_dim, 1)

    def pretrain_step(self, x, edge_index, mask_rate=0.15):
        # Mask features
        mask = torch.rand(x.size(0)) < mask_rate
        x_masked = x.clone()
        x_masked[mask] = 0

        # Encode
        z = self.encoder(x_masked, edge_index)

        # Node-level loss: reconstruct masked features
        feat_loss = F.mse_loss(self.feature_decoder(z[mask]), x[mask])

        # Edge-level loss: predict random negative edges
        neg_edge = negative_sampling(edge_index, x.size(0))
        pos_score = self.edge_predictor(z[edge_index[0]], z[edge_index[1]])
        neg_score = self.edge_predictor(z[neg_edge[0]], z[neg_edge[1]])
        edge_loss = F.binary_cross_entropy_with_logits(
            torch.cat([pos_score, neg_score]),
            torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)])
        )

        return feat_loss + edge_loss

Multi-level pre-training: feature reconstruction + link prediction. The encoder learns from both objectives simultaneously.

Enterprise example: multi-database foundation model

KumoRFM pre-trains on diverse relational databases:

  • E-commerce databases (customers, orders, products)
  • Financial databases (accounts, transactions, merchants)
  • Healthcare databases (patients, encounters, prescriptions)
  • Social networks (users, posts, interactions)

From this diverse pre-training, KumoRFM learns universal relational patterns: what it means for entities to be connected by foreign keys, how temporal sequences of events predict outcomes, how multi-table relationships create compound features. When pointed at a new database it has never seen, it achieves 76.71 AUROC zero-shot, outperforming task-specific models trained from scratch on the target data.

Frequently asked questions

What is GNN pre-training?

GNN pre-training trains a graph neural network on a large dataset with self-supervised or supervised objectives before applying it to a specific downstream task. The pre-trained model learns general graph representations that transfer to new tasks. Pre-training is followed by fine-tuning on the target task with limited labeled data.

What pre-training objectives work for GNNs?

Three levels: (1) Node-level: predict masked node features, predict node properties like degree or clustering coefficient. (2) Edge-level: predict masked edges (link prediction), predict edge properties. (3) Graph-level: predict graph properties, contrastive learning between graph views. Combining objectives across levels often works best.

How much data do you need for GNN pre-training?

Pre-training benefits scale with data size. For molecular GNNs, pre-training on 2M+ molecules from ChEMBL or ZINC significantly outperforms training from scratch. For relational data, KumoRFM pre-trains on multiple databases totaling 100M+ rows. Generally, pre-training is worthwhile when you have 10x+ more unlabeled than labeled data.

Is pre-training always worth it?

Not always. If you have abundant labeled data for your target task (10,000+ labels), training from scratch may match pre-training performance. Pre-training provides the biggest lift when labeled data is scarce (100-1,000 labels) and unlabeled graph data is abundant. The less labeled data you have, the more pre-training helps.

What is the difference between pre-training and self-supervised learning?

Self-supervised learning is a training paradigm (create signal from structure). Pre-training is a training stage (train first on large data, then adapt to target task). Pre-training often uses self-supervised objectives, but can also use supervised objectives on large labeled datasets. SSL is how you pre-train; pre-training is when you do it.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.