Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide7 min read

Graph Augmentation: Creating Modified Graph Views for Training

Graph augmentation produces variant views of a graph by dropping edges, masking features, or perturbing structure. These views drive contrastive self-supervised learning and provide regularization during supervised training.

PyTorch Geometric

TL;DR

  • 1Graph augmentation creates modified graph views by applying random transformations: edge dropping, feature masking, node dropping, or subgraph sampling.
  • 2Two purposes: creating positive pairs for contrastive self-supervised learning, and regularizing supervised training to improve generalization.
  • 3Key difference from image augmentation: graph structure IS the content. Dropping the wrong edge can destroy critical information. Augmentation rate tuning is essential.
  • 4Adaptive augmentation (GCA) learns which elements to preserve vs drop, outperforming uniform random augmentation by targeting redundant structure.
  • 5Standard rates: edge dropping 10-30%, feature masking 10-30%, node dropping 5-20%. Start conservative and increase. Too aggressive destroys graph structure.

Graph augmentation creates modified graph views by applying random transformations to the structure, features, or nodes. In image augmentation, you rotate, crop, or color-jitter an image, and the content stays the same. In graph augmentation, you drop edges, mask features, or remove nodes, creating a perturbed view that preserves essential structural properties while introducing variation.

These augmented views serve two purposes: they provide the positive pairs for contrastive self-supervised learning (where the model learns to match two views of the same graph), and they provide regularization during supervised training (similar to how dropout prevents overfitting by randomly perturbing the network).

Augmentation strategies

graph_augmentations.py
import torch
from torch_geometric.utils import dropout_edge, to_undirected

def edge_dropping(edge_index, drop_rate=0.2):
    """Remove random edges. Most common augmentation."""
    edge_index, mask = dropout_edge(edge_index, p=drop_rate)
    return edge_index

def feature_masking(x, mask_rate=0.2):
    """Zero out random feature dimensions."""
    mask = torch.rand(x.size(1)) > mask_rate
    return x * mask.float().unsqueeze(0)

def node_dropping(x, edge_index, drop_rate=0.1):
    """Remove random nodes and their edges."""
    keep_mask = torch.rand(x.size(0)) > drop_rate
    keep_idx = keep_mask.nonzero().squeeze()

    # Remap edges to new node indices
    new_x = x[keep_idx]
    # ... remap edge_index to only include kept nodes
    return new_x, new_edge_index

def subgraph_sampling(edge_index, num_nodes, ratio=0.8):
    """Sample a connected subgraph containing ratio% of nodes."""
    # Random walk starting from random seed node
    # Keep visited nodes and their induced edges
    pass  # implementation varies

# Compose multiple augmentations
def augment(x, edge_index):
    edge_index = edge_dropping(edge_index, 0.2)
    x = feature_masking(x, 0.15)
    return x, edge_index

Four augmentation strategies. Composing multiple augmentations (edge drop + feature mask) generally works better than any single one.

Which augmentation for which domain

  • Social networks: edge dropping works well because friendships are redundant (removing one connection among many has low information loss).
  • Molecular graphs: be careful with edge dropping since every bond is structurally important. Feature masking is safer. Subgraph sampling preserves local chemistry.
  • Knowledge graphs: edge dropping is acceptable because knowledge graphs are inherently incomplete. Feature masking is less applicable since entities often have sparse features.
  • Enterprise graphs: moderate edge dropping (10-20%) with feature masking (15-25%) is a good starting point. The redundancy in large graphs makes augmentation generally safe.

Enterprise example: robust transaction embeddings

A bank wants transaction graph embeddings that are robust to data quality issues: missing transactions (dropped edges), incomplete account attributes (masked features), and accounts closed after training (dropped nodes).

By training with augmentation that simulates these real-world data issues:

  • Edge dropping simulates missing transactions
  • Feature masking simulates incomplete records
  • Node dropping simulates closed or removed accounts

The resulting model produces embeddings that are stable under production data quality conditions, not just clean training conditions.

Adaptive augmentation (GCA)

Uniform random augmentation treats all edges and features equally. But some edges are critical (a bridge connecting two communities) and some are redundant (one of 100 connections in a dense cluster). GCA learns to:

  • Preserve high-centrality edges (bridges, connectors)
  • Drop low-centrality edges (redundant within clusters)
  • Preserve high-variance features (informative)
  • Mask low-variance features (redundant)

Frequently asked questions

What is graph augmentation?

Graph augmentation creates modified versions of a graph by applying random transformations: dropping edges, masking features, removing nodes, adding noise, or sampling subgraphs. These augmented views serve two purposes: training data for contrastive self-supervised learning, and regularization during supervised training.

What augmentation strategies work best for graphs?

Four main strategies: edge dropping (remove 10-30% of random edges), feature masking (zero out 10-30% of feature dimensions), node dropping (remove 5-20% of nodes with their edges), and subgraph sampling (extract random k-hop subgraphs). The best choice depends on the graph domain and downstream task.

How is graph augmentation different from image augmentation?

Image augmentation (rotation, cropping, color jitter) preserves semantic content by exploiting known invariances of visual recognition. Graph augmentation is harder because graph structure is the content. Dropping an edge could remove a critical connection or a noisy one. There are fewer universal invariances in graph data, making augmentation strategy selection more important.

What is adaptive graph augmentation?

Adaptive augmentation (GCA) learns which augmentations to apply instead of using uniform random perturbations. It identifies important edges (high centrality) and important features (high variance) and preferentially preserves them while dropping redundant elements. This produces better augmented views for contrastive learning.

Can graph augmentation hurt performance?

Yes. Aggressive augmentation can destroy critical graph structure. Dropping important bridge edges disconnects communities. Masking key features removes task-relevant signal. The augmentation rate must be tuned: too little does not regularize enough, too much destroys information. Start conservative (10-20% drop rate) and increase.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.