Original Paper
Principal Neighbourhood Aggregation for Graph Nets
Corso et al. (2020). NeurIPS 2020
Read paper →What PNAConv does
PNAConv addresses a fundamental limitation: no single aggregation function captures all neighborhood information. Its approach:
- Apply multiple aggregation functions in parallel: mean, sum, max, and standard deviation
- Scale each aggregated result by degree-based scalers: identity, amplification (degree), attenuation (1/degree)
- Concatenate all aggregator-scaler combinations
- Pass through a linear transformation to combine them
With 4 aggregators and 3 scalers, each node gets a 12x richer representation of its neighborhood than a single-aggregator layer.
The math (simplified)
# Multiple aggregations in parallel
agg_mean = MEAN(h_j for j in N(i))
agg_sum = SUM(h_j for j in N(i))
agg_max = MAX(h_j for j in N(i))
agg_std = STD(h_j for j in N(i))
# Multiple scalers per aggregation
for each agg in [mean, sum, max, std]:
s_identity = agg
s_amplify = agg * log(deg(i) + 1)
s_attenuate = agg / log(deg(i) + 1)
# Concatenate all combinations (4 agg x 3 scalers = 12 vectors)
h_i' = W · CONCAT(all scaler outputs) + b4 aggregators x 3 scalers = 12 neighborhood views per node. The linear layer learns which combinations matter for the task.
PyG implementation
import torch
import torch.nn.functional as F
from torch_geometric.nn import PNAConv, global_add_pool
from torch_geometric.utils import degree
class PNA(torch.nn.Module):
def __init__(self, in_channels, hidden, out_channels, edge_dim, deg):
super().__init__()
aggregators = ['mean', 'sum', 'max', 'std']
scalers = ['identity', 'amplification', 'attenuation']
self.conv1 = PNAConv(in_channels, hidden, aggregators=aggregators,
scalers=scalers, deg=deg, edge_dim=edge_dim)
self.conv2 = PNAConv(hidden, hidden, aggregators=aggregators,
scalers=scalers, deg=deg, edge_dim=edge_dim)
self.classifier = torch.nn.Linear(hidden, out_channels)
def forward(self, x, edge_index, edge_attr, batch):
x = self.conv1(x, edge_index, edge_attr)
x = F.relu(x)
x = self.conv2(x, edge_index, edge_attr)
x = F.relu(x)
x = global_add_pool(x, batch)
return self.classifier(x)
# Compute degree histogram (required, do once)
deg = degree(data.edge_index[1], data.num_nodes, dtype=torch.long)
deg_hist = torch.bincount(deg, minlength=1)
model = PNA(in_channels=9, hidden=64, out_channels=1,
edge_dim=3, deg=deg_hist)The degree histogram is computed once from the training graph. PNAConv also supports edge features natively via edge_dim.
When to use PNAConv
- Molecular property prediction. Molecules have diverse local structures: aromatic rings, functional groups, chain branches. Multiple aggregators capture this variety better than sum alone.
- When you need more expressiveness than GIN. If GINConv hits a ceiling on your graph classification task, PNAConv is the next step up before moving to full graph transformers.
- Graphs with variable degree distribution. The degree-based scalers help PNAConv handle graphs where node degrees range from 1 to 1000+, normalizing the aggregation appropriately.
- As the local layer in GPSConv. PNA is a common choice for the local message passing component in graph transformer architectures.
When not to use PNAConv
- Node classification on citation networks. On Cora/CiteSeer, the extra expressiveness does not improve over GCNConv. The task is too simple to need multiple aggregators.
- Very large graphs. 12x the feature dimension means 12x more memory. On million-node graphs, this becomes prohibitive. Use SAGEConv with sampling instead.