Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide8 min read

Graph Classification: Predicting Properties of Entire Graphs

Graph classification makes one prediction per graph rather than per node. It combines message passing with graph pooling to collapse all node representations into a single graph-level vector that feeds a classifier.

PyTorch Geometric

TL;DR

  • 1Graph classification predicts a label for an entire graph, not individual nodes. It requires graph pooling to compress node embeddings into one graph-level vector.
  • 2Architecture: GNN layers (message passing) -> graph pooling (collapse nodes) -> classifier (predict label). The pooling step is what distinguishes graph classification from node classification.
  • 3Enterprise applications: classify customer transaction subgraphs (risk level), supply chain paths (reliability), support ticket clusters (severity), compliance report subgraphs (violation type).
  • 4PyG handles batching by concatenating multiple graphs into one disconnected graph. The batch tensor maps nodes to their graph. global_mean_pool uses this for per-graph aggregation.
  • 5In PyG: stack GCNConv layers, add global_mean_pool(x, batch), add a linear classifier. Train with cross-entropy loss on graph-level labels.

Graph classification is the task of predicting a categorical label for an entire graph by combining GNN message passing with graph pooling to produce a single graph-level representation. While node classification makes per-node predictions, graph classification makes one prediction per graph. The architecture adds a pooling step that aggregates all node embeddings into a fixed-size vector, which is then passed to a classification head. This is the standard approach for molecular property prediction, document classification, and any task where the prediction target is a graph-level property.

Why it matters for enterprise data

In enterprise relational databases, entity-level predictions often require reasoning about an entity's entire relational neighborhood. When you extract that neighborhood as a subgraph, entity-level prediction becomes graph classification:

  • Account risk: Extract all transactions, merchants, and devices linked to an account. Classify the entire subgraph as high-risk or low-risk.
  • Supply chain reliability: Extract the supply path from raw material to finished product. Classify the chain as reliable, at-risk, or disrupted.
  • Claim severity: Extract all medical codes, providers, and treatments in an insurance claim. Classify severity level.

This approach captures holistic patterns that per-entity features miss. An account's risk depends not on any single transaction but on the pattern of all its connections.

How graph classification works

graph_classification.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.loader import DataLoader

class GraphClassifier(torch.nn.Module):
    def __init__(self, num_features, hidden_dim, num_classes):
        super().__init__()
        # Message passing layers
        self.conv1 = GCNConv(num_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        # Classification head
        self.classifier = torch.nn.Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index, batch):
        # Step 1: Message passing (node-level)
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)

        # Step 2: Graph pooling (collapse to graph-level)
        x = global_mean_pool(x, batch)  # [num_graphs, hidden_dim]

        # Step 3: Classify
        return self.classifier(x)  # [num_graphs, num_classes]

# DataLoader batches graphs automatically
loader = DataLoader(dataset, batch_size=32, shuffle=True)
model = GraphClassifier(dataset.num_features, 64, dataset.num_classes)

for batch in loader:
    out = model(batch.x, batch.edge_index, batch.batch)
    loss = F.cross_entropy(out, batch.y)  # batch.y = graph-level labels

The batch tensor is critical: it maps each node to its graph in the mini-batch. global_mean_pool uses it to average within each graph, not across graphs.

Concrete example: molecular toxicity prediction

The TU Datasets collection includes molecular graphs:

  • Each graph = one molecule (10-50 atoms)
  • Nodes = atoms with features [element, charge, hybridization]
  • Edges = chemical bonds with features [bond_type, is_aromatic]
  • Label = toxic (1) or non-toxic (0)

After 3 GCNConv layers, each atom's embedding encodes its local chemical environment (3-hop radius). Global mean pooling averages all atom embeddings into a single 64-dimensional molecule vector. The classifier predicts toxicity from this vector.

Limitations and what comes next

  1. Pooling information loss: Graph pooling compresses variable-size graphs into fixed-size vectors. Large, complex graphs lose more information than small ones. Hierarchical pooling mitigates this.
  2. Expressiveness: Standard message passing cannot distinguish certain graph structures. If two structurally different graphs map to the same graph-level vector, they will be classified identically regardless of having different true labels.
  3. Graph size variation: Enterprise subgraphs can vary enormously in size (a customer with 5 transactions vs. one with 50,000). Models must handle this range without bias toward larger graphs.

Frequently asked questions

What is graph classification?

Graph classification is the task of predicting a label for an entire graph. Unlike node classification (one label per node), graph classification produces one prediction per graph. It requires graph pooling to compress all node representations into a single graph-level vector, which is then fed to a classifier. Common applications include molecular property prediction (is this molecule toxic?) and document classification (what topic is this document graph about?).

How does graph classification differ from node classification?

Node classification makes one prediction per node. Graph classification makes one prediction per graph. Node classification uses per-node embeddings directly. Graph classification requires an additional pooling step that aggregates all node embeddings into a single graph-level vector. The architecture is: GNN layers -> graph pooling -> classifier.

What pooling methods are used for graph classification?

Global mean pooling (average all node embeddings), global sum pooling (sum all node embeddings), global max pooling (element-wise max), hierarchical pooling (DiffPool, SAGPool that progressively coarsen the graph), and attention pooling (learned weighted sum of node embeddings). Global mean pooling is the most common starting point.

How does graph classification apply to enterprise data?

In enterprise settings, entity subgraphs become individual graphs for classification. A customer's entire transaction history forms a subgraph; classify it as high-risk or low-risk. A supply chain path forms a subgraph; classify it as reliable or unreliable. A support ticket cluster forms a subgraph; classify its severity. Each entity's relational neighborhood is one graph in the classification dataset.

How is batching handled for graph classification?

PyG batches multiple graphs by concatenating them into a single large disconnected graph. A batch tensor maps each node to its graph index. When computing global_mean_pool(x, batch), the function averages only within each graph using the batch tensor. This is memory-efficient and allows GPU parallelism across graphs.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.