Graph distillation compresses large graph neural networks into smaller, faster models while preserving prediction quality. A large “teacher” GNN is trained for maximum accuracy, then a compact “student” model learns to replicate the teacher's behavior. The student trains on the teacher's soft predictions rather than hard labels, inheriting the teacher's knowledge of relational patterns without needing the same model capacity.
Why distill instead of training small?
The teacher's soft probability outputs encode far more information than binary labels. Consider fraud detection: a hard label says “fraud” or “not fraud.” But the teacher's soft output might say “0.85 fraud, 0.10 suspicious, 0.05 legitimate.” That distribution tells the student about decision boundaries, edge cases, and class relationships.
Empirically, a 2-layer GCN student distilled from a 6-layer GAT teacher consistently outperforms the same 2-layer GCN trained from scratch, often by 2-5% in AUROC on node classification tasks.
Three distillation approaches
Response-based distillation
The simplest approach: the student minimizes KL divergence between its output distribution and the teacher's. A temperature parameter softens the distributions, exposing more information about class similarities.
import torch
import torch.nn.functional as F
def distillation_loss(student_logits, teacher_logits, labels,
temperature=4.0, alpha=0.5):
"""Combine soft distillation loss with hard label loss."""
# Soft targets from teacher
soft_loss = F.kl_div(
F.log_softmax(student_logits / temperature, dim=-1),
F.softmax(teacher_logits / temperature, dim=-1),
reduction='batchmean'
) * (temperature ** 2)
# Hard targets for grounding
hard_loss = F.cross_entropy(student_logits, labels)
return alpha * soft_loss + (1 - alpha) * hard_lossTemperature controls softness. Higher temperature exposes more of the teacher's uncertainty to the student.
Feature-based distillation
Rather than matching only final outputs, the student also learns to match the teacher's intermediate node representations. This is particularly effective for GNNs because intermediate layers capture different neighborhood scales: layer 1 captures 1-hop patterns, layer 2 captures 2-hop patterns.
Graph-structure-aware distillation
The most sophisticated approach: the student learns to preserve the teacher's neighborhood aggregation patterns. This means the student not only matches the teacher's predictions but also learns which neighbors the teacher found most important (via attention weights or gradient-based importance scores).
Practical considerations
- Architecture mismatch: The student does not need the same architecture as the teacher. A GCN student can learn from a GAT teacher. In heterogeneous graphs, you can even distill a heterogeneous teacher into a simpler homogeneous student if the task allows it.
- Layer mismatch: Teachers often have more layers than students. Feature-based distillation requires a mapping between teacher and student layers, typically aligning every N-th teacher layer with a student layer.
- Temporal graphs: When distilling temporal GNNs, the student must preserve temporal ordering. Shuffling temporal edges during distillation introduces subtle leakage.
- Heterogeneous graphs: Distilling models on heterogeneous graphs (multiple node and edge types) requires type-specific distillation losses to avoid collapsing type information.
When to use graph distillation
Distillation is most valuable when:
- Production latency requirements are strict (sub-10ms per prediction)
- The teacher model is too large for the deployment target (edge devices, serverless)
- You need to serve predictions at high throughput (millions of inferences per second)
- Cost optimization: a 10x smaller model running on CPUs can replace an expensive GPU deployment
Distillation is less useful when accuracy on rare events is paramount (distillation slightly degrades tail-class performance) or when the teacher model is already small.
Production pipeline
A typical production workflow:
- Train the largest feasible teacher model, optimizing purely for accuracy
- Evaluate teacher on a held-out set to establish the accuracy ceiling
- Distill into candidate student architectures of varying sizes
- Select the student that meets latency requirements with minimal accuracy loss
- Deploy the student model, retrain the teacher periodically, and re-distill