What GENConv does
GENConv is the message-passing layer from the DeeperGCN framework. It introduces a generalized aggregation function:
- Compute messages from neighbors (with optional edge features via MLP)
- Aggregate using softmax-weighted sum with learnable temperature
- Combine with residual connection and normalization
The math (simplified)
# Generalized aggregation (softmax variant)
m_j = MLP(h_j, e_ij) # message with optional edge features
w_j = exp(m_j / t) / Σ_k exp(m_k / t) # softmax with temperature t
h_i' = Σ_j w_j · m_j
When t -> infinity: approaches mean aggregation
When t -> 0: approaches max aggregation
Intermediate t: learnable attention-like weighting
# DeeperGCN block (pre-activation residual)
h = BatchNorm(h)
h = ReLU(h)
h = GENConv(h, edge_index)
h = h + h_residual # skip connectionThe temperature parameter t is learnable, allowing the model to find the optimal aggregation between mean and max for each task.
PyG implementation
import torch
import torch.nn.functional as F
from torch_geometric.nn import GENConv, DeepGCNLayer
class DeeperGCN(torch.nn.Module):
def __init__(self, in_channels, hidden, out_channels, num_layers=14):
super().__init__()
self.node_encoder = torch.nn.Linear(in_channels, hidden)
self.layers = torch.nn.ModuleList()
for _ in range(num_layers):
conv = GENConv(hidden, hidden, aggr='softmax', t=1.0,
learn_t=True, num_layers=2)
norm = torch.nn.BatchNorm1d(hidden)
act = torch.nn.ReLU(inplace=True)
layer = DeepGCNLayer(conv, norm, act,
block='res+', dropout=0.1)
self.layers.append(layer)
self.classifier = torch.nn.Linear(hidden, out_channels)
def forward(self, x, edge_index, edge_attr, batch):
x = self.node_encoder(x)
for layer in self.layers:
x = layer(x, edge_index, edge_attr)
x = self.layers[0].act(self.layers[0].norm(x))
from torch_geometric.nn import global_mean_pool
x = global_mean_pool(x, batch)
return self.classifier(x)
model = DeeperGCN(in_channels=9, hidden=256, out_channels=1,
num_layers=14)DeepGCNLayer wraps GENConv with pre-activation residual connections. block='res+' means residual connection with pre-activation. learn_t=True makes temperature learnable.
When to use GENConv
- Molecular property prediction. Complex chemical properties depend on long-range atomic interactions that require deep propagation to capture.
- Point cloud processing. 3D shape understanding benefits from deep models that can propagate information across large spatial extents.
- When you need more than 3-4 layers. Any task where GCNConv's shallow limit is a bottleneck.
When not to use GENConv
- Shallow tasks. If 2-3 layers of GCNConv already achieve your target, GENConv's depth overhead is unnecessary.
- Heterogeneous graphs. GENConv is designed for homogeneous graphs. Use HGTConv for multi-type data.