Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide7 min read

Graph Distillation: Compressing Large GNN Models for Efficient Inference

Production systems need fast inference. Graph distillation trains a compact student GNN to mimic a powerful teacher model, delivering 90-99% of the accuracy at a fraction of the latency and cost.

PyTorch Geometric

TL;DR

  • 1Graph distillation trains a small student GNN to replicate the predictions of a large teacher GNN. The student learns from soft probability distributions, not hard labels, capturing richer relational patterns.
  • 2Three approaches: response-based distillation (match teacher outputs), feature-based (match intermediate representations), and graph-structure-aware (distill neighborhood aggregation patterns).
  • 3Typical results: 5-50x model compression, 3-20x inference speedup, retaining 95-99% of teacher accuracy. The gap narrows further with graph-aware distillation techniques.
  • 4Graph-specific challenges include preserving structural information across hops, handling heterogeneous node types, and maintaining temporal ordering during distillation.
  • 5In production, distillation is often the final step: train a large model for maximum accuracy, then distill it for deployment under latency and cost constraints.

Graph distillation compresses large graph neural networks into smaller, faster models while preserving prediction quality. A large “teacher” GNN is trained for maximum accuracy, then a compact “student” model learns to replicate the teacher's behavior. The student trains on the teacher's soft predictions rather than hard labels, inheriting the teacher's knowledge of relational patterns without needing the same model capacity.

Why distill instead of training small?

The teacher's soft probability outputs encode far more information than binary labels. Consider fraud detection: a hard label says “fraud” or “not fraud.” But the teacher's soft output might say “0.85 fraud, 0.10 suspicious, 0.05 legitimate.” That distribution tells the student about decision boundaries, edge cases, and class relationships.

Empirically, a 2-layer GCN student distilled from a 6-layer GAT teacher consistently outperforms the same 2-layer GCN trained from scratch, often by 2-5% in AUROC on node classification tasks.

Three distillation approaches

Response-based distillation

The simplest approach: the student minimizes KL divergence between its output distribution and the teacher's. A temperature parameter softens the distributions, exposing more information about class similarities.

response_distillation.py
import torch
import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, labels,
                      temperature=4.0, alpha=0.5):
    """Combine soft distillation loss with hard label loss."""
    # Soft targets from teacher
    soft_loss = F.kl_div(
        F.log_softmax(student_logits / temperature, dim=-1),
        F.softmax(teacher_logits / temperature, dim=-1),
        reduction='batchmean'
    ) * (temperature ** 2)

    # Hard targets for grounding
    hard_loss = F.cross_entropy(student_logits, labels)

    return alpha * soft_loss + (1 - alpha) * hard_loss

Temperature controls softness. Higher temperature exposes more of the teacher's uncertainty to the student.

Feature-based distillation

Rather than matching only final outputs, the student also learns to match the teacher's intermediate node representations. This is particularly effective for GNNs because intermediate layers capture different neighborhood scales: layer 1 captures 1-hop patterns, layer 2 captures 2-hop patterns.

Graph-structure-aware distillation

The most sophisticated approach: the student learns to preserve the teacher's neighborhood aggregation patterns. This means the student not only matches the teacher's predictions but also learns which neighbors the teacher found most important (via attention weights or gradient-based importance scores).

Practical considerations

  • Architecture mismatch: The student does not need the same architecture as the teacher. A GCN student can learn from a GAT teacher. In heterogeneous graphs, you can even distill a heterogeneous teacher into a simpler homogeneous student if the task allows it.
  • Layer mismatch: Teachers often have more layers than students. Feature-based distillation requires a mapping between teacher and student layers, typically aligning every N-th teacher layer with a student layer.
  • Temporal graphs: When distilling temporal GNNs, the student must preserve temporal ordering. Shuffling temporal edges during distillation introduces subtle leakage.
  • Heterogeneous graphs: Distilling models on heterogeneous graphs (multiple node and edge types) requires type-specific distillation losses to avoid collapsing type information.

When to use graph distillation

Distillation is most valuable when:

  • Production latency requirements are strict (sub-10ms per prediction)
  • The teacher model is too large for the deployment target (edge devices, serverless)
  • You need to serve predictions at high throughput (millions of inferences per second)
  • Cost optimization: a 10x smaller model running on CPUs can replace an expensive GPU deployment

Distillation is less useful when accuracy on rare events is paramount (distillation slightly degrades tail-class performance) or when the teacher model is already small.

Production pipeline

A typical production workflow:

  1. Train the largest feasible teacher model, optimizing purely for accuracy
  2. Evaluate teacher on a held-out set to establish the accuracy ceiling
  3. Distill into candidate student architectures of varying sizes
  4. Select the student that meets latency requirements with minimal accuracy loss
  5. Deploy the student model, retrain the teacher periodically, and re-distill

Frequently asked questions

What is graph distillation?

Graph distillation is the process of training a smaller, faster student GNN to mimic the predictions of a larger, more accurate teacher GNN. The student learns from the teacher's soft probability distributions rather than hard labels, capturing richer information about class relationships and uncertainty.

Why not just train a smaller model from scratch?

A student trained with distillation consistently outperforms the same architecture trained from scratch on hard labels. The teacher's soft predictions contain information about class similarities and decision boundaries that hard labels discard. A teacher that assigns 0.7 fraud / 0.3 legitimate tells the student more than a hard 'fraud' label.

What compression ratios are typical for graph distillation?

Typical compression ratios range from 5x to 50x in model size, with inference speedups of 3x to 20x, while retaining 95-99% of the teacher's accuracy. The exact trade-off depends on the task complexity and graph structure.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.