What HGTConv does
HGTConv applies transformer attention with type-specific projections. For each target node:
- Project the target node's features into a query using a type-specific W_Q
- For each neighbor, project features into key and value using source-type-specific W_K and W_V
- Compute attention scores using an edge-type-specific attention function
- Aggregate neighbor values weighted by attention
The key difference from TransformerConv: every projection depends on the types involved. A “User” query uses different parameters than a “Product” query. A “purchases” edge computes attention differently than a “reviews” edge.
The math (simplified)
# Type-specific projections
Q_i = W_Q[type(i)] · h_i # query depends on target type
K_j = W_K[type(j)] · h_j # key depends on source type
V_j = W_V[type(j)] · h_j # value depends on source type
# Type-specific attention
alpha_ij = softmax_j(
Q_i^T · W_ATT[type(e_ij)] · K_j / sqrt(d)
)
# Weighted aggregation with type-specific message
h_i' = Σ_j alpha_ij · W_MSG[type(e_ij)] · V_j
Where:
type(i) = node type of i (User, Product, etc.)
type(e_ij) = edge type connecting i and j (purchases, reviews, etc.)
W_ATT, W_MSG = edge-type-specific matricesEvery projection and attention computation is conditioned on node types and edge types. This is the most type-aware attention mechanism in standard PyG layers.
PyG implementation
import torch
import torch.nn.functional as F
from torch_geometric.nn import HGTConv, Linear
class HGT(torch.nn.Module):
def __init__(self, hidden_channels, out_channels, num_heads,
num_layers, metadata):
super().__init__()
# metadata = (node_types, edge_types) from HeteroData
self.lin_dict = torch.nn.ModuleDict()
for node_type in metadata[0]:
self.lin_dict[node_type] = Linear(-1, hidden_channels)
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
self.convs.append(HGTConv(hidden_channels, hidden_channels,
metadata, heads=num_heads))
self.out = Linear(hidden_channels, out_channels)
def forward(self, x_dict, edge_index_dict):
# Project each node type to shared hidden dim
x_dict = {k: self.lin_dict[k](v) for k, v in x_dict.items()}
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
return self.out(x_dict['target_node_type'])
# With PyG HeteroData
from torch_geometric.data import HeteroData
data = HeteroData()
data['user'].x = user_features
data['product'].x = product_features
data['user', 'purchases', 'product'].edge_index = purchase_edges
data['user', 'reviews', 'product'].edge_index = review_edges
model = HGT(hidden_channels=64, out_channels=num_classes,
num_heads=4, num_layers=2, metadata=data.metadata())HGTConv works with PyG's HeteroData format. The metadata tuple (node_types, edge_types) is extracted automatically. Each node type can have different input dimensions.
When to use HGTConv
- Enterprise relational data. Databases with multiple tables (customers, orders, products, merchants) connected by different relationships. HGTConv is designed exactly for this structure.
- Complex heterogeneous graphs. Graphs with 3+ node types and 3+ edge types where different relationships have fundamentally different semantics.
- When you need both types and attention. RGCNConv has types but no attention. GATConv has attention but no types. HGTConv has both.
- Academic knowledge graph tasks. Node classification and link prediction on heterogeneous academic graphs (papers, authors, venues, topics).
When not to use HGTConv
- Homogeneous graphs. If all nodes and edges are the same type, HGTConv adds unnecessary complexity. Use GATConv or TransformerConv instead.
- Many types with limited data per type. With 100+ node types, the per-type projections may have too few training examples per parameter. Consider HeteroConv with shared base layers.