A node embedding is a dense, low-dimensional vector that represents a node in a graph by encoding its features, its neighborhood structure, and its position within the overall graph topology. Nodes with similar neighborhoods and features produce similar embeddings. These vectors serve as the input to downstream tasks: classification, regression, link prediction, and clustering. Every GNN layer produces node embeddings. The final layer's output is what gets fed to a prediction head.
Why it matters for enterprise data
Enterprise relational databases store information across many linked tables. A customer table alone tells you age, location, and account tenure. But the customer's behavior is spread across orders, products, support tickets, and payment tables. Traditional ML requires a data scientist to manually join and aggregate these tables into a single feature vector per customer.
Node embeddings eliminate this bottleneck. When you represent the relational database as a graph (rows = nodes, foreign keys = edges), a GNN computes embeddings that automatically incorporate information from all linked tables. A customer's embedding encodes their orders, the products in those orders, and the behavior of similar customers, all without a single SQL GROUP BY.
On the RelBench benchmark, GNN-based node embeddings achieve 75.83 AUROC across 30 enterprise tasks, compared to 62.44 for manually engineered flat-table features fed to LightGBM.
How node embeddings are computed
GNN-based embeddings (supervised)
A GNN computes node embeddings through repeated rounds of message passing. Each layer aggregates neighbor information and transforms it through learnable weights. After k layers, each node's embedding encodes information from its k-hop neighborhood.
- Layer 0: raw node features (age, amount, category)
- Layer 1: features + 1-hop neighbor information
- Layer 2: features + 2-hop neighborhood context
The embeddings are optimized end-to-end for a specific prediction task (e.g., churn prediction), so the model learns which neighborhood patterns are predictive.
Random-walk-based embeddings (unsupervised)
Methods like node2vec and DeepWalk generate node embeddings without labels. They perform random walks on the graph, then train a skip-gram model (similar to Word2Vec) to predict which nodes co-occur in walks. Nodes that appear in similar walk contexts get similar embeddings.
These embeddings capture structural similarity but are not optimized for any specific task. They are useful for exploratory analysis, visualization, and as pre-trained features for downstream models.
Concrete example: customer embeddings from a transaction graph
Consider a banking database with three tables:
- Customers: [customer_id, age, income_bracket, account_type]
- Transactions: [txn_id, customer_id, merchant_id, amount, timestamp]
- Merchants: [merchant_id, category, avg_ticket, fraud_rate]
After 2 layers of a GNN, customer “Alice”'s 128-dimensional embedding encodes:
- Her own features (age, income bracket)
- Transaction patterns (average amount, frequency, time-of-day distribution)
- Merchant characteristics (categories she shops at, their fraud rates)
- Behavioral similarity to other customers who share the same merchants
PyG implementation
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class NodeEmbedder(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, embedding_dim):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, embedding_dim)
def forward(self, x, edge_index):
# Layer 1: raw features -> 1-hop embeddings
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
# Layer 2: 1-hop -> 2-hop embeddings
x = self.conv2(x, edge_index)
return x # shape: [num_nodes, embedding_dim]
# Usage
model = NodeEmbedder(in_channels=16, hidden_channels=64, embedding_dim=128)
embeddings = model(data.x, data.edge_index)
# embeddings[i] is the 128-dim embedding for node i
# Use for classification: logits = classifier(embeddings)
# Use for link prediction: score = (embeddings[src] * embeddings[dst]).sum()The output of the forward pass IS the node embedding matrix. Each row is one node's learned representation.
Limitations and what comes next
- Fixed dimensionality: All nodes get the same embedding size regardless of their neighborhood complexity. A customer with 1,000 transactions gets the same 128 dimensions as one with 3 transactions. Information is necessarily lost for high-degree nodes.
- Shallow reach: GNN embeddings capture only k-hop neighborhoods where k is the number of layers. Beyond 2-3 layers, over-smoothing degrades embeddings.
- Static snapshots: Standard GNN embeddings represent the graph at a single point in time. Temporal dynamics require re-computation or specialized temporal GNNs.
Graph transformers address the reach limitation by computing embeddings with global attention, allowing each node to incorporate information from the entire graph rather than just its local neighborhood. KumoRFM uses a relational graph transformer to produce node embeddings that achieve 81.14 AUROC after fine-tuning on RelBench tasks.