Original Paper
Recipe for a General, Powerful, Scalable Graph Transformer
Rampasek et al. (2022). NeurIPS 2022
Read paper →What GPSConv does
GPSConv is a modular block with four components:
- Positional/structural encodings: Add information about each node's position in the graph (random walk, Laplacian eigenvectors) before the first layer.
- Local MPNN: A message-passing layer (GINConv, PNAConv, etc.) that aggregates neighbor features within the graph structure.
- Global attention: Multi-head self-attention across all nodes, like a standard transformer. Captures dependencies beyond the local neighborhood.
- Feedforward network: An MLP applied per node after combining local and global representations.
Stack multiple GPS blocks (typically 5-10) with residual connections and layer normalization. The result is a deep graph transformer.
PyG implementation
import torch
import torch.nn.functional as F
from torch_geometric.nn import GPSConv, GINConv, global_add_pool
class GPS(torch.nn.Module):
def __init__(self, in_channels, hidden, out_channels, num_layers=5):
super().__init__()
self.node_emb = torch.nn.Linear(in_channels, hidden)
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
# Local layer: GINConv with 2-layer MLP
local_nn = torch.nn.Sequential(
torch.nn.Linear(hidden, hidden),
torch.nn.ReLU(),
torch.nn.Linear(hidden, hidden),
)
local_layer = GINConv(local_nn)
# GPSConv wraps local + global attention
self.convs.append(GPSConv(
channels=hidden,
conv=local_layer,
heads=4,
attn_dropout=0.5,
))
self.classifier = torch.nn.Linear(hidden, out_channels)
def forward(self, x, edge_index, batch):
x = self.node_emb(x)
for conv in self.convs:
x = conv(x, edge_index, batch)
x = global_add_pool(x, batch)
return self.classifier(x)
model = GPS(in_channels=9, hidden=64, out_channels=1, num_layers=5)GPSConv wraps a local layer (GINConv here) and adds global multi-head attention. The batch parameter is needed for global attention to operate within each graph in the batch.
When to use GPSConv
- Long-range dependency tasks. When distant nodes influence each other (molecular properties, protein function prediction), global attention captures these dependencies that local message passing misses.
- Long-range graph benchmarks. Peptides-func and Peptides-struct specifically test long-range reasoning. GPS outperforms local-only models here.
- Small to medium graphs. The O(N^2) global attention is practical for graphs with up to ~10K nodes per graph (molecular, protein datasets).
- Research on graph transformers. GPS provides a modular framework to test different local layers, attention types, and positional encodings.
When not to use GPSConv
- Large graphs. Global attention is O(N^2) in memory and compute. For graphs with 100K+ nodes, use local-only layers (TransformerConv, GATConv) or switch to linear attention.
- When local structure is sufficient. On most node classification tasks (Cora, CiteSeer), 2-3 hops of local context capture all needed information. Global attention adds cost without benefit.