Contrastive learning trains GNNs by comparing similar and dissimilar graph views. The idea: create two different augmented versions of the same graph (positive pair) and encode both with the GNN. Train the model to produce similar embeddings for the positive pair and dissimilar embeddings for views from different graphs (negative pairs). The resulting representations capture structural properties that are invariant to random noise, which tend to be the properties that matter for downstream tasks.
This is the same principle behind SimCLR (images) and CLIP (image-text). Applied to graphs, it is the most effective self-supervised learning paradigm, often matching supervised performance with 10-100x fewer labels.
The contrastive pipeline
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import dropout_edge
class GraphCL(torch.nn.Module):
def __init__(self, in_dim, hidden_dim, proj_dim):
super().__init__()
# Encoder
self.conv1 = GCNConv(in_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, hidden_dim)
# Projection head (discarded after pre-training)
self.proj = torch.nn.Sequential(
torch.nn.Linear(hidden_dim, proj_dim),
torch.nn.ReLU(),
torch.nn.Linear(proj_dim, proj_dim),
)
def encode(self, x, edge_index, batch):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return global_mean_pool(x, batch)
def forward(self, data):
# Create two augmented views
ei1, _ = dropout_edge(data.edge_index, p=0.2)
ei2, _ = dropout_edge(data.edge_index, p=0.2)
# Encode both views
z1 = self.proj(self.encode(data.x, ei1, data.batch))
z2 = self.proj(self.encode(data.x, ei2, data.batch))
return info_nce_loss(z1, z2, temperature=0.1)
def info_nce_loss(z1, z2, temperature=0.1):
z1 = F.normalize(z1)
z2 = F.normalize(z2)
sim = torch.mm(z1, z2.t()) / temperature
labels = torch.arange(z1.size(0), device=z1.device)
return (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2GraphCL: augment twice, encode twice, contrast. The projection head is discarded after pre-training; only the encoder transfers.
Augmentation strategies
The choice of augmentation defines what the model learns:
- Edge dropping (p=0.1-0.3): model learns representations invariant to specific edges. Good for noisy graphs.
- Feature masking (p=0.1-0.3): model learns to infer features from context. Good when features are redundant.
- Node dropping (p=0.05-0.2): model learns representations not dependent on any single node. Good for robustness.
- Subgraph sampling: model learns from local context. Good for large graphs where global views are expensive.
GCA (Graph Contrastive Learning with Adaptive Augmentation) learns which edges and features to drop instead of dropping uniformly. Important edges are kept; redundant ones are dropped. This produces better augmentations and better representations.
Enterprise example: product embedding pre-training
An e-commerce platform wants to learn product embeddings for recommendation, search, and category prediction. They have 5 million products with 50 million co-purchase edges but only 10,000 category labels.
- Pre-train with contrastive learning on the full co-purchase graph (no labels)
- The model learns: products frequently co-purchased are similar, products in the same community cluster together
- Fine-tune for category prediction with 10,000 labels
- The fine-tuned model achieves 92% accuracy vs 78% for training from scratch
The same pre-trained embeddings also power recommendation (nearest neighbors in embedding space) and search (query-product similarity) without additional training.