The Graph Attention Network (GAT) introduced learned attention to graph neural networks. Where GCN weights all neighbors by a fixed function of node degrees, GAT learns which neighbors matter for each specific node and task. This makes the aggregation step adaptive: the same node in the same graph produces different attention patterns for fraud detection vs recommendation, because the model learns which connections carry signal for each task.
The attention mechanism
For each edge from neighbor j to target node i, GAT computes an attention weight in four steps:
import torch
from torch_geometric.nn import GATConv
# GATConv with 8 attention heads
conv = GATConv(
in_channels=64,
out_channels=8, # per-head output dimension
heads=8, # number of attention heads
concat=True, # concatenate heads (intermediate layers)
)
# Output dimension: 8 * 8 = 64
# Forward pass
x_new = conv(x, edge_index)
# x_new: [num_nodes, 64]
# Access attention weights (for explainability)
x_new, (edge_index_out, attention_weights) = conv(
x, edge_index, return_attention_weights=True
)
# attention_weights: [num_edges, 8] (one weight per head per edge)PyG's GATConv implements multi-head attention with a single line. return_attention_weights=True provides interpretable edge-level importance scores.
Step 1: linear projection
Both source and target node features are projected through a shared linear layerW. This transforms features into the attention space.
Step 2: attention score
The projected features of source and target are concatenated and passed through a learnable attention vector a followed by LeakyReLU. This produces a scalar score e_ij for each edge.
Step 3: normalization
Softmax is applied across all neighbors of node i: the scores are normalized to sum to 1. This produces the final attention weight alpha_ij.
Step 4: weighted aggregation
Node i's updated embedding is the weighted sum of its neighbors' projected features, using the attention weights.
Multi-head attention
A single attention head might focus on only one aspect of neighbor relevance. Multi-head attention runs K independent attention functions, each with its own parameters:
- Head 1 might learn to attend to neighbors with similar features
- Head 2 might learn to attend to high-degree hub neighbors
- Head 3 might learn to attend to recently connected neighbors
In intermediate layers, head outputs are concatenated (output dimension = heads x per_head_dim). In the final layer, outputs are averaged (output dimension = per_head_dim). This multi-perspective approach produces more robust and informative embeddings.
GATv2: fixing the expressiveness limitation
Brody et al. (2021) showed that the original GAT has a limited form of attention: the attention function is “static” in that the ranking of attention scores does not change based on the query node's features. GATv2 fixes this by changing the order of operations:
- GAT: attention = a^T [Wh_i || Wh_j] (attention parameters applied to concatenation of independently transformed features)
- GATv2: attention = a^T LeakyReLU(W[h_i || h_j]) (features concatenated first, then jointly transformed)
GATv2 is strictly more expressive: it can compute any attention function that GAT can, plus additional functions that GAT cannot. Use GATv2Conv in PyG.
When to use GAT
- Heterogeneous neighbor importance: fraud networks (suspicious vs normal connections), social networks (influential vs casual friends)
- Interpretability requirements: attention weights show which edges drove each prediction, useful for explainability
- Dynamic relevance: when which neighbors matter changes depending on the target node's features or the prediction task
When GCN suffices
- Uniform neighbor importance: molecular graphs where all bonds are structurally significant
- Computational efficiency: GAT adds O(|E| x d) computation for attention scores per layer
- Maximum expressiveness needed: GINConv (sum without weighting) is provably more expressive for graph isomorphism testing than attention-weighted aggregation