Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide8 min read

Node Embeddings: Dense Vector Representations of Graph Nodes

A node embedding compresses a node's features, its neighborhood, and its structural position into a single dense vector. These embeddings are the output of GNN layers and the input to every downstream prediction task.

PyTorch Geometric

TL;DR

  • 1A node embedding is a dense vector (64-256 dimensions) that encodes a node's own features plus information from its graph neighborhood. Similar nodes end up with similar embeddings.
  • 2GNN-based embeddings use message passing to aggregate neighbor information through learnable layers. They are supervised and task-specific, outperforming unsupervised alternatives on most enterprise tasks.
  • 3On relational databases, node embeddings replace manual feature engineering. A customer embedding automatically captures spending patterns, product preferences, and behavioral signals from linked tables.
  • 4Two approaches exist: GNN-based (GCN, GAT, GraphSAGE) for supervised tasks, and random-walk-based (node2vec, DeepWalk) for unsupervised pre-training. GNN embeddings dominate for enterprise prediction.
  • 5In PyG, node embeddings are simply the output tensor of a GNN forward pass. Stack GCNConv layers, pass node features and edge_index, and the output is a matrix of node embeddings.

A node embedding is a dense, low-dimensional vector that represents a node in a graph by encoding its features, its neighborhood structure, and its position within the overall graph topology. Nodes with similar neighborhoods and features produce similar embeddings. These vectors serve as the input to downstream tasks: classification, regression, link prediction, and clustering. Every GNN layer produces node embeddings. The final layer's output is what gets fed to a prediction head.

Why it matters for enterprise data

Enterprise relational databases store information across many linked tables. A customer table alone tells you age, location, and account tenure. But the customer's behavior is spread across orders, products, support tickets, and payment tables. Traditional ML requires a data scientist to manually join and aggregate these tables into a single feature vector per customer.

Node embeddings eliminate this bottleneck. When you represent the relational database as a graph (rows = nodes, foreign keys = edges), a GNN computes embeddings that automatically incorporate information from all linked tables. A customer's embedding encodes their orders, the products in those orders, and the behavior of similar customers, all without a single SQL GROUP BY.

On the RelBench benchmark, GNN-based node embeddings achieve 75.83 AUROC across 30 enterprise tasks, compared to 62.44 for manually engineered flat-table features fed to LightGBM.

How node embeddings are computed

GNN-based embeddings (supervised)

A GNN computes node embeddings through repeated rounds of message passing. Each layer aggregates neighbor information and transforms it through learnable weights. After k layers, each node's embedding encodes information from its k-hop neighborhood.

  • Layer 0: raw node features (age, amount, category)
  • Layer 1: features + 1-hop neighbor information
  • Layer 2: features + 2-hop neighborhood context

The embeddings are optimized end-to-end for a specific prediction task (e.g., churn prediction), so the model learns which neighborhood patterns are predictive.

Random-walk-based embeddings (unsupervised)

Methods like node2vec and DeepWalk generate node embeddings without labels. They perform random walks on the graph, then train a skip-gram model (similar to Word2Vec) to predict which nodes co-occur in walks. Nodes that appear in similar walk contexts get similar embeddings.

These embeddings capture structural similarity but are not optimized for any specific task. They are useful for exploratory analysis, visualization, and as pre-trained features for downstream models.

Concrete example: customer embeddings from a transaction graph

Consider a banking database with three tables:

  • Customers: [customer_id, age, income_bracket, account_type]
  • Transactions: [txn_id, customer_id, merchant_id, amount, timestamp]
  • Merchants: [merchant_id, category, avg_ticket, fraud_rate]

After 2 layers of a GNN, customer “Alice”'s 128-dimensional embedding encodes:

  • Her own features (age, income bracket)
  • Transaction patterns (average amount, frequency, time-of-day distribution)
  • Merchant characteristics (categories she shops at, their fraud rates)
  • Behavioral similarity to other customers who share the same merchants

PyG implementation

node_embeddings.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class NodeEmbedder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, embedding_dim):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, embedding_dim)

    def forward(self, x, edge_index):
        # Layer 1: raw features -> 1-hop embeddings
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)

        # Layer 2: 1-hop -> 2-hop embeddings
        x = self.conv2(x, edge_index)
        return x  # shape: [num_nodes, embedding_dim]

# Usage
model = NodeEmbedder(in_channels=16, hidden_channels=64, embedding_dim=128)
embeddings = model(data.x, data.edge_index)

# embeddings[i] is the 128-dim embedding for node i
# Use for classification: logits = classifier(embeddings)
# Use for link prediction: score = (embeddings[src] * embeddings[dst]).sum()

The output of the forward pass IS the node embedding matrix. Each row is one node's learned representation.

Limitations and what comes next

  1. Fixed dimensionality: All nodes get the same embedding size regardless of their neighborhood complexity. A customer with 1,000 transactions gets the same 128 dimensions as one with 3 transactions. Information is necessarily lost for high-degree nodes.
  2. Shallow reach: GNN embeddings capture only k-hop neighborhoods where k is the number of layers. Beyond 2-3 layers, over-smoothing degrades embeddings.
  3. Static snapshots: Standard GNN embeddings represent the graph at a single point in time. Temporal dynamics require re-computation or specialized temporal GNNs.

Graph transformers address the reach limitation by computing embeddings with global attention, allowing each node to incorporate information from the entire graph rather than just its local neighborhood. KumoRFM uses a relational graph transformer to produce node embeddings that achieve 81.14 AUROC after fine-tuning on RelBench tasks.

Frequently asked questions

What is a node embedding?

A node embedding is a dense, low-dimensional vector (typically 64-256 dimensions) that represents a node in a graph. It encodes the node's own features, the features of its neighbors, and its structural position within the graph. Nodes with similar neighborhoods end up with similar embeddings, enabling downstream tasks like classification, regression, and link prediction.

How are node embeddings different from raw node features?

Raw node features describe only the node itself (e.g., a customer's age and location). Node embeddings incorporate information from the entire neighborhood: the customer's orders, the products in those orders, and even the behavior of other customers who bought similar products. Embeddings are learned representations that capture relational context that raw features cannot.

What is the difference between GNN-based and random-walk-based node embeddings?

GNN-based embeddings (GCN, GAT, GraphSAGE) use message passing to aggregate neighbor features through learnable neural network layers. They are supervised and task-specific. Random-walk-based embeddings (node2vec, DeepWalk) learn from co-occurrence patterns in random walks over the graph. They are unsupervised and task-agnostic. GNN embeddings generally outperform random-walk embeddings on supervised tasks because they are optimized for the target.

How do node embeddings apply to enterprise relational data?

In a relational database, each row becomes a node and each foreign key becomes an edge. A GNN computes node embeddings by aggregating information across these edges. A customer embedding encodes not just the customer's attributes but their entire transaction history, product preferences, and behavioral patterns, all without manual feature engineering.

What dimensionality should node embeddings have?

Common choices are 64, 128, or 256 dimensions. Smaller dimensions (32-64) work for simple graphs. Larger dimensions (128-256) capture more information but increase memory and computation. The optimal size depends on graph complexity and downstream task. In practice, 128 dimensions works well for most enterprise applications.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.