Graph upsampling projects representations from a coarsened graph back to the original full-resolution graph. When hierarchical coarsening reduces 100,000 nodes to 1,000 for efficient global processing, upsampling maps those 1,000 supernode representations back to the original 100,000 nodes. Combined with skip connections from the encoding phase, this enables per-node predictions that benefit from both multi-scale context and fine-grained local detail.
The Graph U-Net pattern
The most common use of graph upsampling is in Graph U-Net architectures, which mirror the encoder-decoder pattern from image segmentation:
- Encoder: GNN layers + coarsening. 100K nodes → 10K → 1K. Each level captures broader structural context.
- Bottleneck: GNN layers on the coarsest graph. Captures global patterns.
- Decoder: Upsampling + GNN layers. 1K → 10K → 100K. Each level restores resolution.
- Skip connections: Encoder features at each level are concatenated with decoder features at the corresponding level.
from torch_geometric.nn import TopKPooling, SAGEConv, global_mean_pool
class GraphUNet(torch.nn.Module):
def __init__(self, in_ch, hidden, out_ch):
super().__init__()
# Encoder
self.enc1 = SAGEConv(in_ch, hidden)
self.pool1 = TopKPooling(hidden, ratio=0.5)
self.enc2 = SAGEConv(hidden, hidden)
self.pool2 = TopKPooling(hidden, ratio=0.5)
# Bottleneck
self.bottleneck = SAGEConv(hidden, hidden)
# Decoder (GNN layers to refine after upsampling)
self.dec2 = SAGEConv(hidden * 2, hidden) # *2 for skip
self.dec1 = SAGEConv(hidden * 2, hidden)
self.out = torch.nn.Linear(hidden, out_ch)
def forward(self, x, edge_index, batch):
# Encode level 1
x1 = self.enc1(x, edge_index).relu()
x2, ei2, _, b2, perm2, score2 = self.pool1(x1, edge_index, batch=batch)
# Encode level 2
x2 = self.enc2(x2, ei2).relu()
x3, ei3, _, b3, perm3, score3 = self.pool2(x2, ei2, batch=b2)
# Bottleneck
x3 = self.bottleneck(x3, ei3).relu()
# Upsample level 2: scatter back using perm indices
up2 = torch.zeros_like(x2)
up2[perm3] = x3[:len(perm3)]
x2 = self.dec2(torch.cat([up2, x2], dim=-1), ei2).relu()
# Upsample level 1
up1 = torch.zeros_like(x1)
up1[perm2] = x2[:len(perm2)]
x1 = self.dec1(torch.cat([up1, x1], dim=-1), edge_index).relu()
return self.out(x1)Graph U-Net: encode (coarsen), process at low resolution, decode (upsample with skip connections). Per-node predictions benefit from multi-scale context.
Upsampling methods
Index-based upsampling (for TopKPool)
TopKPooling saves the indices of retained nodes. During upsampling, the coarsened features are scattered back to the retained positions. Non-retained nodes receive zero features (or skip-connection features from the encoder). This is simple and efficient.
Assignment-based upsampling (for DiffPool)
DiffPool learns a soft assignment matrix S mapping nodes to clusters. Upsampling uses the transpose S^T to distribute supernode features back to original nodes, weighted by the original assignment probabilities.
Interpolation-based upsampling (for point clouds)
For point cloud data, upsampled features are interpolated from the nearest coarsened points using inverse-distance weighting. This is the approach used in PointNet++ and is geometrically meaningful for spatial data.
Applications
- Point cloud segmentation: Coarsen for global context, upsample for per-point labels.
- Mesh-based simulation: Solve at coarse resolution, refine to fine mesh for output.
- Node classification on large graphs: Coarsen to capture community-level patterns, upsample to classify individual nodes.
- Graph super-resolution: Generate higher-resolution graphs from coarse representations.