Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide8 min read

Link Prediction: Predicting Missing or Future Edges in a Graph

Link prediction answers the question: should an edge exist between these two nodes? It is the foundation of recommendation systems, fraud discovery, and knowledge graph completion, and it maps naturally to enterprise relational data problems.

PyTorch Geometric

TL;DR

  • 1Link prediction predicts whether an edge should exist between two nodes. It scores candidate pairs using node embedding similarity (dot product, MLP, or cosine distance).
  • 2Enterprise applications: product recommendation (will customer buy product?), fraud discovery (are accounts secretly linked?), supply chain optimization (should we partner with supplier?), knowledge graph completion.
  • 3Training uses positive edges (real links) and sampled negative edges (non-links) with binary cross-entropy. The model learns embeddings where connected nodes are close and unconnected nodes are far apart.
  • 4GNN-based link prediction outperforms collaborative filtering because it leverages multi-hop graph structure, not just direct user-item interactions.
  • 5In PyG: compute node embeddings with a GNN, score pairs with dot product, train with negative sampling. Use torch_geometric.transforms.RandomLinkSplit for proper evaluation splits.

Link prediction is the task of predicting whether an edge should exist between two nodes in a graph, either to fill in missing edges in an incomplete graph or to forecast edges that will appear in the future. A GNN computes embeddings for each node through message passing, then scores candidate node pairs based on embedding similarity. Pairs with high scores are predicted as likely links. This task underpins recommendation systems, fraud network discovery, knowledge graph completion, and relationship forecasting.

Why it matters for enterprise data

Many high-value enterprise predictions are link prediction problems in disguise:

  • Product recommendation: “Will customer X buy product Y?” = predicting a customer-product edge in a purchase graph.
  • Fraud network discovery: “Are these two accounts connected through shell companies?” = predicting hidden edges in a financial network.
  • Supply chain optimization: “Should factory X source from supplier Y?” = predicting edges in a supply chain graph.
  • Cross-sell / up-sell: “Will this customer upgrade to premium?” = predicting a customer-tier edge.

Link prediction on relational enterprise graphs captures multi-hop patterns that collaborative filtering misses. A product recommendation considers not just which products the customer bought, but which products similar customers bought, which categories are trending, and which products are frequently co-purchased.

How link prediction works

Step 1: Compute node embeddings

A GNN (GCNConv, GATConv, SAGEConv) computes embeddings for all nodes via message passing on the training edges.

Step 2: Score candidate pairs

For a candidate pair (u, v), compute a link score:

  • Dot product: score = z_u^T * z_v. Simple, fast, works well.
  • MLP decoder: score = MLP(concat(z_u, z_v)). More expressive.
  • Distance-based: score = -||z_u - z_v||. Closer embeddings = higher score.

Step 3: Train with negative sampling

Positive examples are real edges. Negative examples are randomly sampled non-edges. Train with binary cross-entropy to push connected nodes' embeddings together and unconnected nodes' embeddings apart.

link_prediction.py
import torch
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import RandomLinkSplit

# Split edges into train/val/test
transform = RandomLinkSplit(num_val=0.1, num_test=0.1,
                            add_negative_train_samples=True)
train_data, val_data, test_data = transform(data)

class LinkPredictor(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

    def decode(self, z, edge_label_index):
        # Dot product decoder
        src, dst = edge_label_index
        return (z[src] * z[dst]).sum(dim=-1)

    def forward(self, x, edge_index, edge_label_index):
        z = self.encode(x, edge_index)
        return self.decode(z, edge_label_index)

model = LinkPredictor(data.num_features, 64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(100):
    model.train()
    scores = model(train_data.x, train_data.edge_index,
                   train_data.edge_label_index)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(
        scores, train_data.edge_label.float())
    loss.backward()
    optimizer.step()

Complete link prediction pipeline. RandomLinkSplit handles train/test splitting. The dot product decoder scores candidate pairs from node embeddings.

Concrete example: product recommendations

An e-commerce company wants to recommend products:

  • Customer nodes: 1M customers with features [age, location, tenure]
  • Product nodes: 100K products with features [price, category, rating]
  • Purchase edges: 10M historical purchases (customer bought product)

The GNN learns embeddings where customers and their purchased products are close in embedding space. At inference, for customer Alice, the system scores all 100K products by dot product with Alice's embedding and returns the top 10 highest-scoring unowned products as recommendations.

The 2-hop neighborhood gives each customer information about products bought by similar customers (collaborative signal) and products in the same categories they prefer (content signal), unified in a single model.

Limitations and what comes next

  1. Scalability of negative sampling: For a graph with 1M nodes, there are ~10^12 possible edges. Sampling informative negatives (hard negatives) is critical for training efficiency and quality.
  2. Temporal validity: Link prediction on a static snapshot does not respect time. A model might predict that customer A will buy product B, but product B was discontinued yesterday. Temporal link prediction requires time-aware training splits.
  3. Cold start: New nodes with few connections have poor embeddings, leading to poor link predictions. This is the classic cold-start problem in recommendation systems.

Frequently asked questions

What is link prediction?

Link prediction is the task of predicting whether an edge (link) should exist between two nodes in a graph. It can predict missing edges in an incomplete graph or forecast future edges that will appear. The model scores candidate node pairs by comparing their embeddings (dot product, MLP, or distance-based). Higher scores indicate higher probability of a link.

What are enterprise applications of link prediction?

Product recommendation (will customer X buy product Y?), fraud network discovery (are these two accounts connected through hidden channels?), knowledge graph completion (does drug X treat disease Y?), supply chain optimization (should factory X partner with supplier Y?), and social network suggestions (should user X follow user Y?).

How does GNN-based link prediction work?

A GNN computes node embeddings via message passing. For a candidate pair (u, v), the model scores the link using a decoder: typically dot product (u^T * v), concatenation + MLP, or cosine similarity. Training uses positive edges (real links) and negative edges (random non-links) with binary cross-entropy loss. At inference, all candidate pairs are scored and ranked.

What is the difference between link prediction and collaborative filtering?

Collaborative filtering predicts user-item preferences, typically on a bipartite graph. Link prediction is the more general formulation that works on any graph type. GNN-based link prediction subsumes collaborative filtering: it handles bipartite user-item graphs but also homogeneous networks (social graphs), heterogeneous graphs (customer-order-product), and temporal graphs (future edge prediction).

How do you evaluate link prediction?

Common metrics: AUC-ROC (area under the ROC curve, measures ranking quality), Average Precision (precision-recall trade-off), Hits@K (fraction of true links in top-K predictions), and Mean Reciprocal Rank (MRR). Evaluation splits the edges into train/val/test sets, trains on the train edges, and evaluates on held-out test edges with randomly sampled negative edges.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.