Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide9 min read

Graph Transformers: Attention Beyond Neighbors

Graph transformers apply self-attention to graphs, letting every node attend to every other node. This overcomes the three fundamental limitations of message passing: over-smoothing, over-squashing, and bounded expressiveness.

PyTorch Geometric

TL;DR

  • 1Graph transformers apply self-attention to graphs. Every node can attend to every other node in the (sub)graph, not just direct neighbors. This is global instead of local.
  • 2They solve three message passing limitations: over-smoothing (no need to stack many layers), over-squashing (no bottleneck edges), and expressiveness ceiling (attention can distinguish structures that 1-WL cannot).
  • 3GPSConv in PyG combines local message passing + global attention in each layer. This gives both local structural awareness and long-range information flow.
  • 4Positional encodings (Laplacian eigenvectors, random walks) inject graph structure into the attention mechanism. Without them, the transformer ignores topology.
  • 5KumoRFM uses a Relational Graph Transformer that extends this architecture to heterogeneous relational databases, achieving state-of-the-art on the RelBench benchmark.

A graph transformer is an attention-based architecture that overcomes the fundamental limits of message passing. Standard GNN layers (GCN, GAT, GraphSAGE) pass messages only between direct neighbors. After 2 layers, a node sees its 2-hop neighborhood. After 5-6 layers, over-smoothing collapses all representations. Graph transformers bypass this by letting every node attend directly to every other node in the subgraph, with positional encodings that encode graph structure.

This is the same self-attention mechanism that powers language models like GPT, adapted for graph-structured data. The key challenge is incorporating graph topology: in a sequence, position is trivial (token 1, token 2, ...). In a graph, “position” is defined by the complex structure of the graph itself.

Why message passing hits a ceiling

Three limitations drive the shift to graph transformers:

  1. Over-smoothing: after 5-6 layers of averaging neighbor features, all node representations converge. On Cora, accuracy peaks at 2-3 layers (~81.5%) and drops to ~30% at 8 layers. You cannot go deeper.
  2. Over-squashing: information from a node 5 hops away must pass through bottleneck edges, getting compressed into fixed-size vectors at each step. Signal degrades exponentially with distance.
  3. Expressiveness ceiling: standard message passing is bounded by the 1-WL graph isomorphism test. Certain graph structures (e.g., 6-cycles vs two 3-cycles) are provably indistinguishable.

Architecture: GPS (General, Powerful, Scalable)

The GPS framework, implemented as GPSConv in PyG, is the dominant graph transformer architecture. Each layer combines:

  • Local message passing: a standard GNN layer (GCN, GIN) that captures local neighborhood structure
  • Global self-attention: a transformer attention module that captures long-range dependencies
  • Combination: the outputs are added together, so each node gets both local and global information
gps_graph_transformer.py
import torch
from torch_geometric.nn import GPSConv, GINConv
from torch.nn import Linear, Sequential, ReLU

class GraphTransformer(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers=4):
        super().__init__()
        self.input_proj = Linear(in_dim, hidden_dim)

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            # Local: GIN for max expressiveness
            nn = Sequential(Linear(hidden_dim, hidden_dim), ReLU(),
                          Linear(hidden_dim, hidden_dim))
            gin = GINConv(nn)

            # GPS = local GNN + global attention
            gps = GPSConv(hidden_dim, conv=gin, heads=4, attn_dropout=0.1)
            self.convs.append(gps)

        self.output = Linear(hidden_dim, out_dim)

    def forward(self, x, edge_index, batch):
        x = self.input_proj(x)
        for conv in self.convs:
            x = conv(x, edge_index, batch)
        return self.output(x)

GPS combines local GIN convolution with global multi-head attention. 4 layers of GPS = deep local + long-range global.

Positional and structural encodings

Self-attention is permutation-invariant: without positional information, the transformer does not know the graph structure. Two types of encodings solve this:

  • Laplacian Positional Encoding (LapPE): uses the smallest eigenvectors of the graph Laplacian as node positions. Nodes close in the graph get similar positional vectors.
  • Random Walk Structural Encoding (RWSE): the probability of a random walk returning to the starting node after k steps. Captures local structural properties like clustering coefficient and local density.
positional_encodings.py
import torch_geometric.transforms as T

# Add Laplacian positional encoding (8 eigenvectors)
transform = T.AddLaplacianEigenvectorPE(k=8)
data = transform(data)
# data.laplacian_eigenvector_pe: [num_nodes, 8]

# Add random walk structural encoding (16 steps)
transform = T.AddRandomWalkPE(walk_length=16)
data = transform(data)
# data.random_walk_pe: [num_nodes, 16]

# Concatenate with node features before feeding to transformer
x = torch.cat([data.x, data.laplacian_eigenvector_pe, data.random_walk_pe], dim=-1)

Positional encodings tell the transformer where each node sits in the graph. Without them, all nodes look identical to the attention mechanism.

Enterprise example: cross-department fraud detection

In a bank's transaction graph, fraud rings often span many hops: account A sends to B, B sends to C, C sends to D, D sends back to a shell company linked to A. This is a 4-hop pattern. Standard message passing with 2-3 layers cannot see it.

A graph transformer with global attention lets node A directly attend to node D, even if they are 4 hops apart. The attention mechanism learns that certain long-range patterns (circular money flows, rapid fan-out then fan-in) are suspicious. On the RelBench benchmark, graph transformers achieve 76.71 AUROC zero-shot vs 62.44 for flat-table LightGBM, with much of the improvement coming from exactly these long-range patterns.

Frequently asked questions

What is a graph transformer?

A graph transformer is a neural network architecture that applies self-attention to graphs. Instead of only passing messages between direct neighbors (like GCN or GAT), graph transformers let every node attend to every other node in the (sub)graph, with positional and structural encodings to incorporate graph topology.

How do graph transformers differ from standard message passing?

Standard message passing is local: each node only sees direct neighbors per layer. Graph transformers are global: each node can attend to any other node. This eliminates over-smoothing (no need for many layers), over-squashing (no bottleneck edges), and expressiveness limits (attention can distinguish structures that message passing cannot).

What is GPSConv in PyG?

GPSConv (General, Powerful, Scalable Graph Transformer) combines a local message passing layer (like GCN or GIN) with global self-attention in each layer. This gives the model both local structural awareness and long-range attention. It is the most popular graph transformer implementation in PyG.

How do graph transformers handle graph structure?

Through positional and structural encodings. Laplacian eigenvector positional encodings give each node a position in the graph. Random walk structural encodings capture local topology. These are added to node features before attention, so the transformer knows the graph structure even though attention itself is global.

Are graph transformers better than GCN and GAT?

On large-scale benchmarks, yes. Graph transformers consistently outperform message passing GNNs on datasets where long-range dependencies matter. On small homogeneous benchmarks like Cora, the difference is marginal. The advantage grows with graph size and structural complexity, especially on heterogeneous enterprise data.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.