Link prediction is the task of predicting whether an edge should exist between two nodes in a graph, either to fill in missing edges in an incomplete graph or to forecast edges that will appear in the future. A GNN computes embeddings for each node through message passing, then scores candidate node pairs based on embedding similarity. Pairs with high scores are predicted as likely links. This task underpins recommendation systems, fraud network discovery, knowledge graph completion, and relationship forecasting.
Why it matters for enterprise data
Many high-value enterprise predictions are link prediction problems in disguise:
- Product recommendation: “Will customer X buy product Y?” = predicting a customer-product edge in a purchase graph.
- Fraud network discovery: “Are these two accounts connected through shell companies?” = predicting hidden edges in a financial network.
- Supply chain optimization: “Should factory X source from supplier Y?” = predicting edges in a supply chain graph.
- Cross-sell / up-sell: “Will this customer upgrade to premium?” = predicting a customer-tier edge.
Link prediction on relational enterprise graphs captures multi-hop patterns that collaborative filtering misses. A product recommendation considers not just which products the customer bought, but which products similar customers bought, which categories are trending, and which products are frequently co-purchased.
How link prediction works
Step 1: Compute node embeddings
A GNN (GCNConv, GATConv, SAGEConv) computes embeddings for all nodes via message passing on the training edges.
Step 2: Score candidate pairs
For a candidate pair (u, v), compute a link score:
- Dot product: score = z_u^T * z_v. Simple, fast, works well.
- MLP decoder: score = MLP(concat(z_u, z_v)). More expressive.
- Distance-based: score = -||z_u - z_v||. Closer embeddings = higher score.
Step 3: Train with negative sampling
Positive examples are real edges. Negative examples are randomly sampled non-edges. Train with binary cross-entropy to push connected nodes' embeddings together and unconnected nodes' embeddings apart.
import torch
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import RandomLinkSplit
# Split edges into train/val/test
transform = RandomLinkSplit(num_val=0.1, num_test=0.1,
add_negative_train_samples=True)
train_data, val_data, test_data = transform(data)
class LinkPredictor(torch.nn.Module):
def __init__(self, in_dim, hidden_dim):
super().__init__()
self.conv1 = GCNConv(in_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, hidden_dim)
def encode(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
return self.conv2(x, edge_index)
def decode(self, z, edge_label_index):
# Dot product decoder
src, dst = edge_label_index
return (z[src] * z[dst]).sum(dim=-1)
def forward(self, x, edge_index, edge_label_index):
z = self.encode(x, edge_index)
return self.decode(z, edge_label_index)
model = LinkPredictor(data.num_features, 64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(100):
model.train()
scores = model(train_data.x, train_data.edge_index,
train_data.edge_label_index)
loss = torch.nn.functional.binary_cross_entropy_with_logits(
scores, train_data.edge_label.float())
loss.backward()
optimizer.step()Complete link prediction pipeline. RandomLinkSplit handles train/test splitting. The dot product decoder scores candidate pairs from node embeddings.
Concrete example: product recommendations
An e-commerce company wants to recommend products:
- Customer nodes: 1M customers with features [age, location, tenure]
- Product nodes: 100K products with features [price, category, rating]
- Purchase edges: 10M historical purchases (customer bought product)
The GNN learns embeddings where customers and their purchased products are close in embedding space. At inference, for customer Alice, the system scores all 100K products by dot product with Alice's embedding and returns the top 10 highest-scoring unowned products as recommendations.
The 2-hop neighborhood gives each customer information about products bought by similar customers (collaborative signal) and products in the same categories they prefer (content signal), unified in a single model.
Limitations and what comes next
- Scalability of negative sampling: For a graph with 1M nodes, there are ~10^12 possible edges. Sampling informative negatives (hard negatives) is critical for training efficiency and quality.
- Temporal validity: Link prediction on a static snapshot does not respect time. A model might predict that customer A will buy product B, but product B was discontinued yesterday. Temporal link prediction requires time-aware training splits.
- Cold start: New nodes with few connections have poor embeddings, leading to poor link predictions. This is the classic cold-start problem in recommendation systems.