Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Layer9 min read

TransformerConv: Full Transformer Attention on Graphs

TransformerConv brings the key-query-value attention pattern from transformers to graph neural networks. It is the bridge between classic GNN message passing and modern graph transformers, offering more expressiveness than GATConv while staying within the local neighborhood framework.

PyTorch Geometric

TL;DR

  • 1TransformerConv uses full dot-product key-query-value attention on graphs, similar to transformer self-attention but restricted to each node's neighborhood.
  • 2More expressive than GATConv: separate key, query, and value projections capture richer interaction patterns. Naturally produces dynamic attention (no GATv2 fix needed).
  • 3Supports edge features natively. Transaction amounts, timestamps, and relationship types can be injected into the attention computation via edge_attr.
  • 43-5x slower than GCNConv but more accurate on tasks where neighbor relationships are complex. The sweet spot between GAT and full graph transformers.
  • 5KumoRFM's Relational Graph Transformer is a production-grade descendant of TransformerConv, extended with type-aware projections and temporal encodings.

Original Paper

Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification

Shi et al. (2020). IJCAI 2021

Read paper →

What TransformerConv does

TransformerConv applies the standard transformer attention pattern to graph neighborhoods. For each node, it:

  1. Projects the node's features into a query vector
  2. Projects each neighbor's features into key and value vectors
  3. Computes attention scores via dot-product of query and keys
  4. Aggregates neighbor values weighted by attention scores

This is exactly how self-attention works in standard transformers, except the attention is restricted to each node's local neighborhood rather than all tokens in a sequence.

The math (simplified)

TransformerConv formula
Q_i = W_Q · h_i          # query from target node
K_j = W_K · h_j          # key from source neighbor
V_j = W_V · h_j          # value from source neighbor

# Attention score (scaled dot-product)
alpha_ij = softmax_j( Q_i^T · K_j / sqrt(d) )

# Optional: add edge features to attention
alpha_ij = softmax_j( (Q_i^T · K_j + W_E · e_ij) / sqrt(d) )

# Weighted aggregation
h_i' = Σ_j alpha_ij · V_j

Where:
  W_Q, W_K, W_V = projection matrices (separate, unlike GAT)
  d              = dimension per attention head
  e_ij           = optional edge features

Separate key, query, and value projections give TransformerConv more capacity than GAT's shared weight matrix with a single attention vector.

PyG implementation

transformer_conv_model.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv

class GraphTransformer(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
        super().__init__()
        self.conv1 = TransformerConv(in_channels, hidden_channels,
                                      heads=heads, edge_dim=None)
        self.conv2 = TransformerConv(hidden_channels * heads, out_channels,
                                      heads=1, concat=False)

    def forward(self, x, edge_index, edge_attr=None):
        x = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index, edge_attr)
        return x

# With edge features (e.g., transaction amounts)
model = GraphTransformer(
    in_channels=64,
    hidden_channels=32,
    out_channels=num_classes,
    heads=4
)
# edge_attr shape: [num_edges, edge_feature_dim]
out = model(x, edge_index, edge_attr)

Set edge_dim to your edge feature dimension to enable edge-aware attention. Without edge features, TransformerConv behaves like transformer self-attention restricted to neighbors.

When to use TransformerConv

  • Edge features carry important signal. Transaction amounts, timestamps, relationship types. TransformerConv natively incorporates edge attributes into attention scores.
  • Complex neighbor interactions. When the importance of a neighbor depends on a complex function of both node features (not just a simple additive score like GATConv).
  • Bridge to graph transformers. If you are exploring graph transformer architectures, TransformerConv is the natural local component. Combine it with GPSConv for local + global attention.
  • Medium-scale graphs (10K-1M nodes). Where the 3-5x overhead over GCNConv is acceptable and the extra expressiveness improves accuracy.

When not to use TransformerConv

  • Very large graphs without sampling. The key-query-value projections per edge are expensive at billion-node scale. Combine with NeighborLoader or use SAGEConv.
  • Simple homogeneous graphs. On Cora-level tasks, the accuracy gain over GATConv is marginal (~0.2%) while compute cost is higher.

How KumoRFM builds on this

KumoRFM's Relational Graph Transformer is a direct descendant of TransformerConv. It uses the same key-query-value attention pattern but extends it for production enterprise data:

  • Type-specific projections: Separate W_Q, W_K, W_V matrices per node and edge type, so customer-to-order attention uses different parameters than customer-to-merchant
  • Temporal position encoding: Time-stamped edges get positional encodings similar to sequence transformers, capturing when events happened
  • Scalable via sampling: Combines TransformerConv's expressiveness with SAGEConv's sampling efficiency for billion-node production graphs

Frequently asked questions

What is TransformerConv in PyTorch Geometric?

TransformerConv implements a transformer-style multi-head attention layer for graphs from the UniMP paper (Shi et al., 2020). Unlike GATConv which uses a single attention vector, TransformerConv uses full key-query-value attention like standard transformers, making it more expressive for capturing complex neighbor relationships.

How does TransformerConv differ from GATConv?

GATConv uses a simple attention scoring function (concatenate + linear). TransformerConv uses full dot-product attention with separate key, query, and value projections, similar to transformer self-attention. This gives it more expressive power and naturally produces dynamic attention without needing the GATv2 fix.

Can TransformerConv use edge features?

Yes. TransformerConv supports edge features through the edge_attr parameter. Edge features are projected and added to the attention scores, allowing the model to consider relationship properties (e.g., transaction amount, edge type) when computing attention weights.

Is TransformerConv the same as a graph transformer?

TransformerConv applies transformer-style attention within each node's local neighborhood (message passing). Full graph transformers (like GPSConv) combine local message passing with global attention across all nodes. TransformerConv is one building block of a graph transformer architecture.

When should I use TransformerConv vs GPSConv?

Use TransformerConv when local neighborhood attention is sufficient (most tasks). Use GPSConv when you need both local and global attention, such as on molecular property prediction where distant atoms influence properties. GPSConv wraps TransformerConv (or another local layer) with a global attention module.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.