Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide8 min read

Contrastive Learning: Comparing Similar and Dissimilar Graph Views

Contrastive learning trains GNNs by asking: are these two views from the same graph or different graphs? This simple question, applied at scale, produces representations that rival supervised training without requiring a single label.

PyTorch Geometric

TL;DR

  • 1Contrastive learning creates two augmented views of the same graph, encodes both with a GNN, and trains the model to recognize them as the same while distinguishing them from other graphs.
  • 2The InfoNCE loss pushes positive pair embeddings together and negative pair embeddings apart. Temperature controls the sharpness of the similarity distribution.
  • 3Key augmentations: edge dropping, feature masking, node dropping, subgraph sampling. The augmentation strategy defines what properties the model learns to preserve.
  • 4State-of-the-art methods: GraphCL (graph-level), GRACE (node-level), GCA (adaptive augmentation). All follow the augment-encode-contrast pipeline.
  • 5Best used as pre-training before fine-tuning with limited labels. Contrastive pre-training + 100 labels often matches supervised training with 10,000 labels.

Contrastive learning trains GNNs by comparing similar and dissimilar graph views. The idea: create two different augmented versions of the same graph (positive pair) and encode both with the GNN. Train the model to produce similar embeddings for the positive pair and dissimilar embeddings for views from different graphs (negative pairs). The resulting representations capture structural properties that are invariant to random noise, which tend to be the properties that matter for downstream tasks.

This is the same principle behind SimCLR (images) and CLIP (image-text). Applied to graphs, it is the most effective self-supervised learning paradigm, often matching supervised performance with 10-100x fewer labels.

The contrastive pipeline

graph_contrastive_learning.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import dropout_edge

class GraphCL(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, proj_dim):
        super().__init__()
        # Encoder
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        # Projection head (discarded after pre-training)
        self.proj = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, proj_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(proj_dim, proj_dim),
        )

    def encode(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return global_mean_pool(x, batch)

    def forward(self, data):
        # Create two augmented views
        ei1, _ = dropout_edge(data.edge_index, p=0.2)
        ei2, _ = dropout_edge(data.edge_index, p=0.2)

        # Encode both views
        z1 = self.proj(self.encode(data.x, ei1, data.batch))
        z2 = self.proj(self.encode(data.x, ei2, data.batch))

        return info_nce_loss(z1, z2, temperature=0.1)

def info_nce_loss(z1, z2, temperature=0.1):
    z1 = F.normalize(z1)
    z2 = F.normalize(z2)
    sim = torch.mm(z1, z2.t()) / temperature
    labels = torch.arange(z1.size(0), device=z1.device)
    return (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2

GraphCL: augment twice, encode twice, contrast. The projection head is discarded after pre-training; only the encoder transfers.

Augmentation strategies

The choice of augmentation defines what the model learns:

  • Edge dropping (p=0.1-0.3): model learns representations invariant to specific edges. Good for noisy graphs.
  • Feature masking (p=0.1-0.3): model learns to infer features from context. Good when features are redundant.
  • Node dropping (p=0.05-0.2): model learns representations not dependent on any single node. Good for robustness.
  • Subgraph sampling: model learns from local context. Good for large graphs where global views are expensive.

GCA (Graph Contrastive Learning with Adaptive Augmentation) learns which edges and features to drop instead of dropping uniformly. Important edges are kept; redundant ones are dropped. This produces better augmentations and better representations.

Enterprise example: product embedding pre-training

An e-commerce platform wants to learn product embeddings for recommendation, search, and category prediction. They have 5 million products with 50 million co-purchase edges but only 10,000 category labels.

  1. Pre-train with contrastive learning on the full co-purchase graph (no labels)
  2. The model learns: products frequently co-purchased are similar, products in the same community cluster together
  3. Fine-tune for category prediction with 10,000 labels
  4. The fine-tuned model achieves 92% accuracy vs 78% for training from scratch

The same pre-trained embeddings also power recommendation (nearest neighbors in embedding space) and search (query-product similarity) without additional training.

Frequently asked questions

What is contrastive learning on graphs?

Contrastive learning trains a GNN to produce similar embeddings for different augmented views of the same graph (positive pairs) and dissimilar embeddings for views of different graphs or nodes (negative pairs). The model learns representations that capture essential structural properties invariant to augmentation noise.

What are positive and negative pairs in graph contrastive learning?

Positive pairs: two augmented views of the same node or graph (e.g., the same node after different random edge drops). Negative pairs: views from different nodes or graphs. The contrastive loss pushes positive pairs together and negative pairs apart in embedding space.

What augmentations are used for graph contrastive learning?

Common augmentations: edge dropping (remove random edges), feature masking (zero random features), node dropping (remove random nodes), subgraph sampling (extract random subgraphs), and edge addition (add random edges). The choice of augmentation matters: it defines what the model learns to be invariant to.

What is InfoNCE loss?

InfoNCE (Noise Contrastive Estimation) is the standard loss function for contrastive learning. For a positive pair (z_i, z_j), it computes: -log(exp(sim(z_i, z_j)/tau) / sum(exp(sim(z_i, z_k)/tau))) over all k including negatives. It maximizes similarity of positive pairs relative to all negatives. Temperature tau controls sharpness.

How does contrastive learning compare to supervised GNN training?

Contrastive learning requires no labels and uses the full graph structure. Supervised training requires labels but directly optimizes for the target task. In practice: contrastive pre-training + supervised fine-tuning outperforms either alone, especially when labels are scarce (fewer than 100 per class).

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.