A heterogeneous graph is a graph with multiple node types and edge types. Unlike a homogeneous graph where every node and edge is the same, a heterogeneous graph distinguishes between fundamentally different entities and relationships. An e-commerce graph might have user nodes, product nodes, and order nodes, connected by “purchased”, “reviewed”, and “contains” edges.
This is the natural representation of enterprise relational databases. Every table becomes a node type. Every foreign key relationship becomes an edge type. Preserving this heterogeneity lets GNNs learn type-specific transformations, which consistently outperforms forcing everything into a single node type.
Why heterogeneity matters
Consider an e-commerce database with three tables: users, products, and orders. In a homogeneous graph, you would pad all features to the same dimension and lose the distinction between a user's age and a product's price. In a heterogeneous graph:
- Users have features: age, location, account_age
- Products have features: price, category, avg_rating
- Orders have features: amount, date, payment_method
Each type gets its own embedding layer and its own learned transformations. The model learns that “user purchases product” is a fundamentally different relationship than “order contains product.”
HeteroData in PyG
PyTorch Geometric represents heterogeneous graphs with the HeteroData class. Node features and edge indices are stored per-type using dictionary-style access:
from torch_geometric.data import HeteroData
import torch
data = HeteroData()
# Node features by type
data['user'].x = torch.randn(1000, 16) # 1000 users, 16 features
data['product'].x = torch.randn(5000, 32) # 5000 products, 32 features
data['order'].x = torch.randn(8000, 8) # 8000 orders, 8 features
# Edge indices by (src_type, relation, dst_type) triplets
data['user', 'purchases', 'product'].edge_index = torch.randint(0, 1000, (2, 15000))
data['user', 'places', 'order'].edge_index = torch.randint(0, 1000, (2, 8000))
data['order', 'contains', 'product'].edge_index = torch.randint(0, 5000, (2, 20000))
print(data)
# HeteroData(
# user={ x=[1000, 16] },
# product={ x=[5000, 32] },
# order={ x=[8000, 8] },
# (user, purchases, product)={ edge_index=[2, 15000] },
# ...
# )Each node type has its own feature dimension. Each edge type has its own connectivity matrix.
Building a heterogeneous GNN
PyG offers two approaches to build GNNs on heterogeneous graphs:
Approach 1: to_hetero() (fast prototyping)
Write a standard homogeneous model, then convert it automatically:
import torch
from torch_geometric.nn import SAGEConv, to_hetero
class HomoGNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = SAGEConv(-1, 64)
self.conv2 = SAGEConv(64, 32)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
return self.conv2(x, edge_index)
model = HomoGNN()
# Convert to heterogeneous: duplicates layers per type
model = to_hetero(model, data.metadata(), aggr='sum')
# Now model accepts HeteroData directly
out = model(data.x_dict, data.edge_index_dict)to_hetero() creates separate weight matrices for each node type and edge type automatically.
Approach 2: HGTConv (best performance)
The Heterogeneous Graph Transformer uses type-specific attention heads and is generally the strongest layer for heterogeneous data:
from torch_geometric.nn import HGTConv, Linear
class HGT(torch.nn.Module):
def __init__(self, metadata, hidden_dim=64, num_layers=2):
super().__init__()
self.lin_dict = torch.nn.ModuleDict()
for node_type in metadata[0]:
self.lin_dict[node_type] = Linear(-1, hidden_dim)
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
self.convs.append(HGTConv(hidden_dim, hidden_dim, metadata))
def forward(self, x_dict, edge_index_dict):
x_dict = {k: self.lin_dict[k](x).relu()
for k, x in x_dict.items()}
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
return x_dictHGTConv learns type-specific attention patterns. It knows that user-purchases-product edges carry different semantics than order-contains-product edges.
Enterprise example: churn prediction
A SaaS company wants to predict which customers will churn. Their database has:
- Customers: plan_tier, signup_date, company_size
- Support tickets: severity, resolution_time, satisfaction_score
- Features used: feature_name, usage_frequency, last_used
- Invoices: amount, paid_on_time, discount_applied
As a heterogeneous graph, message passing lets a customer node absorb information from its support tickets (high severity tickets with low satisfaction scores), its feature usage patterns (declining usage of core features), and its payment history (increasing late payments). Each signal travels through its own typed edge with type-specific learned weights.
A flat feature table would require a data engineer to manually aggregate each relationship: “count of high-severity tickets in last 30 days,” “average feature usage decline over 90 days.” The heterogeneous GNN discovers these patterns automatically from the raw relational structure.