Few-shot learning trains GNNs to make predictions from very few labeled examples per class. In a 5-shot setting, you have exactly 5 labeled nodes per class. That is it. The model must classify the remaining thousands of unlabeled nodes using only those 5 examples plus the graph structure. This is realistic for enterprise settings where labels are expensive: each confirmed fraud case requires investigation, each medical label requires expert review.
Graphs are uniquely well-suited for few-shot learning because the structure itself provides implicit supervision. Connected nodes tend to share labels. Structurally similar nodes tend to have similar roles. A few labeled nodes in a dense graph provide far more signal than a few labeled rows in a flat table.
Prototypical networks for graphs
The simplest and most effective few-shot approach for graphs:
import torch
from torch_geometric.nn import GCNConv
class GraphProtoNet(torch.nn.Module):
def __init__(self, in_dim, hidden_dim):
super().__init__()
self.conv1 = GCNConv(in_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, hidden_dim)
def encode(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
return self.conv2(x, edge_index)
def few_shot_classify(self, x, edge_index, support_idx, support_labels):
"""
support_idx: indices of labeled nodes (e.g., 5 per class)
support_labels: labels of those nodes
"""
# Encode all nodes using graph structure
z = self.encode(x, edge_index)
# Compute prototype per class (mean of support embeddings)
prototypes = {}
for label in support_labels.unique():
mask = support_labels == label
prototypes[label.item()] = z[support_idx[mask]].mean(dim=0)
# Classify all nodes by nearest prototype
proto_stack = torch.stack(list(prototypes.values()))
dists = torch.cdist(z, proto_stack)
return dists.argmin(dim=-1) # predicted class for all nodes5 labeled nodes per class, thousands classified via nearest prototype in GNN embedding space.
Enterprise example: new fraud pattern detection
A bank discovers a new type of fraud that does not match existing patterns. Investigators confirm 8 accounts as this new fraud type. The question: are there more?
- Use a pre-trained GNN to encode all accounts in the transaction graph
- Compute the fraud prototype from the 8 confirmed accounts
- Find the nearest 500 accounts to this prototype
- Investigators review these 500 accounts (instead of 10 million)
The graph-aware encoding ensures that accounts connected to the 8 confirmed fraudsters (through shared merchants, intermediaries, or timing patterns) rank highest. The 8 labels, amplified by graph structure, identify a fraud ring of 200 accounts that would have been invisible to rule-based systems.
Meta-learning for graphs
Meta-learning goes further: instead of training a general encoder and applying prototypes, it trains the model specifically to be good at few-shot tasks:
- MAML: learn initial weights that adapt quickly to any few-shot task in 1-5 gradient steps
- ProtoNet: learn an embedding space where prototypical classification works optimally
- Task sampling: during meta-training, sample many random few-shot tasks from the graph. Each task has a random class subset with random support nodes.