The problem with GATConv
The original GATConv computes attention scores with this pattern:
# GATv1 (static attention)
e_ij = a^T · [W·h_i || W·h_j]
= a_left^T · W·h_i + a_right^T · W·h_j
# Problem: the terms are INDEPENDENT. The ranking of
# neighbors j can be the same for every query node i.
# GATv2 (dynamic attention)
e_ij = a^T · LeakyReLU(W · [h_i || h_j])
# Fix: apply nonlinearity AFTER combining features.
# Now the score depends on the INTERACTION between i and j.The key insight: in GATv1, the attention score decomposes into two independent terms. GATv2 applies LeakyReLU after combining features, making the score depend on the joint representation.
In GATv1, the attention score for edge (i, j) decomposes into a sum of two terms: one that depends only on node i and one that depends only on node j. This means the ranking of neighbors can be the same for every query node. In practice, high-degree or high-feature-norm nodes get high attention from everyone.
GATv2 fixes this by applying the nonlinearity after combining the source and target features. Now the attention score depends on the interaction between the two nodes, not just their individual properties.
PyG implementation
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv # Drop-in replacement
class GATv2(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
super().__init__()
# Same API as GATConv
self.conv1 = GATv2Conv(in_channels, hidden_channels, heads=heads)
self.conv2 = GATv2Conv(hidden_channels * heads, out_channels,
heads=1, concat=False)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return x
# Migration from GATConv: change the import, nothing else
# from torch_geometric.nn import GATConv # old
# from torch_geometric.nn import GATv2Conv # newGATv2Conv is a true drop-in replacement. The constructor, forward signature, and output shape are identical to GATConv.
When static attention fails
Static attention is problematic whenever the relevance of a neighbor depends on who is asking. Consider a fraud detection graph:
- Node A (fraudster): Connected to merchant M. Merchant M should get high attention because it is a known fraud-associated entity.
- Node B (legitimate user): Also connected to merchant M. Merchant M should get low attention because B's other connections indicate legitimate behavior.
With static attention, merchant M gets the same attention score from both A and B. With dynamic attention, the model can learn that M is important in the context of A but not B.
When to use GATv2Conv
- Always, when you would use GATConv. GATv2Conv is strictly more expressive at nearly the same cost. There is no accuracy or speed penalty.
- Fraud detection and anomaly detection. Dynamic attention correctly handles the case where the same entity is suspicious in one context and benign in another.
- Knowledge graphs and heterogeneous networks. High-degree hub nodes play different roles depending on the query context. Dynamic attention captures this.
When not to use GATv2Conv
- When attention itself is unnecessary. If all neighbors are equally informative (regular grids, uniform molecular bonds), GCNConv is faster and simpler.
- Reproducing published GATv1 results. Some benchmarks report GATConv numbers specifically. Use the original for fair comparisons.