Graph classification is the task of predicting a categorical label for an entire graph by combining GNN message passing with graph pooling to produce a single graph-level representation. While node classification makes per-node predictions, graph classification makes one prediction per graph. The architecture adds a pooling step that aggregates all node embeddings into a fixed-size vector, which is then passed to a classification head. This is the standard approach for molecular property prediction, document classification, and any task where the prediction target is a graph-level property.
Why it matters for enterprise data
In enterprise relational databases, entity-level predictions often require reasoning about an entity's entire relational neighborhood. When you extract that neighborhood as a subgraph, entity-level prediction becomes graph classification:
- Account risk: Extract all transactions, merchants, and devices linked to an account. Classify the entire subgraph as high-risk or low-risk.
- Supply chain reliability: Extract the supply path from raw material to finished product. Classify the chain as reliable, at-risk, or disrupted.
- Claim severity: Extract all medical codes, providers, and treatments in an insurance claim. Classify severity level.
This approach captures holistic patterns that per-entity features miss. An account's risk depends not on any single transaction but on the pattern of all its connections.
How graph classification works
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.loader import DataLoader
class GraphClassifier(torch.nn.Module):
def __init__(self, num_features, hidden_dim, num_classes):
super().__init__()
# Message passing layers
self.conv1 = GCNConv(num_features, hidden_dim)
self.conv2 = GCNConv(hidden_dim, hidden_dim)
self.conv3 = GCNConv(hidden_dim, hidden_dim)
# Classification head
self.classifier = torch.nn.Linear(hidden_dim, num_classes)
def forward(self, x, edge_index, batch):
# Step 1: Message passing (node-level)
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
x = self.conv3(x, edge_index)
# Step 2: Graph pooling (collapse to graph-level)
x = global_mean_pool(x, batch) # [num_graphs, hidden_dim]
# Step 3: Classify
return self.classifier(x) # [num_graphs, num_classes]
# DataLoader batches graphs automatically
loader = DataLoader(dataset, batch_size=32, shuffle=True)
model = GraphClassifier(dataset.num_features, 64, dataset.num_classes)
for batch in loader:
out = model(batch.x, batch.edge_index, batch.batch)
loss = F.cross_entropy(out, batch.y) # batch.y = graph-level labelsThe batch tensor is critical: it maps each node to its graph in the mini-batch. global_mean_pool uses it to average within each graph, not across graphs.
Concrete example: molecular toxicity prediction
The TU Datasets collection includes molecular graphs:
- Each graph = one molecule (10-50 atoms)
- Nodes = atoms with features [element, charge, hybridization]
- Edges = chemical bonds with features [bond_type, is_aromatic]
- Label = toxic (1) or non-toxic (0)
After 3 GCNConv layers, each atom's embedding encodes its local chemical environment (3-hop radius). Global mean pooling averages all atom embeddings into a single 64-dimensional molecule vector. The classifier predicts toxicity from this vector.
Limitations and what comes next
- Pooling information loss: Graph pooling compresses variable-size graphs into fixed-size vectors. Large, complex graphs lose more information than small ones. Hierarchical pooling mitigates this.
- Expressiveness: Standard message passing cannot distinguish certain graph structures. If two structurally different graphs map to the same graph-level vector, they will be classified identically regardless of having different true labels.
- Graph size variation: Enterprise subgraphs can vary enormously in size (a customer with 5 transactions vs. one with 50,000). Models must handle this range without bias toward larger graphs.