Over-smoothing is the phenomenon where node representations in a graph neural network converge to indistinguishable values as more message passing layers are stacked. Each GNN layer aggregates information from neighbors, spreading signals further across the graph. After enough layers, every node has received information from the entire graph, and the averaging effect of aggregation washes out the local differences that distinguish one node from another. The representations become “smooth” and the model loses its ability to make accurate per-node predictions.
Why it matters for enterprise data
Enterprise relational databases contain rich multi-hop patterns. A customer's fraud risk depends on merchants 3 hops away. A product's demand depends on supplier reliability 4 hops away. Capturing these patterns requires deep GNNs. But over-smoothing limits practical depth to 2-3 layers, meaning GNNs can only see 2-3 hops into the relational graph.
This is a real constraint. On the RelBench benchmark, GNNs with 2 layers achieve 75.83 AUROC. But deeper GNNs do not improve. KumoRFM's graph transformer, which bypasses over-smoothing with global attention, reaches 81.14 AUROC after fine-tuning.
How over-smoothing works
Think of message passing as diffusion. Each layer spreads each node's information to its neighbors. After k layers:
- k=1: Each node knows about its direct neighbors
- k=2: Each node knows about its 2-hop neighborhood
- k=5: Each node knows about its 5-hop neighborhood (often the entire graph for small-world networks)
- k=10: All nodes have aggregated nearly identical global information
Mathematically, repeated application of the normalized adjacency matrix converges to the stationary distribution of a random walk on the graph. All node representations converge to vectors proportional to the square root of their degree.
Concrete example: customer nodes losing identity
Consider a banking graph with 10,000 customer nodes, 500,000 transaction edges, and 5,000 merchant nodes. Average shortest path is 4 hops.
- 2 layers: Customer “Alice” (frequent online shopper) and “Bob” (in-store only) have distinct embeddings. Cosine similarity: 0.3.
- 4 layers: Both Alice and Bob have aggregated information from shared merchants. Cosine similarity: 0.7.
- 8 layers: Both have aggregated information from nearly the entire graph. Cosine similarity: 0.95. The model can barely tell them apart.
Detection and mitigation
Detecting over-smoothing
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
def measure_smoothing(model, data, num_layers):
"""Track embedding similarity across layers."""
x = data.x
similarities = []
for i, conv in enumerate(model.convs):
x = F.relu(conv(x, data.edge_index))
# Compute average pairwise cosine similarity
x_norm = F.normalize(x, dim=1)
sim = (x_norm @ x_norm.T).mean().item()
similarities.append(sim)
print(f"Layer {i+1}: avg cosine similarity = {sim:.4f}")
return similarities
# Healthy: similarity stays below 0.5
# Over-smoothed: similarity approaches 1.0Track cosine similarity between node embeddings after each layer. When similarity approaches 1.0, over-smoothing has occurred.
Mitigation techniques
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.utils import dropout_edge
class DeepGCN(torch.nn.Module):
"""GCN with skip connections and DropEdge to resist over-smoothing."""
def __init__(self, in_dim, hidden_dim, out_dim, num_layers=8):
super().__init__()
self.convs = torch.nn.ModuleList()
self.convs.append(GCNConv(in_dim, hidden_dim))
for _ in range(num_layers - 2):
self.convs.append(GCNConv(hidden_dim, hidden_dim))
self.convs.append(GCNConv(hidden_dim, out_dim))
def forward(self, x, edge_index):
# DropEdge: randomly remove 20% of edges each forward pass
edge_index_dropped, _ = dropout_edge(edge_index, p=0.2,
training=self.training)
for i, conv in enumerate(self.convs[:-1]):
identity = x # save for skip connection
x = conv(x, edge_index_dropped)
x = F.relu(x)
if x.shape == identity.shape:
x = x + identity # residual (skip) connection
x = F.dropout(x, p=0.5, training=self.training)
x = self.convs[-1](x, edge_index_dropped)
return xSkip connections preserve early-layer features. DropEdge randomly removes edges to slow diffusion. Together they extend practical depth from 2-3 to 6-8 layers.
Limitations of mitigation techniques
- Skip connections help but do not solve: They preserve early-layer features alongside smoothed deeper features. The deep features are still smoothed. Practical improvement: from 2-3 layers to 6-8 layers.
- DropEdge slows but does not stop: Removing edges reduces the diffusion rate but does not change the convergence destination. Given enough layers, smoothing still occurs.
- Normalization prevents collapse but limits capacity: PairNorm forces embeddings to maintain spread, but this constraint can prevent the model from learning useful representations.
Graph transformers offer a fundamentally different architecture. Instead of iterative local aggregation (which inherently diffuses), they use global attention where every node attends directly to every other node. No diffusion, no over-smoothing. KumoRFM's Relational Graph Transformer is built on this principle.