Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide8 min read

Graph Attention Network (GAT): The Architecture That Learns Which Neighbors Matter

GCN treats all neighbors equally. GAT learns that some neighbors matter more. Through learned attention weights, each node dynamically decides how much to listen to each of its neighbors, producing task-adaptive, interpretable aggregation.

PyTorch Geometric

TL;DR

  • 1GAT computes learned attention weights for each edge. Instead of degree-normalized averaging (GCN), each neighbor's contribution is weighted by a learned score that depends on both source and target features.
  • 2Attention mechanism: project both node features, concatenate, pass through LeakyReLU, normalize via softmax. The result is a per-edge weight in [0, 1] that sums to 1 across all neighbors.
  • 3Multi-head attention runs K independent attention functions in parallel. Each head may focus on different neighbor aspects (feature similarity, degree, edge type). Outputs are concatenated or averaged.
  • 4GAT provides built-in interpretability: attention weights show which edges the model considers most important for each prediction. Useful for fraud investigation, clinical decision support, and debugging.
  • 5GATv2 fixes a known expressiveness limitation: the original GAT computes attention before the nonlinearity, limiting the functions it can express. GATv2 applies the nonlinearity first, enabling strictly more expressive attention.

The Graph Attention Network (GAT) introduced learned attention to graph neural networks. Where GCN weights all neighbors by a fixed function of node degrees, GAT learns which neighbors matter for each specific node and task. This makes the aggregation step adaptive: the same node in the same graph produces different attention patterns for fraud detection vs recommendation, because the model learns which connections carry signal for each task.

The attention mechanism

For each edge from neighbor j to target node i, GAT computes an attention weight in four steps:

gat_attention.py
import torch
from torch_geometric.nn import GATConv

# GATConv with 8 attention heads
conv = GATConv(
    in_channels=64,
    out_channels=8,    # per-head output dimension
    heads=8,           # number of attention heads
    concat=True,       # concatenate heads (intermediate layers)
)
# Output dimension: 8 * 8 = 64

# Forward pass
x_new = conv(x, edge_index)
# x_new: [num_nodes, 64]

# Access attention weights (for explainability)
x_new, (edge_index_out, attention_weights) = conv(
    x, edge_index, return_attention_weights=True
)
# attention_weights: [num_edges, 8] (one weight per head per edge)

PyG's GATConv implements multi-head attention with a single line. return_attention_weights=True provides interpretable edge-level importance scores.

Step 1: linear projection

Both source and target node features are projected through a shared linear layerW. This transforms features into the attention space.

Step 2: attention score

The projected features of source and target are concatenated and passed through a learnable attention vector a followed by LeakyReLU. This produces a scalar score e_ij for each edge.

Step 3: normalization

Softmax is applied across all neighbors of node i: the scores are normalized to sum to 1. This produces the final attention weight alpha_ij.

Step 4: weighted aggregation

Node i's updated embedding is the weighted sum of its neighbors' projected features, using the attention weights.

Multi-head attention

A single attention head might focus on only one aspect of neighbor relevance. Multi-head attention runs K independent attention functions, each with its own parameters:

  • Head 1 might learn to attend to neighbors with similar features
  • Head 2 might learn to attend to high-degree hub neighbors
  • Head 3 might learn to attend to recently connected neighbors

In intermediate layers, head outputs are concatenated (output dimension = heads x per_head_dim). In the final layer, outputs are averaged (output dimension = per_head_dim). This multi-perspective approach produces more robust and informative embeddings.

GATv2: fixing the expressiveness limitation

Brody et al. (2021) showed that the original GAT has a limited form of attention: the attention function is “static” in that the ranking of attention scores does not change based on the query node's features. GATv2 fixes this by changing the order of operations:

  • GAT: attention = a^T [Wh_i || Wh_j] (attention parameters applied to concatenation of independently transformed features)
  • GATv2: attention = a^T LeakyReLU(W[h_i || h_j]) (features concatenated first, then jointly transformed)

GATv2 is strictly more expressive: it can compute any attention function that GAT can, plus additional functions that GAT cannot. Use GATv2Conv in PyG.

When to use GAT

  • Heterogeneous neighbor importance: fraud networks (suspicious vs normal connections), social networks (influential vs casual friends)
  • Interpretability requirements: attention weights show which edges drove each prediction, useful for explainability
  • Dynamic relevance: when which neighbors matter changes depending on the target node's features or the prediction task

When GCN suffices

  • Uniform neighbor importance: molecular graphs where all bonds are structurally significant
  • Computational efficiency: GAT adds O(|E| x d) computation for attention scores per layer
  • Maximum expressiveness needed: GINConv (sum without weighting) is provably more expressive for graph isomorphism testing than attention-weighted aggregation

Frequently asked questions

What is a Graph Attention Network?

A Graph Attention Network (GAT) is a GNN layer that computes learned attention weights for each edge during message passing. Instead of treating all neighbors equally (like GCN), GAT learns which neighbors are more important for each target node. The attention weight is computed from both the source and target node features using a small neural network, then normalized via softmax across all neighbors.

How does GAT compute attention weights?

For each edge (j -> i): (1) project both node features through a shared linear layer, (2) concatenate the projected source and target features, (3) pass through a learnable attention vector and LeakyReLU activation to get a scalar score, (4) apply softmax across all neighbors of i to normalize scores to sum to 1. The resulting attention weight determines how much of neighbor j's message contributes to node i's updated embedding.

What is multi-head attention in GAT?

Multi-head attention runs K independent attention mechanisms (heads) in parallel, each with its own set of learnable parameters. Each head may learn to attend to different aspects: one head might focus on feature similarity, another on degree, another on edge type. In intermediate layers, head outputs are concatenated. In the final layer, they are averaged. GATv2 fixes a known expressiveness limitation of the original GAT attention mechanism.

When should you use GAT over GCN?

Use GAT when: (1) neighbors vary in importance (fraud networks, social networks with influential vs casual connections), (2) interpretability matters (attention weights show which edges the model relies on), (3) the graph is heterogeneous (different edge types have different relevance). Use GCN when all neighbors are equally important (regular graphs, molecular bonds) or when computational efficiency is the priority.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.