Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide8 min read

Node Classification: Predicting Labels for Individual Graph Nodes

Node classification is the most common GNN task. It predicts a category for each node using both its features and the structure of its neighborhood. In enterprise terms, it is entity-level prediction on relational data.

PyTorch Geometric

TL;DR

  • 1Node classification predicts a categorical label for each node using its features plus information aggregated from its graph neighborhood through message passing.
  • 2Enterprise applications: fraud detection (is this account fraudulent?), churn prediction (will this customer leave?), customer segmentation, credit risk rating, product categorization.
  • 3GNNs excel at semi-supervised node classification. On Cora, 140 labeled nodes (5.2%) are enough to classify 2,708 nodes at 81.5%. Message passing propagates label signal through the graph.
  • 4On RelBench enterprise benchmarks, GNN node classification achieves 75.83 AUROC vs. 62.44 for flat-table methods. KumoRFM reaches 81.14 fine-tuned.
  • 5In PyG: stack 2 GCNConv layers, add a softmax output, train with cross-entropy loss on labeled nodes. The model automatically leverages graph structure for unlabeled nodes.

Node classification is the task of predicting a categorical label for each node in a graph by leveraging both the node's own features and information aggregated from its neighborhood through message passing. It is the most fundamental and widely studied GNN task. In enterprise terms, every entity-level categorical prediction on relational data is a node classification problem: fraud/not-fraud, will-churn/will-stay, risk-tier-A/B/C.

Why it matters for enterprise data

Traditional ML classification treats each entity as an independent row in a feature table. Node classification treats each entity as a node in a graph, connected to related entities through foreign-key relationships. This relational context is what drives the performance gap:

  • Fraud detection: A transaction looks normal in isolation. But connected to a merchant with 40% chargeback rate and an account linked to 3 other flagged accounts, the fraud signal is clear. Node classification captures this.
  • Churn prediction: A customer has not changed their behavior. But their closest contacts (most-called numbers) are all churning. Node classification captures social contagion that flat-table models miss.
  • Credit risk: A borrower looks creditworthy. But their business partners are all in financial distress. The 2-hop neighborhood reveals risk invisible in the borrower's own features.

How node classification works

  1. Build the graph: Entities become nodes. Foreign keys become edges. Row features become node features.
  2. Message passing: GNN layers aggregate neighbor information. After 2 layers, each node's representation encodes its 2-hop neighborhood.
  3. Classification head: A linear layer + softmax maps each node's embedding to class probabilities.
  4. Train: Cross-entropy loss on labeled nodes. Gradients flow back through the GNN layers.
node_classification.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class NodeClassifier(torch.nn.Module):
    def __init__(self, num_features, hidden_dim, num_classes):
        super().__init__()
        self.conv1 = GCNConv(num_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, num_classes)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

# Training loop
model = NodeClassifier(num_features=16, hidden_dim=64, num_classes=4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(200):
    model.train()
    out = model(data.x, data.edge_index)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

# Inference: predict ALL nodes, including unlabeled
model.eval()
predictions = model(data.x, data.edge_index).argmax(dim=1)

A complete node classification pipeline in 20 lines. The model trains on labeled nodes and predicts all nodes by leveraging graph structure.

Concrete example: customer risk categorization

A financial services company wants to categorize customers into 4 risk tiers:

  • Customers: 500,000 nodes with features [income, age, tenure, credit_score]
  • Transactions: 5,000,000 edge nodes linking customers to merchants
  • Labels: 25,000 customers (5%) have known risk tier labels from manual review

After 2 GCNConv layers, each customer's embedding encodes their own financial profile plus the transaction patterns and merchant types in their 2-hop neighborhood. The classification head maps this to a risk tier. The model predicts tiers for all 475,000 unlabeled customers, using the graph structure to propagate risk signals from labeled to unlabeled nodes.

Limitations and what comes next

  1. Homophily dependence: Node classification with standard GNNs works best when connected nodes share labels. Under heterophily, performance can drop below a simple MLP.
  2. Class imbalance: Enterprise classification tasks often have extreme imbalance (1% fraud, 99% legitimate). Standard cross-entropy loss biases toward the majority class. Focal loss, oversampling, or class-weighted loss are needed.
  3. Scalability: Full-graph training does not scale to millions of nodes. Neighbor sampling and subgraph sampling are required for enterprise-scale graphs.

Frequently asked questions

What is node classification?

Node classification is the task of predicting a categorical label for each node in a graph. The model uses both the node's own features and information from its neighbors (via message passing) to make predictions. In semi-supervised node classification, only a small fraction of nodes are labeled during training, and the model predicts labels for all remaining nodes.

What are common enterprise examples of node classification?

Customer segmentation (which segment does each customer belong to?), fraud detection (is this account fraudulent?), churn prediction (will this customer leave?), credit risk rating (what is this borrower's risk category?), and product categorization (what category does this item belong to?). Any entity-level categorical prediction on relational data maps to node classification.

How does node classification differ from traditional classification?

Traditional classification uses only each entity's own features (one row from a feature table). Node classification uses the entity's features PLUS information from its relational neighborhood. A customer's churn prediction considers not just their own activity but also the behavior of their connected orders, products, support tickets, and even other customers. This relational context is what gives GNNs their advantage.

What accuracy can I expect from GNN node classification?

On academic benchmarks: GCN achieves ~81.5% on Cora (citation network, 7 classes). On enterprise data (RelBench): GNN-based models achieve 75.83 AUROC across 30 tasks vs. 62.44 for flat-table methods. KumoRFM achieves 81.14 fine-tuned. Actual accuracy depends heavily on the specific dataset, label distribution, and graph structure.

How many labeled nodes do I need for node classification?

GNNs excel in semi-supervised settings with very few labels. On Cora (2,708 nodes), standard benchmarks use only 140 labeled nodes (5.2%) for training. On enterprise data, even 1-5% labeled nodes can produce strong results because message passing propagates label information through the graph structure.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.