Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide7 min read

Multi-Task Learning: One GNN, Multiple Predictions

Why train five models when one will do? Multi-task learning shares a GNN encoder across multiple prediction tasks, producing richer representations and reducing infrastructure complexity.

PyTorch Geometric

TL;DR

  • 1Multi-task learning trains a single GNN encoder with multiple task-specific heads. The shared encoder learns relational patterns that benefit all tasks, while each head specializes in its objective.
  • 2It works best when tasks share underlying patterns: churn + LTV, fraud + credit risk, click-through + conversion. Unrelated tasks can cause negative transfer.
  • 3Loss balancing is critical. Uncertainty weighting and GradNorm dynamically adjust task weights so no single task dominates training.
  • 4Practical benefits: fewer models to deploy, lower infrastructure cost, more consistent predictions, and a natural regularization effect that reduces overfitting.
  • 5KumoRFM is inherently multi-task: its pre-trained encoder supports any prediction task on the same relational database with just a task-specific head swap.

Multi-task learning trains a single graph neural network on multiple prediction tasks simultaneously. Instead of building separate models for churn prediction, lifetime value estimation, and product recommendation, a shared GNN encoder learns representations that serve all three. Each task gets its own lightweight prediction head, but the expensive graph encoding happens once.

Architecture: shared encoder, separate heads

A multi-task GNN has two components:

  • Shared encoder: The GNN layers (GCN, GAT, or transformer) that process the graph and produce node representations. This is the expensive part, involving neighborhood aggregation across the entire graph.
  • Task-specific heads: Lightweight MLPs that map node representations to task-specific outputs. A churn head outputs a binary probability. An LTV head outputs a dollar value. A recommendation head scores candidate products.
multi_task_gnn.py
import torch
import torch.nn as nn
from torch_geometric.nn import SAGEConv

class MultiTaskGNN(nn.Module):
    def __init__(self, in_channels, hidden, num_classes_task1):
        super().__init__()
        # Shared encoder
        self.conv1 = SAGEConv(in_channels, hidden)
        self.conv2 = SAGEConv(hidden, hidden)

        # Task-specific heads
        self.churn_head = nn.Linear(hidden, 2)       # binary
        self.ltv_head = nn.Linear(hidden, 1)          # regression
        self.segment_head = nn.Linear(hidden, num_classes_task1)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index).relu()
        h = self.conv2(h, edge_index).relu()
        return {
            'churn': self.churn_head(h),
            'ltv': self.ltv_head(h),
            'segment': self.segment_head(h),
        }

The encoder runs once. Each head is a single linear layer, adding negligible overhead.

Why multi-task learning improves GNN performance

The key benefit is positive transfer: patterns learned for one task help others.

  • Shared relational patterns: Customer spending velocity predicts both churn and LTV. Transaction graph density predicts both fraud and credit risk. The shared encoder captures these patterns once.
  • Regularization: Multiple tasks act as implicit regularizers. The encoder cannot overfit to the idiosyncrasies of any single task because it must produce representations that work for all tasks.
  • Data efficiency: Some tasks have sparse labels (fraud is rare). By sharing an encoder with a label-rich task (transaction classification), the fraud task benefits from representations trained on much more data.

Balancing task losses

The most common pitfall: one task dominates the gradient, starving others. If churn loss gradients are 100x larger than LTV gradients, the encoder optimizes almost exclusively for churn.

Three solutions:

  1. Fixed weighting: Manually set loss weights (e.g., 1.0 for churn, 0.01 for LTV). Simple but requires extensive tuning and breaks when loss scales shift during training.
  2. Uncertainty weighting: Learn a weight for each task based on the homoscedastic uncertainty. Tasks with higher noise get lower weight. This is the most widely used automatic method.
  3. GradNorm: Dynamically adjust weights to normalize gradient magnitudes across tasks. Ensures all tasks train at similar rates regardless of loss scale differences.
uncertainty_weighting.py
class UncertaintyWeightedLoss(nn.Module):
    def __init__(self, num_tasks):
        super().__init__()
        # log(sigma^2) for each task, learned during training
        self.log_vars = nn.Parameter(torch.zeros(num_tasks))

    def forward(self, losses):
        total = 0
        for i, loss in enumerate(losses):
            precision = torch.exp(-self.log_vars[i])
            total += precision * loss + self.log_vars[i]
        return total

Uncertainty weighting learns to down-weight noisy tasks automatically. Two lines added to any multi-task training loop.

Multi-task learning in production

Beyond accuracy, multi-task learning simplifies operations:

  • One model, one graph pass: A single inference call produces predictions for all tasks. No need to maintain separate feature pipelines, training jobs, and model registries.
  • Consistent representations: All tasks share the same view of each entity. There is no risk of one model treating a customer as high-risk while another treats them as high-value due to different encodings.
  • Incremental task addition: Adding a new task requires only training a new head. The shared encoder, which represents 95% of compute, is already trained.

Frequently asked questions

What is multi-task learning in graph neural networks?

Multi-task learning trains a single GNN with a shared encoder on multiple prediction tasks simultaneously. For example, one model predicts both customer churn probability and expected lifetime value. The shared encoder learns richer representations because patterns useful for one task (spending behavior) also help the other.

When does multi-task learning help GNN performance?

Multi-task learning helps most when tasks share underlying relational patterns. Churn prediction and LTV estimation both depend on customer engagement patterns. Fraud detection and credit scoring both rely on transaction graph structure. When tasks are unrelated, multi-task learning can hurt performance through negative transfer.

How do you balance losses across multiple GNN tasks?

Common approaches include fixed weighting (manually set loss weights), uncertainty weighting (learn weights based on task uncertainty), and gradient normalization (GradNorm, which dynamically adjusts weights so all tasks train at similar rates). Uncertainty weighting is the most popular in practice.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.