Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide7 min read

Explainability on Graphs: Understanding Why a GNN Made a Specific Prediction

A GNN flags an account as high-risk. Why? Graph explainability identifies the specific edges, neighbor nodes, and features that drove the prediction, turning a black-box score into an actionable, auditable explanation.

PyTorch Geometric

TL;DR

  • 1GNN predictions are determined by the input subgraph: which nodes, which edges, which features. Explainability identifies the minimal subgraph and features that most influence a specific prediction.
  • 2GNNExplainer learns soft masks over edges and features to find the most important subset. For a fraud prediction, it identifies which connections and which transaction patterns the model relied on.
  • 3Attention weights (from GAT) provide approximate explanations: high attention on an edge means that neighbor was weighted heavily. But attention is not always faithful to the true decision process.
  • 4Graph explanations are richer than tabular: they identify important features AND important connections AND important structural patterns. The explanation is a subgraph, not just a feature importance vector.
  • 5In regulated industries (finance, healthcare), explainability is a compliance requirement. A fraud flag must include the reason. Graph explanations provide it: 'this account shares devices with 3 flagged accounts.'

Graph explainability answers the question: which parts of the input graph caused this prediction? When a GNN classifies an account as fraudulent, the explanation is not a list of feature importances. It is a subgraph: the specific neighbor nodes, the specific edges, and the specific features that drove the score. This subgraph explanation is what an analyst needs to investigate the flag, what a regulator needs for compliance, and what a model developer needs for debugging.

What makes graph explainability different

Tabular explainability (SHAP, LIME) answers: “which input features mattered?” Graph explainability must answer three questions simultaneously:

  • Which features? What attributes of the target node mattered (account age, balance)?
  • Which edges? Which connections influenced the prediction (shared device with flagged account)?
  • Which subgraph structure? Which neighborhood patterns drove the score (dense cluster of interconnected accounts)?

The explanation is a subgraph with highlighted edges and features, not a flat list of feature importances. This is both more informative and more challenging to compute and present.

GNNExplainer

GNNExplainer (Ying et al., 2019) is the foundational method. For a specific prediction (node v classified as fraudulent), it learns two soft masks:

  • Edge mask: a weight [0, 1] for each edge in node v's computation graph. Higher weight = more important for this prediction.
  • Feature mask: a weight [0, 1] for each input feature dimension. Higher weight = more important.

The masks are optimized to maximize mutual information between the masked subgraph and the original prediction. The result: the minimal set of edges and features that, if kept, produce the same prediction.

Attention-based explanations

Graph Attention Networks (GAT) compute learned attention weights for each edge. These weights are naturally interpretable: a high attention weight on edge (j → i) means neighbor j's message was weighted heavily when computing node i's embedding.

Advantages of attention-based explanations:

  • Free: no additional computation beyond the forward pass
  • Per-layer: attention weights at each layer show how information flows through the graph
  • Intuitive: “the model paid attention to these neighbors”

Limitations:

  • Attention is not always faithful: high attention does not always mean high causal influence on the prediction
  • Attention only explains which neighbors were weighted, not which features were important
  • Only available for attention-based architectures (GAT, TransformerConv)

Other explanation methods

  • PGExplainer: learns a global explanation model (not per-instance like GNNExplainer). Faster at inference: one forward pass generates explanations for any node.
  • SubgraphX: uses Monte Carlo tree search to find the most important connected subgraph. Produces explanations that are connected subgraphs (more interpretable than disconnected edge sets).
  • GraphMask: learns which edges can be removed without changing the prediction. Edges that cannot be removed are the explanation.
  • Gradient-based: compute gradients of the prediction with respect to node features and edge weights. Higher gradient = more important. Fast but noisy.

Enterprise requirements

In regulated industries, explainability is not optional:

  • Financial services: regulators require explanations for fraud flags, credit decisions, and AML alerts. “The model said so” is not acceptable.
  • Healthcare: clinical decision support must explain why a patient is flagged as high-risk.
  • Insurance: claim denial must be justified with specific factors.

Graph explanations are especially valuable because they identify not just what features matter but which relationships matter: “this claim is flagged because the claimant shares a phone number with 3 other claimants who filed similar claims within 2 weeks.”

Frequently asked questions

Why is explainability important for GNNs?

In regulated industries (finance, healthcare), predictions must be explainable for compliance. A fraud flag must come with a reason: 'this account shares a device with 3 flagged accounts and received deposits from a known laundering entity.' Without explainability, GNN predictions are black boxes that cannot be acted upon. Explainability also helps debug models and build trust with business stakeholders.

What is GNNExplainer?

GNNExplainer learns a soft mask over edges and node features that identifies the minimal subgraph and feature subset that is most important for a specific prediction. For a node classified as fraudulent, GNNExplainer identifies which edges (connections to which accounts) and which features (transaction amounts, timing patterns) the model relied on. It optimizes the mask to maximize mutual information between the masked subgraph and the prediction.

Can attention weights be used as explanations?

Attention weights from GAT indicate which neighbors the model attended to most for each prediction. High attention on a specific edge means that neighbor's message was weighted heavily. However, attention weights are not always faithful explanations: they show what the model looked at, not necessarily what caused the prediction. They are useful as approximate explanations but should not be the sole basis for compliance reporting.

How does graph explainability differ from tabular explainability?

Tabular explainability (SHAP, LIME) attributes importance to input features. Graph explainability must attribute importance to three things: node features (which attributes mattered), edges (which connections mattered), and subgraph structure (which neighborhood patterns mattered). The explanation is a subgraph, not a feature vector. This richer explanation format is both more informative and more complex to present.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.