Original Paper
Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification
Shi et al. (2020). IJCAI 2021
Read paper →What TransformerConv does
TransformerConv applies the standard transformer attention pattern to graph neighborhoods. For each node, it:
- Projects the node's features into a query vector
- Projects each neighbor's features into key and value vectors
- Computes attention scores via dot-product of query and keys
- Aggregates neighbor values weighted by attention scores
This is exactly how self-attention works in standard transformers, except the attention is restricted to each node's local neighborhood rather than all tokens in a sequence.
The math (simplified)
Q_i = W_Q · h_i # query from target node
K_j = W_K · h_j # key from source neighbor
V_j = W_V · h_j # value from source neighbor
# Attention score (scaled dot-product)
alpha_ij = softmax_j( Q_i^T · K_j / sqrt(d) )
# Optional: add edge features to attention
alpha_ij = softmax_j( (Q_i^T · K_j + W_E · e_ij) / sqrt(d) )
# Weighted aggregation
h_i' = Σ_j alpha_ij · V_j
Where:
W_Q, W_K, W_V = projection matrices (separate, unlike GAT)
d = dimension per attention head
e_ij = optional edge featuresSeparate key, query, and value projections give TransformerConv more capacity than GAT's shared weight matrix with a single attention vector.
PyG implementation
import torch
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv
class GraphTransformer(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
super().__init__()
self.conv1 = TransformerConv(in_channels, hidden_channels,
heads=heads, edge_dim=None)
self.conv2 = TransformerConv(hidden_channels * heads, out_channels,
heads=1, concat=False)
def forward(self, x, edge_index, edge_attr=None):
x = self.conv1(x, edge_index, edge_attr)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index, edge_attr)
return x
# With edge features (e.g., transaction amounts)
model = GraphTransformer(
in_channels=64,
hidden_channels=32,
out_channels=num_classes,
heads=4
)
# edge_attr shape: [num_edges, edge_feature_dim]
out = model(x, edge_index, edge_attr)Set edge_dim to your edge feature dimension to enable edge-aware attention. Without edge features, TransformerConv behaves like transformer self-attention restricted to neighbors.
When to use TransformerConv
- Edge features carry important signal. Transaction amounts, timestamps, relationship types. TransformerConv natively incorporates edge attributes into attention scores.
- Complex neighbor interactions. When the importance of a neighbor depends on a complex function of both node features (not just a simple additive score like GATConv).
- Bridge to graph transformers. If you are exploring graph transformer architectures, TransformerConv is the natural local component. Combine it with GPSConv for local + global attention.
- Medium-scale graphs (10K-1M nodes). Where the 3-5x overhead over GCNConv is acceptable and the extra expressiveness improves accuracy.
When not to use TransformerConv
- Very large graphs without sampling. The key-query-value projections per edge are expensive at billion-node scale. Combine with NeighborLoader or use SAGEConv.
- Simple homogeneous graphs. On Cora-level tasks, the accuracy gain over GATConv is marginal (~0.2%) while compute cost is higher.
How KumoRFM builds on this
KumoRFM's Relational Graph Transformer is a direct descendant of TransformerConv. It uses the same key-query-value attention pattern but extends it for production enterprise data:
- Type-specific projections: Separate W_Q, W_K, W_V matrices per node and edge type, so customer-to-order attention uses different parameters than customer-to-merchant
- Temporal position encoding: Time-stamped edges get positional encodings similar to sequence transformers, capturing when events happened
- Scalable via sampling: Combines TransformerConv's expressiveness with SAGEConv's sampling efficiency for billion-node production graphs