Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide7 min read

Heterogeneous Attention: Attention Mechanisms for Multi-Type Graphs

In a heterogeneous graph, a customer's purchased product should get different attention weight than a viewed product. Heterogeneous attention learns type-specific attention patterns, enabling GNNs to weight relationship types by their predictive importance.

PyTorch Geometric

TL;DR

  • 1Heterogeneous attention computes attention weights that depend on source node type, target node type, and edge type. Each (source, edge, target) triplet gets its own learned attention pattern.
  • 2HGT (Heterogeneous Graph Transformer) uses type-specific Query, Key, Value projections: Q depends on target type, K and V depend on source type, and a type-specific matrix modulates the attention for each edge type.
  • 3This captures semantic differences: a purchased product should receive higher attention than a viewed product for churn prediction, but the opposite for product recommendation.
  • 4More parameter-efficient than HeteroConv with many edge types because it shares base attention computation, only adding type-specific projection matrices.
  • 5In PyG, HGTConv implements heterogeneous attention. It handles arbitrary numbers of node types and edge types with a single layer definition.

Heterogeneous attention is an attention mechanism that computes type-aware weights for graphs with multiple node and edge types. Standard graph attention (GAT) treats all neighbors equally regardless of type. In a retail graph, this means a customer's purchased product and a merely viewed product get the same kind of attention computation. Heterogeneous attention uses different learned projections per type, enabling the model to discover that purchases matter more than views for churn prediction but views matter more for recommendation.

The HGT architecture

The Heterogeneous Graph Transformer (HGT) is the most widely used heterogeneous attention layer. For each edge from source node s (type tau_s) to target node t (type tau_t) with edge type phi:

hgt_attention.py
# HGT attention computation (simplified)
# For edge (source_type, edge_type, target_type)

# Type-specific projections
Q = W_Q[target_type] @ h_target    # Query depends on target type
K = W_K[source_type] @ h_source    # Key depends on source type
V = W_V[source_type] @ h_source    # Value depends on source type

# Type-specific attention weight
attention_logit = (Q @ W_ATT[edge_type] @ K.T) / sqrt(d)

# Multi-head attention with type-aware softmax
alpha = softmax(attention_logit)   # over all neighbors of target

# Weighted message
message = alpha * V

# In PyG:
from torch_geometric.nn import HGTConv

conv = HGTConv(
    in_channels=64,
    out_channels=64,
    metadata=data.metadata(),  # auto-detects node/edge types
    heads=4,                    # multi-head attention
)

Three type-specific matrices (W_Q, W_K, W_ATT) differentiate attention patterns per type triplet. W_Q uses target type, W_K uses source type, W_ATT uses edge type.

Why type-aware attention matters

Consider a financial graph with three edge types:

  • Transferred_to: Direct money flow. Strongest fraud signal. Should get high attention for fraud detection.
  • Shared_device: Same login device. Moderate fraud signal for account takeover. Medium attention.
  • Same_merchant: Shopped at the same store. Weak fraud signal. Low attention.

Standard GAT computes one attention weight per neighbor regardless of edge type. It might learn to ignore same_merchant edges, but it cannot learn different attention patterns per type. HGT can learn that for transferred_to edges, the amount is the key feature, while for shared_device edges, the timing is what matters.

HGT vs HeteroConv

Two approaches to heterogeneous graphs in PyG:

  • HeteroConv: Wraps a separate GNN layer per edge type. Each type gets its own complete layer (SAGEConv, GATConv, etc.). Simple and effective with 3-5 edge types. Parameters scale as O(T * d^2) where T is the number of types.
  • HGTConv: Shared base computation with type-specific projection matrices. Parameters scale as O(T * d + d^2), more efficient with many types. Captures cross-type interactions that separate layers miss.

Rule of thumb: use HeteroConv for simple heterogeneous graphs (3-5 edge types) where you want maximum control. Use HGTConv for complex heterogeneous graphs (10+ edge types) where parameter efficiency and cross-type learning matter.

Multi-head heterogeneous attention

Like standard multi-head attention, HGT uses multiple attention heads. Each head can learn a different aspect of type-dependent importance:

  • Head 1 might focus on recency across all edge types
  • Head 2 might focus on amount for transaction edges and frequency for interaction edges
  • Head 3 might focus on structural position (hub vs leaf) per node type

The heads are concatenated and linearly projected, combining multiple type-aware attention patterns into a rich representation.

Frequently asked questions

What is heterogeneous attention?

Heterogeneous attention computes attention weights that depend on the types of both the source and target nodes, as well as the edge type connecting them. In a retail graph, the attention weight for a customer-purchased-product edge is computed differently than for a customer-viewed-product edge, using type-specific query, key, and value projections.

How does HGT (Heterogeneous Graph Transformer) work?

HGT uses type-specific linear projections for queries (target node type), keys (source node type), and values (source node type), plus a type-specific attention weight matrix for each edge type. This means every combination of (source type, edge type, target type) gets its own learned attention pattern.

When should you use heterogeneous attention vs HeteroConv?

HeteroConv applies separate GNN layers per edge type (e.g., separate SAGEConv per relation). Heterogeneous attention (HGT) shares attention computation across types with type-specific projections. HGT is more parameter-efficient with many edge types (10+) and captures cross-type interactions. HeteroConv is simpler and works well with few edge types (3-5).

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.