Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide7 min read

Few-Shot Learning: Training GNNs from Very Few Labels

Few-shot learning makes accurate predictions from just 1-10 labeled examples per class. Graph structure provides implicit supervision that compensates for label scarcity.

PyTorch Geometric

TL;DR

  • 1Few-shot learning classifies nodes using just 1-10 labeled examples per class. Graph structure provides implicit supervision: connected nodes share labels, structural roles predict functional roles.
  • 2Prototypical networks: compute class prototypes (average embeddings of labeled nodes), classify new nodes by nearest prototype. Simple, effective, and works with any GNN encoder.
  • 3Meta-learning (MAML, ProtoNet): train on many few-shot tasks during pre-training. The model learns how to learn from few examples, not just how to classify.
  • 4Pre-training + few-shot is the strongest combination. SSL pre-training provides general representations; few-shot fine-tuning adapts with minimal labels.
  • 5Enterprise impact: confirming fraud costs $50-100 per investigation. Few-shot learning maximizes the value of each label by propagating through graph structure.

Few-shot learning trains GNNs to make predictions from very few labeled examples per class. In a 5-shot setting, you have exactly 5 labeled nodes per class. That is it. The model must classify the remaining thousands of unlabeled nodes using only those 5 examples plus the graph structure. This is realistic for enterprise settings where labels are expensive: each confirmed fraud case requires investigation, each medical label requires expert review.

Graphs are uniquely well-suited for few-shot learning because the structure itself provides implicit supervision. Connected nodes tend to share labels. Structurally similar nodes tend to have similar roles. A few labeled nodes in a dense graph provide far more signal than a few labeled rows in a flat table.

Prototypical networks for graphs

The simplest and most effective few-shot approach for graphs:

graph_prototypical_network.py
import torch
from torch_geometric.nn import GCNConv

class GraphProtoNet(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

    def few_shot_classify(self, x, edge_index, support_idx, support_labels):
        """
        support_idx: indices of labeled nodes (e.g., 5 per class)
        support_labels: labels of those nodes
        """
        # Encode all nodes using graph structure
        z = self.encode(x, edge_index)

        # Compute prototype per class (mean of support embeddings)
        prototypes = {}
        for label in support_labels.unique():
            mask = support_labels == label
            prototypes[label.item()] = z[support_idx[mask]].mean(dim=0)

        # Classify all nodes by nearest prototype
        proto_stack = torch.stack(list(prototypes.values()))
        dists = torch.cdist(z, proto_stack)
        return dists.argmin(dim=-1)  # predicted class for all nodes

5 labeled nodes per class, thousands classified via nearest prototype in GNN embedding space.

Enterprise example: new fraud pattern detection

A bank discovers a new type of fraud that does not match existing patterns. Investigators confirm 8 accounts as this new fraud type. The question: are there more?

  1. Use a pre-trained GNN to encode all accounts in the transaction graph
  2. Compute the fraud prototype from the 8 confirmed accounts
  3. Find the nearest 500 accounts to this prototype
  4. Investigators review these 500 accounts (instead of 10 million)

The graph-aware encoding ensures that accounts connected to the 8 confirmed fraudsters (through shared merchants, intermediaries, or timing patterns) rank highest. The 8 labels, amplified by graph structure, identify a fraud ring of 200 accounts that would have been invisible to rule-based systems.

Meta-learning for graphs

Meta-learning goes further: instead of training a general encoder and applying prototypes, it trains the model specifically to be good at few-shot tasks:

  • MAML: learn initial weights that adapt quickly to any few-shot task in 1-5 gradient steps
  • ProtoNet: learn an embedding space where prototypical classification works optimally
  • Task sampling: during meta-training, sample many random few-shot tasks from the graph. Each task has a random class subset with random support nodes.

Frequently asked questions

What is few-shot learning on graphs?

Few-shot learning trains GNNs to make predictions using very few labeled examples per class, typically 1-10. Instead of requiring thousands of labeled nodes, a few-shot GNN learns to generalize from just a handful of examples by leveraging graph structure, pre-trained representations, or meta-learning strategies.

Why is few-shot learning important for enterprise graphs?

Enterprise labels are expensive and scarce. Confirming fraud requires investigation ($50-100 per case). Labeling churn requires waiting for customers to actually leave. Medical labels require expert review. Few-shot learning maximizes the value of each label by combining it with abundant unlabeled graph structure.

What is a prototypical network for graphs?

A prototypical network computes a 'prototype' (average embedding) for each class from the few labeled examples. New nodes are classified by nearest prototype. With a good GNN encoder (especially pre-trained), even 5 examples per class produce reliable prototypes because graph-aware embeddings cluster by structural similarity.

How does meta-learning help few-shot graph learning?

Meta-learning trains the model on many few-shot tasks during pre-training, learning how to learn from few examples. At test time, the model can quickly adapt to a new few-shot task. MAML and ProtoNet are popular meta-learning frameworks adapted for graphs.

Can graph structure compensate for few labels?

Yes. Graph structure provides implicit supervision: connected nodes tend to share labels (homophily), structural roles tend to correspond to functional roles, and multi-hop neighborhoods provide context. A GNN with just 10 labels per class can propagate those labels through the graph to achieve 60-70% accuracy on citation networks.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.