Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide7 min read

Fine-Tuning: Adapting a Pre-Trained GNN to a Specific Task

Fine-tuning takes a GNN pre-trained on large data and adapts it to your specific prediction task. It converges faster, needs less labeled data, and often outperforms training from scratch.

PyTorch Geometric

TL;DR

  • 1Fine-tuning continues training a pre-trained GNN on task-specific labeled data. The pre-trained model provides general graph understanding; fine-tuning adapts it to your prediction target.
  • 2Use small learning rates (1e-4 to 5e-5) for pre-trained layers, higher rates (1e-3) for new task-specific layers. This preserves pre-trained knowledge while learning the new task.
  • 3Layer freezing strategy depends on data size: freeze early layers with <100 labels, fine-tune everything with 1,000+ labels. Gradual unfreezing is a robust default.
  • 4Fine-tuning converges in 5-20 epochs (vs 100+ from scratch), needs 10-100x less labeled data, and often achieves higher accuracy thanks to pre-trained regularization.
  • 5KumoRFM fine-tuning improves from 76.71 (zero-shot) to 81.14 AUROC on RelBench. The pre-trained relational patterns adapt to your specific database and prediction task.

Fine-tuning adapts a pre-trained GNN to a specific downstream task. Instead of training a model from random initialization, you start with weights that already encode general graph patterns (from pre-training on large data) and gently adjust them for your specific prediction target. This is faster, requires less labeled data, and typically produces better results than training from scratch.

Fine-tuning is the second stage of the pre-train / fine-tune paradigm that has transformed NLP (BERT, GPT) and vision (CLIP, ViT). For graphs, it is the mechanism by which foundation models like KumoRFM achieve state-of-the-art on specific enterprise prediction tasks.

Fine-tuning recipe

fine_tuning_gnn.py
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

# Load pre-trained encoder
encoder = load_pretrained('graph_encoder.pt')

# Add task-specific head
class FineTuneModel(torch.nn.Module):
    def __init__(self, encoder, num_classes):
        super().__init__()
        self.encoder = encoder
        self.head = torch.nn.Sequential(
            torch.nn.Linear(encoder.out_dim, 64),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.3),
            torch.nn.Linear(64, num_classes),
        )

    def forward(self, x, edge_index):
        z = self.encoder(x, edge_index)
        return self.head(z)

model = FineTuneModel(encoder, num_classes=2)

# Differential learning rates
optimizer = AdamW([
    {'params': model.encoder.parameters(), 'lr': 5e-5},   # small LR
    {'params': model.head.parameters(), 'lr': 1e-3},      # larger LR
], weight_decay=0.01)

scheduler = CosineAnnealingLR(optimizer, T_max=20)

# Fine-tune for 10-20 epochs (vs 100+ from scratch)
for epoch in range(20):
    train_one_epoch(model, train_loader, optimizer)
    scheduler.step()

Differential learning rates: small for pre-trained layers, larger for new layers. Cosine schedule decays both smoothly.

Freezing strategies

How much of the pre-trained model to freeze depends on how much labeled data you have:

  • Few labels (<100): freeze entire encoder, train only the classification head. The encoder is a fixed feature extractor.
  • Moderate labels (100-1,000): fine-tune all layers with small learning rates. The encoder adapts gently.
  • Many labels (1,000+): fine-tune everything with moderate learning rates. The model has enough data to adapt significantly.

Enterprise example: KumoRFM fine-tuning

A retailer wants to predict customer churn from their database of 5 million customers, 50 million orders, and 200,000 products. They have 2,000 confirmed churn labels.

  1. Zero-shot: KumoRFM makes predictions immediately using only the database schema and relational structure. No training needed. Achieves baseline AUROC.
  2. Fine-tuning: using the 2,000 churn labels, the pre-trained model adapts its attention patterns to the specific churn signals in this database. 10 epochs of fine-tuning.
  3. Result: fine-tuned AUROC is 5-10 points higher than zero-shot, and 15-20 points higher than training a GNN from scratch on only 2,000 labels.

Common mistakes

  • Learning rate too high: destroys pre-trained knowledge in the first few batches. Use 10-100x smaller than training from scratch.
  • Training too long: fine-tuning overfits quickly because the model starts from a good solution. 10-20 epochs is usually sufficient. Monitor validation loss and stop early.
  • No warmup: the first few steps with a cold head can produce large gradients that damage the encoder. Use linear warmup for 5-10% of total steps.
  • Ignoring class imbalance: enterprise datasets are often imbalanced (1% fraud). Use weighted loss, focal loss, or oversampling during fine-tuning.

Frequently asked questions

What is GNN fine-tuning?

Fine-tuning adapts a pre-trained GNN to a specific downstream task by continuing training on task-specific labeled data. The pre-trained model provides a strong starting point with general graph understanding. Fine-tuning adjusts these representations for the target prediction (e.g., fraud detection, churn prediction, drug property prediction).

How does fine-tuning differ from training from scratch?

Training from scratch initializes weights randomly and learns everything from the task data alone. Fine-tuning starts from pre-trained weights that already encode general graph patterns. Fine-tuning converges faster (fewer epochs), needs less labeled data (10-100x less), and often achieves better accuracy because the pre-trained knowledge regularizes the model.

What learning rate should I use for fine-tuning?

Use 10-100x smaller learning rate than training from scratch. Typical: 1e-4 to 5e-5 for the pre-trained layers, with a higher rate (1e-3) for newly added task-specific layers (classification head). This preserves pre-trained knowledge while adapting to the new task. Use linear warmup for the first 5-10% of training steps.

Should I freeze some layers during fine-tuning?

It depends on data size. With very few labels (<100), freeze early GNN layers and only train the last layer + classifier (prevents overfitting). With moderate labels (100-1,000), fine-tune all layers with small learning rate. With many labels (1,000+), fine-tune everything. Gradual unfreezing (unfreeze one layer at a time) is a robust middle ground.

How does KumoRFM fine-tuning work?

KumoRFM can be fine-tuned on your specific database and prediction task. The pre-trained relational graph transformer already understands relational structure. Fine-tuning adapts it to your specific tables, features, and prediction target. On RelBench, fine-tuning improves from 76.71 (zero-shot) to 81.14 AUROC.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.