What GATConv does
GATConv replaces GCNConv's fixed degree-based weighting with learned attention weights. For each node, it:
- Transforms each neighbor's features with a shared weight matrix
- Computes an attention score for each edge (how important is this neighbor?)
- Normalizes scores across all neighbors using softmax
- Aggregates neighbor features weighted by these attention scores
The attention scores are learned end-to-end. The model discovers which neighbors are informative for the downstream task without manual feature engineering.
The math (simplified)
# Attention coefficient between nodes i and j
e_ij = LeakyReLU( a^T · [W·h_i || W·h_j] )
# Normalize across all neighbors
alpha_ij = softmax_j(e_ij) = exp(e_ij) / Σ_k exp(e_ik)
# Weighted aggregation
h_i' = σ( Σ_j alpha_ij · W · h_j )
Where:
W = shared weight matrix
a = attention vector (learnable)
|| = concatenation
alpha = attention weight (sums to 1 over neighbors)The attention vector 'a' learns a scoring function over pairs of transformed features. Softmax ensures weights sum to 1 across neighbors.
PyG implementation
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv
class GAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
# Output layer: average heads instead of concatenating
self.conv2 = GATConv(hidden_channels * heads, out_channels,
heads=1, concat=False)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return x
# Usage
model = GAT(dataset.num_features, 8, dataset.num_classes, heads=8)
# Note: hidden dim is 8 * 8 = 64 after concat in layer 1With 8 heads and 8 features per head, layer 1 outputs 64-dimensional vectors. Layer 2 averages heads to produce final class logits.
When to use GATConv
- Fraud detection. One connection to a known fraudulent account is more informative than hundreds of normal transactions. GATConv learns to up-weight suspicious edges and down-weight routine ones.
- Recommendation systems. A user's most recent interactions are typically more relevant than older ones. Attention lets the model focus on the interactions that best predict preferences.
- Heterogeneous neighbor importance. Any graph where not all edges carry equal signal benefits from learned attention. Citation networks (some references are more relevant), social networks (close friends vs acquaintances), knowledge graphs.
- When you want inspectable attention patterns. Attention weights show which neighbors the model weights highest. Note: Jain & Wallace (2019) showed that attention weights do not always correlate with feature importance, so treat them as a diagnostic signal rather than a faithful explanation.
When not to use GATConv
1. All neighbors are equally important
On homogeneous graphs where every neighbor carries similar signal (regular lattices, molecular graphs with uniform bonds), GCNConv achieves the same accuracy 2-3x faster. The attention computation adds overhead without benefit.
2. Very large graphs without sampling
GATConv computes attention per edge. On billion-edge graphs, this is expensive. Combine with NeighborLoader for sampling, or use SAGEConv which is designed for scale from the ground up.
3. The static attention problem
GATConv's attention scoring function computes key and query independently before combining them. This means node A might rank its neighbors the same way regardless of context. If your task requires truly dynamic attention (where neighbor ranking changes based on the query), use GATv2Conv instead.
Comparison to alternatives
| Layer | Attention | Dynamic | Speed |
|---|---|---|---|
| GCNConv | None (degree only) | N/A | Fastest |
| GATConv | Learned, per edge | Static | 2-3x slower |
| GATv2Conv | Learned, per edge | Dynamic | Similar to GAT |
| TransformerConv | Full transformer | Dynamic | 3-5x slower |
How KumoRFM builds on this
GATConv introduced the key insight: not all neighbors are created equal. KumoRFM's Relational Graph Transformer takes this further:
- Type-aware attention that learns separate attention patterns for each edge type (purchases, reviews, returns), not one shared attention function
- Temporal attention that weighs recent interactions higher than old ones, without manual feature engineering of time windows
- Dynamic attention (like GATv2Conv) where neighbor ranking always depends on the query context
- Scalable attention via sampling, combining GAT's expressiveness with SAGEConv's efficiency