Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide6 min read

Graph Upsampling: Refining Coarsened Graphs Back to Original Resolution

Coarsening compresses a graph to capture global patterns. Upsampling restores it to full resolution for per-node predictions. Together with skip connections, they form Graph U-Net: the most effective architecture for node-level tasks on large graphs.

PyTorch Geometric

TL;DR

  • 1Graph upsampling reverses coarsening: supernode representations are projected back to original nodes using the assignment information saved during the coarsening step.
  • 2Skip connections from the encoder add fine-grained detail lost during coarsening. The result combines multi-scale information: global structure from coarsened levels + local detail from the original.
  • 3Three upsampling methods: assignment-based (reverse DiffPool assignments), index-based (reverse TopK selection masks), and interpolation-based (kNN interpolation for point clouds).
  • 4Graph U-Net architecture: encode at multiple coarsening levels, then decode by upsampling with skip connections. This is the graph equivalent of the image U-Net encoder-decoder.
  • 5Essential for node segmentation, per-point prediction in point clouds, and mesh-based simulation where coarsening captures global physics but per-node accuracy is required.

Graph upsampling projects representations from a coarsened graph back to the original full-resolution graph. When hierarchical coarsening reduces 100,000 nodes to 1,000 for efficient global processing, upsampling maps those 1,000 supernode representations back to the original 100,000 nodes. Combined with skip connections from the encoding phase, this enables per-node predictions that benefit from both multi-scale context and fine-grained local detail.

The Graph U-Net pattern

The most common use of graph upsampling is in Graph U-Net architectures, which mirror the encoder-decoder pattern from image segmentation:

  1. Encoder: GNN layers + coarsening. 100K nodes → 10K → 1K. Each level captures broader structural context.
  2. Bottleneck: GNN layers on the coarsest graph. Captures global patterns.
  3. Decoder: Upsampling + GNN layers. 1K → 10K → 100K. Each level restores resolution.
  4. Skip connections: Encoder features at each level are concatenated with decoder features at the corresponding level.
graph_unet.py
from torch_geometric.nn import TopKPooling, SAGEConv, global_mean_pool

class GraphUNet(torch.nn.Module):
    def __init__(self, in_ch, hidden, out_ch):
        super().__init__()
        # Encoder
        self.enc1 = SAGEConv(in_ch, hidden)
        self.pool1 = TopKPooling(hidden, ratio=0.5)
        self.enc2 = SAGEConv(hidden, hidden)
        self.pool2 = TopKPooling(hidden, ratio=0.5)

        # Bottleneck
        self.bottleneck = SAGEConv(hidden, hidden)

        # Decoder (GNN layers to refine after upsampling)
        self.dec2 = SAGEConv(hidden * 2, hidden)  # *2 for skip
        self.dec1 = SAGEConv(hidden * 2, hidden)
        self.out = torch.nn.Linear(hidden, out_ch)

    def forward(self, x, edge_index, batch):
        # Encode level 1
        x1 = self.enc1(x, edge_index).relu()
        x2, ei2, _, b2, perm2, score2 = self.pool1(x1, edge_index, batch=batch)

        # Encode level 2
        x2 = self.enc2(x2, ei2).relu()
        x3, ei3, _, b3, perm3, score3 = self.pool2(x2, ei2, batch=b2)

        # Bottleneck
        x3 = self.bottleneck(x3, ei3).relu()

        # Upsample level 2: scatter back using perm indices
        up2 = torch.zeros_like(x2)
        up2[perm3] = x3[:len(perm3)]
        x2 = self.dec2(torch.cat([up2, x2], dim=-1), ei2).relu()

        # Upsample level 1
        up1 = torch.zeros_like(x1)
        up1[perm2] = x2[:len(perm2)]
        x1 = self.dec1(torch.cat([up1, x1], dim=-1), edge_index).relu()

        return self.out(x1)

Graph U-Net: encode (coarsen), process at low resolution, decode (upsample with skip connections). Per-node predictions benefit from multi-scale context.

Upsampling methods

Index-based upsampling (for TopKPool)

TopKPooling saves the indices of retained nodes. During upsampling, the coarsened features are scattered back to the retained positions. Non-retained nodes receive zero features (or skip-connection features from the encoder). This is simple and efficient.

Assignment-based upsampling (for DiffPool)

DiffPool learns a soft assignment matrix S mapping nodes to clusters. Upsampling uses the transpose S^T to distribute supernode features back to original nodes, weighted by the original assignment probabilities.

Interpolation-based upsampling (for point clouds)

For point cloud data, upsampled features are interpolated from the nearest coarsened points using inverse-distance weighting. This is the approach used in PointNet++ and is geometrically meaningful for spatial data.

Applications

  • Point cloud segmentation: Coarsen for global context, upsample for per-point labels.
  • Mesh-based simulation: Solve at coarse resolution, refine to fine mesh for output.
  • Node classification on large graphs: Coarsen to capture community-level patterns, upsample to classify individual nodes.
  • Graph super-resolution: Generate higher-resolution graphs from coarse representations.

Frequently asked questions

What is graph upsampling?

Graph upsampling is the reverse of graph coarsening: it projects representations from a coarsened (reduced) graph back to the original full-resolution graph. Supernodes are expanded back to their constituent nodes, with representations distributed based on the original coarsening assignment.

Why is graph upsampling needed?

When you coarsen a graph for hierarchical processing but need per-node predictions (node classification, segmentation), you must map coarsened representations back to the original nodes. Upsampling with skip connections (like U-Net) combines multi-scale information for accurate per-node predictions.

How does graph upsampling compare to image upsampling?

Image upsampling (transposed convolution, bilinear interpolation) operates on regular grids. Graph upsampling must handle irregular topology. The assignment information from the coarsening step guides how supernode features are distributed back to original nodes. Skip connections from the encoding phase add fine-grained detail.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.