Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide8 min read

Over-Smoothing: Why Deep GNNs Fail

Over-smoothing is the primary reason GNNs cannot be stacked as deep as CNNs or transformers. After too many layers, every node's representation converges to the same value, making the model unable to distinguish between nodes.

PyTorch Geometric

TL;DR

  • 1Over-smoothing occurs when stacking too many GNN layers causes all node representations to converge. Information gets diffused across the graph until local differences wash out.
  • 2Practical limit: 2-3 layers for most tasks. On Cora, GCN accuracy peaks at 2-3 layers (~81.5%) and collapses below 30% at 8+ layers. Dense graphs smooth faster than sparse ones.
  • 3Each message passing layer is a diffusion step. After k layers, each node's representation is a weighted average of all nodes within k hops. Eventually, all nodes see (nearly) the entire graph.
  • 4Mitigation techniques: skip connections (preserve early features), DropEdge (slow diffusion), normalization (prevent collapse), graph rewiring (modify topology).
  • 5Graph transformers avoid over-smoothing entirely by using global attention instead of iterative local aggregation. KumoRFM's graph transformer architecture does not suffer from this limitation.

Over-smoothing is the phenomenon where node representations in a graph neural network converge to indistinguishable values as more message passing layers are stacked. Each GNN layer aggregates information from neighbors, spreading signals further across the graph. After enough layers, every node has received information from the entire graph, and the averaging effect of aggregation washes out the local differences that distinguish one node from another. The representations become “smooth” and the model loses its ability to make accurate per-node predictions.

Why it matters for enterprise data

Enterprise relational databases contain rich multi-hop patterns. A customer's fraud risk depends on merchants 3 hops away. A product's demand depends on supplier reliability 4 hops away. Capturing these patterns requires deep GNNs. But over-smoothing limits practical depth to 2-3 layers, meaning GNNs can only see 2-3 hops into the relational graph.

This is a real constraint. On the RelBench benchmark, GNNs with 2 layers achieve 75.83 AUROC. But deeper GNNs do not improve. KumoRFM's graph transformer, which bypasses over-smoothing with global attention, reaches 81.14 AUROC after fine-tuning.

How over-smoothing works

Think of message passing as diffusion. Each layer spreads each node's information to its neighbors. After k layers:

  • k=1: Each node knows about its direct neighbors
  • k=2: Each node knows about its 2-hop neighborhood
  • k=5: Each node knows about its 5-hop neighborhood (often the entire graph for small-world networks)
  • k=10: All nodes have aggregated nearly identical global information

Mathematically, repeated application of the normalized adjacency matrix converges to the stationary distribution of a random walk on the graph. All node representations converge to vectors proportional to the square root of their degree.

Concrete example: customer nodes losing identity

Consider a banking graph with 10,000 customer nodes, 500,000 transaction edges, and 5,000 merchant nodes. Average shortest path is 4 hops.

  • 2 layers: Customer “Alice” (frequent online shopper) and “Bob” (in-store only) have distinct embeddings. Cosine similarity: 0.3.
  • 4 layers: Both Alice and Bob have aggregated information from shared merchants. Cosine similarity: 0.7.
  • 8 layers: Both have aggregated information from nearly the entire graph. Cosine similarity: 0.95. The model can barely tell them apart.

Detection and mitigation

Detecting over-smoothing

detect_over_smoothing.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

def measure_smoothing(model, data, num_layers):
    """Track embedding similarity across layers."""
    x = data.x
    similarities = []

    for i, conv in enumerate(model.convs):
        x = F.relu(conv(x, data.edge_index))
        # Compute average pairwise cosine similarity
        x_norm = F.normalize(x, dim=1)
        sim = (x_norm @ x_norm.T).mean().item()
        similarities.append(sim)
        print(f"Layer {i+1}: avg cosine similarity = {sim:.4f}")

    return similarities
# Healthy: similarity stays below 0.5
# Over-smoothed: similarity approaches 1.0

Track cosine similarity between node embeddings after each layer. When similarity approaches 1.0, over-smoothing has occurred.

Mitigation techniques

mitigate_over_smoothing.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.utils import dropout_edge

class DeepGCN(torch.nn.Module):
    """GCN with skip connections and DropEdge to resist over-smoothing."""
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers=8):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_dim, hidden_dim))
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
        self.convs.append(GCNConv(hidden_dim, out_dim))

    def forward(self, x, edge_index):
        # DropEdge: randomly remove 20% of edges each forward pass
        edge_index_dropped, _ = dropout_edge(edge_index, p=0.2,
                                              training=self.training)

        for i, conv in enumerate(self.convs[:-1]):
            identity = x  # save for skip connection
            x = conv(x, edge_index_dropped)
            x = F.relu(x)
            if x.shape == identity.shape:
                x = x + identity  # residual (skip) connection
            x = F.dropout(x, p=0.5, training=self.training)

        x = self.convs[-1](x, edge_index_dropped)
        return x

Skip connections preserve early-layer features. DropEdge randomly removes edges to slow diffusion. Together they extend practical depth from 2-3 to 6-8 layers.

Limitations of mitigation techniques

  1. Skip connections help but do not solve: They preserve early-layer features alongside smoothed deeper features. The deep features are still smoothed. Practical improvement: from 2-3 layers to 6-8 layers.
  2. DropEdge slows but does not stop: Removing edges reduces the diffusion rate but does not change the convergence destination. Given enough layers, smoothing still occurs.
  3. Normalization prevents collapse but limits capacity: PairNorm forces embeddings to maintain spread, but this constraint can prevent the model from learning useful representations.

Graph transformers offer a fundamentally different architecture. Instead of iterative local aggregation (which inherently diffuses), they use global attention where every node attends directly to every other node. No diffusion, no over-smoothing. KumoRFM's Relational Graph Transformer is built on this principle.

Frequently asked questions

What is over-smoothing in GNNs?

Over-smoothing is the phenomenon where stacking too many GNN layers causes all node representations to converge to similar or identical values. After each message passing layer, nodes average information from their neighbors. After enough layers, every node has aggregated information from the entire graph, washing out local differences. The representations become 'over-smoothed' and nodes become indistinguishable, destroying the model's ability to make accurate per-node predictions.

How many GNN layers before over-smoothing occurs?

Over-smoothing typically becomes noticeable after 4-5 layers and severe after 6-8 layers. On the Cora citation network, GCN accuracy peaks at 2-3 layers (~81.5%) and drops below 30% at 8+ layers. The exact threshold depends on graph structure: dense graphs smooth faster than sparse ones, and heterogeneous graphs resist smoothing longer than homogeneous ones.

Why can't GNNs be as deep as CNNs or transformers?

CNNs and transformers benefit from depth because information is organized on regular grids or sequences with clear spatial/positional relationships. In graphs, repeated aggregation is fundamentally a diffusion process. Each layer spreads information further, and after enough layers, the diffusion reaches equilibrium where all nodes have the same mixture of global information. Skip connections and normalization help but do not fully solve the problem.

How do you detect over-smoothing?

Measure the pairwise cosine similarity between node embeddings after each layer. In a healthy GNN, embeddings maintain diversity (average similarity < 0.5). When over-smoothing occurs, average similarity approaches 1.0 as embeddings converge. You can also track validation accuracy as you add layers: a sharp drop beyond a certain depth signals over-smoothing.

What techniques mitigate over-smoothing?

Five main approaches: (1) Skip connections (residual connections) that preserve early-layer features. (2) DropEdge: randomly remove edges during training to slow information diffusion. (3) PairNorm or NodeNorm: normalization that prevents embedding collapse. (4) Graph rewiring to modify topology. (5) Graph transformers that replace local aggregation with global attention, avoiding the smoothing problem entirely.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.