Multi-task learning trains a single graph neural network on multiple prediction tasks simultaneously. Instead of building separate models for churn prediction, lifetime value estimation, and product recommendation, a shared GNN encoder learns representations that serve all three. Each task gets its own lightweight prediction head, but the expensive graph encoding happens once.
Architecture: shared encoder, separate heads
A multi-task GNN has two components:
- Shared encoder: The GNN layers (GCN, GAT, or transformer) that process the graph and produce node representations. This is the expensive part, involving neighborhood aggregation across the entire graph.
- Task-specific heads: Lightweight MLPs that map node representations to task-specific outputs. A churn head outputs a binary probability. An LTV head outputs a dollar value. A recommendation head scores candidate products.
import torch
import torch.nn as nn
from torch_geometric.nn import SAGEConv
class MultiTaskGNN(nn.Module):
def __init__(self, in_channels, hidden, num_classes_task1):
super().__init__()
# Shared encoder
self.conv1 = SAGEConv(in_channels, hidden)
self.conv2 = SAGEConv(hidden, hidden)
# Task-specific heads
self.churn_head = nn.Linear(hidden, 2) # binary
self.ltv_head = nn.Linear(hidden, 1) # regression
self.segment_head = nn.Linear(hidden, num_classes_task1)
def forward(self, x, edge_index):
h = self.conv1(x, edge_index).relu()
h = self.conv2(h, edge_index).relu()
return {
'churn': self.churn_head(h),
'ltv': self.ltv_head(h),
'segment': self.segment_head(h),
}The encoder runs once. Each head is a single linear layer, adding negligible overhead.
Why multi-task learning improves GNN performance
The key benefit is positive transfer: patterns learned for one task help others.
- Shared relational patterns: Customer spending velocity predicts both churn and LTV. Transaction graph density predicts both fraud and credit risk. The shared encoder captures these patterns once.
- Regularization: Multiple tasks act as implicit regularizers. The encoder cannot overfit to the idiosyncrasies of any single task because it must produce representations that work for all tasks.
- Data efficiency: Some tasks have sparse labels (fraud is rare). By sharing an encoder with a label-rich task (transaction classification), the fraud task benefits from representations trained on much more data.
Balancing task losses
The most common pitfall: one task dominates the gradient, starving others. If churn loss gradients are 100x larger than LTV gradients, the encoder optimizes almost exclusively for churn.
Three solutions:
- Fixed weighting: Manually set loss weights (e.g., 1.0 for churn, 0.01 for LTV). Simple but requires extensive tuning and breaks when loss scales shift during training.
- Uncertainty weighting: Learn a weight for each task based on the homoscedastic uncertainty. Tasks with higher noise get lower weight. This is the most widely used automatic method.
- GradNorm: Dynamically adjust weights to normalize gradient magnitudes across tasks. Ensures all tasks train at similar rates regardless of loss scale differences.
class UncertaintyWeightedLoss(nn.Module):
def __init__(self, num_tasks):
super().__init__()
# log(sigma^2) for each task, learned during training
self.log_vars = nn.Parameter(torch.zeros(num_tasks))
def forward(self, losses):
total = 0
for i, loss in enumerate(losses):
precision = torch.exp(-self.log_vars[i])
total += precision * loss + self.log_vars[i]
return totalUncertainty weighting learns to down-weight noisy tasks automatically. Two lines added to any multi-task training loop.
Multi-task learning in production
Beyond accuracy, multi-task learning simplifies operations:
- One model, one graph pass: A single inference call produces predictions for all tasks. No need to maintain separate feature pipelines, training jobs, and model registries.
- Consistent representations: All tasks share the same view of each entity. There is no risk of one model treating a customer as high-risk while another treats them as high-value due to different encodings.
- Incremental task addition: Adding a new task requires only training a new head. The shared encoder, which represents 95% of compute, is already trained.