Fine-tuning adapts a pre-trained GNN to a specific downstream task. Instead of training a model from random initialization, you start with weights that already encode general graph patterns (from pre-training on large data) and gently adjust them for your specific prediction target. This is faster, requires less labeled data, and typically produces better results than training from scratch.
Fine-tuning is the second stage of the pre-train / fine-tune paradigm that has transformed NLP (BERT, GPT) and vision (CLIP, ViT). For graphs, it is the mechanism by which foundation models like KumoRFM achieve state-of-the-art on specific enterprise prediction tasks.
Fine-tuning recipe
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
# Load pre-trained encoder
encoder = load_pretrained('graph_encoder.pt')
# Add task-specific head
class FineTuneModel(torch.nn.Module):
def __init__(self, encoder, num_classes):
super().__init__()
self.encoder = encoder
self.head = torch.nn.Sequential(
torch.nn.Linear(encoder.out_dim, 64),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
torch.nn.Linear(64, num_classes),
)
def forward(self, x, edge_index):
z = self.encoder(x, edge_index)
return self.head(z)
model = FineTuneModel(encoder, num_classes=2)
# Differential learning rates
optimizer = AdamW([
{'params': model.encoder.parameters(), 'lr': 5e-5}, # small LR
{'params': model.head.parameters(), 'lr': 1e-3}, # larger LR
], weight_decay=0.01)
scheduler = CosineAnnealingLR(optimizer, T_max=20)
# Fine-tune for 10-20 epochs (vs 100+ from scratch)
for epoch in range(20):
train_one_epoch(model, train_loader, optimizer)
scheduler.step()Differential learning rates: small for pre-trained layers, larger for new layers. Cosine schedule decays both smoothly.
Freezing strategies
How much of the pre-trained model to freeze depends on how much labeled data you have:
- Few labels (<100): freeze entire encoder, train only the classification head. The encoder is a fixed feature extractor.
- Moderate labels (100-1,000): fine-tune all layers with small learning rates. The encoder adapts gently.
- Many labels (1,000+): fine-tune everything with moderate learning rates. The model has enough data to adapt significantly.
Enterprise example: KumoRFM fine-tuning
A retailer wants to predict customer churn from their database of 5 million customers, 50 million orders, and 200,000 products. They have 2,000 confirmed churn labels.
- Zero-shot: KumoRFM makes predictions immediately using only the database schema and relational structure. No training needed. Achieves baseline AUROC.
- Fine-tuning: using the 2,000 churn labels, the pre-trained model adapts its attention patterns to the specific churn signals in this database. 10 epochs of fine-tuning.
- Result: fine-tuned AUROC is 5-10 points higher than zero-shot, and 15-20 points higher than training a GNN from scratch on only 2,000 labels.
Common mistakes
- Learning rate too high: destroys pre-trained knowledge in the first few batches. Use 10-100x smaller than training from scratch.
- Training too long: fine-tuning overfits quickly because the model starts from a good solution. 10-20 epochs is usually sufficient. Monitor validation loss and stop early.
- No warmup: the first few steps with a cold head can produce large gradients that damage the encoder. Use linear warmup for 5-10% of total steps.
- Ignoring class imbalance: enterprise datasets are often imbalanced (1% fraud). Use weighted loss, focal loss, or oversampling during fine-tuning.