Node classification is the task of predicting a categorical label for each node in a graph by leveraging both the node's own features and information aggregated from its neighborhood through message passing. It is the most fundamental and widely studied GNN task. In enterprise terms, every entity-level categorical prediction on relational data is a node classification problem: fraud/not-fraud, will-churn/will-stay, risk-tier-A/B/C.
Why it matters for enterprise data
Traditional ML classification treats each entity as an independent row in a feature table. Node classification treats each entity as a node in a graph, connected to related entities through foreign-key relationships. This relational context is what drives the performance gap:
- Fraud detection: A transaction looks normal in isolation. But connected to a merchant with 40% chargeback rate and an account linked to 3 other flagged accounts, the fraud signal is clear. Node classification captures this.
- Churn prediction: A customer has not changed their behavior. But their closest contacts (most-called numbers) are all churning. Node classification captures social contagion that flat-table models miss.
- Credit risk: A borrower looks creditworthy. But their business partners are all in financial distress. The 2-hop neighborhood reveals risk invisible in the borrower's own features.
How node classification works
- Build the graph: Entities become nodes. Foreign keys become edges. Row features become node features.
- Message passing: GNN layers aggregate neighbor information. After 2 layers, each node's representation encodes its 2-hop neighborhood.
- Classification head: A linear layer + softmax maps each node's embedding to class probabilities.
- Train: Cross-entropy loss on labeled nodes. Gradients flow back through the GNN layers.
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class NodeClassifier(torch.nn.Module):
def __init__(self, num_features, hidden_dim, num_classes):
super().__init__()
self.conv1 = GCNConv(num_features, hidden_dim)
self.conv2 = GCNConv(hidden_dim, num_classes)
def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
# Training loop
model = NodeClassifier(num_features=16, hidden_dim=64, num_classes=4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(200):
model.train()
out = model(data.x, data.edge_index)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
# Inference: predict ALL nodes, including unlabeled
model.eval()
predictions = model(data.x, data.edge_index).argmax(dim=1)A complete node classification pipeline in 20 lines. The model trains on labeled nodes and predicts all nodes by leveraging graph structure.
Concrete example: customer risk categorization
A financial services company wants to categorize customers into 4 risk tiers:
- Customers: 500,000 nodes with features [income, age, tenure, credit_score]
- Transactions: 5,000,000 edge nodes linking customers to merchants
- Labels: 25,000 customers (5%) have known risk tier labels from manual review
After 2 GCNConv layers, each customer's embedding encodes their own financial profile plus the transaction patterns and merchant types in their 2-hop neighborhood. The classification head maps this to a risk tier. The model predicts tiers for all 475,000 unlabeled customers, using the graph structure to propagate risk signals from labeled to unlabeled nodes.
Limitations and what comes next
- Homophily dependence: Node classification with standard GNNs works best when connected nodes share labels. Under heterophily, performance can drop below a simple MLP.
- Class imbalance: Enterprise classification tasks often have extreme imbalance (1% fraud, 99% legitimate). Standard cross-entropy loss biases toward the majority class. Focal loss, oversampling, or class-weighted loss are needed.
- Scalability: Full-graph training does not scale to millions of nodes. Neighbor sampling and subgraph sampling are required for enterprise-scale graphs.