Node regression predicts a continuous numerical value for each node in a graph by combining the node's own features with information aggregated from its neighborhood through message passing. It is the regression counterpart to node classification. The GNN architecture is identical; only the output layer (linear with 1 output instead of num_classes outputs) and loss function (MSE instead of cross-entropy) differ. Node regression powers enterprise forecasting tasks: revenue prediction, demand estimation, risk scoring, and lifetime value calculation.
Why it matters for enterprise data
Most enterprise prediction tasks are regression problems: “How much will this customer spend next quarter?” “What is this product's expected demand?” “What is this loan's probability of default?” These are continuous values, not categories.
Flat-table regression uses only the entity's own features. Node regression on a relational graph adds the entity's full context: a customer's predicted revenue incorporates their order history, the products they bought, the categories trending among similar customers, and their interaction with support. This cross-table signal is captured automatically through 2-3 layers of message passing.
How node regression works
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
class NodeRegressor(torch.nn.Module):
def __init__(self, in_dim, hidden_dim):
super().__init__()
self.conv1 = SAGEConv(in_dim, hidden_dim)
self.conv2 = SAGEConv(hidden_dim, hidden_dim)
# Output head: 1 value per node
self.head = torch.nn.Linear(hidden_dim, 1)
def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.3, training=self.training)
x = self.conv2(x, edge_index)
return self.head(x).squeeze(-1) # [num_nodes]
model = NodeRegressor(in_dim=16, hidden_dim=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(200):
model.train()
pred = model(data.x, data.edge_index)
# MSE loss on labeled nodes only
loss = F.mse_loss(pred[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
# Inference: predict continuous value for all nodes
model.eval()
predictions = model(data.x, data.edge_index) # continuous scoresIdentical to node classification except: Linear(hidden, 1) instead of Linear(hidden, num_classes), and mse_loss instead of cross_entropy.
Concrete example: customer lifetime value prediction
A subscription business wants to predict 12-month customer lifetime value (CLV):
- Customer nodes: features = [tenure_months, plan_tier, monthly_charge]
- Order nodes: features = [amount, item_count, discount_applied]
- Product nodes: features = [price, category, margin]
- Edges: customer → order (placed), order → product (contains)
- Target: total revenue from each customer over the next 12 months (continuous $)
After 2 SAGEConv layers, each customer's embedding captures:
- Their own subscription level and tenure
- Their purchasing patterns (average order value, frequency)
- The products they buy (high-margin vs. low-margin, growing vs. declining categories)
The regression head predicts a dollar amount. The model learns that customers buying growing-category, high-margin products have higher CLV, even if their current spending is moderate.
Limitations and what comes next
- Target distribution: Enterprise regression targets (revenue, claim amounts) are often heavily right-skewed. Log-transforming targets before training and exponentiating predictions improves performance significantly.
- Temporal leakage: When predicting future values (next-quarter revenue), the graph must only include edges from before the prediction date. Future edges leak the answer. Temporal splits are essential.
- Uncertainty quantification: Point predictions (single values) are often insufficient for enterprise decision-making. Extending node regression to produce prediction intervals requires ensemble methods or probabilistic GNNs.