Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Use Case11 min read

Ad Click Prediction: GNN on User-Ad Interaction Graphs

Digital advertising is a $600B market where a 0.1% CTR improvement translates to billions in revenue. Deep CTR models miss the collaborative graph signal. Here is how to build a GNN that captures user-ad affinity from the interaction graph.

PyTorch Geometric

TL;DR

  • 1Ad click prediction is a link prediction problem on a user-ad-publisher graph. GNNs capture collaborative click patterns that feature-interaction models (DeepFM, DCN) miss.
  • 2HeteroConv with SAGEConv models user-ad, ad-advertiser, and ad-publisher relationships with type-specific transformations for contextual CTR prediction.
  • 3On RelBench benchmarks, GNNs achieve 75.83 AUROC vs 62.44 for flat-table LightGBM. Graph-level collaborative signal provides the lift.
  • 4Production ad systems serve billions of impressions daily. Cached embeddings and dot-product scoring keep inference under 10ms.
  • 5KumoRFM predicts CTR with one PQL query (76.71 AUROC zero-shot), handling cold-start ads and collaborative patterns automatically.

The business problem

Digital advertising is a $600+ billion market. For ad platforms (Google, Meta, Amazon), CTR prediction quality directly determines revenue: better predictions mean higher-quality ad placements, higher click rates, and more advertiser spend. A 0.1% improvement in CTR prediction at Google's scale translates to billions in annual revenue.

State-of-the-art CTR models (DeepFM, DCN-V2, DLRM) learn feature interactions between user attributes and ad attributes. They are powerful but process each user-ad pair independently, missing the collaborative graph signal: users who clicked on this ad also clicked on that one, and ads that perform well on this publisher also perform well on similar publishers.

Why flat ML fails

  • No collaborative signal: Feature-interaction models see user X and ad Y in isolation. GNNs see that users similar to X clicked ads similar to Y, providing collaborative CTR estimation.
  • Cold-start ads: New ads have no click history. Their category, advertiser, and creative features connect them to the graph, enabling click prediction before any impressions.
  • Contextual blindness: The same ad performs differently on different publishers. The graph captures publisher-ad affinity patterns that feature engineering struggles to encode.
  • Long-tail users: Users with few interactions get poor predictions from feature models. Graph neighbors provide additional signal for sparse users.

The relational schema

schema.txt
Node types:
  User       (id, geo, device, interest_vector)
  Ad         (id, category, creative_emb, bid_price)
  Advertiser (id, industry, budget, quality_score)
  Publisher  (id, category, traffic_volume, audience_demo)

Edge types:
  User      --[clicked]-->     Ad         (timestamp)
  User      --[viewed]-->      Ad         (timestamp, position)
  Ad        --[owned_by]-->    Advertiser
  Ad        --[shown_on]-->    Publisher   (impressions, ctr)
  User      --[similar_to]-->  User       (interest_overlap)

Four node types. Click and view edges carry engagement signal. Publisher-ad edges capture contextual performance.

PyG architecture: HeteroConv for CTR

ctr_model.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, HeteroConv, Linear

class CTRGNN(torch.nn.Module):
    def __init__(self, hidden_dim=128):
        super().__init__()
        self.user_lin = Linear(-1, hidden_dim)
        self.ad_lin = Linear(-1, hidden_dim)
        self.advertiser_lin = Linear(-1, hidden_dim)
        self.publisher_lin = Linear(-1, hidden_dim)

        self.conv1 = HeteroConv({
            ('user', 'clicked', 'ad'): SAGEConv(
                hidden_dim, hidden_dim),
            ('user', 'viewed', 'ad'): SAGEConv(
                hidden_dim, hidden_dim),
            ('ad', 'owned_by', 'advertiser'): SAGEConv(
                hidden_dim, hidden_dim),
            ('ad', 'shown_on', 'publisher'): SAGEConv(
                hidden_dim, hidden_dim),
        }, aggr='sum')

        self.conv2 = HeteroConv({
            ('user', 'clicked', 'ad'): SAGEConv(
                hidden_dim, hidden_dim),
            ('ad', 'owned_by', 'advertiser'): SAGEConv(
                hidden_dim, hidden_dim),
            ('ad', 'shown_on', 'publisher'): SAGEConv(
                hidden_dim, hidden_dim),
        }, aggr='sum')

    def encode(self, x_dict, edge_index_dict):
        x_dict['user'] = self.user_lin(x_dict['user'])
        x_dict['ad'] = self.ad_lin(x_dict['ad'])
        x_dict['advertiser'] = self.advertiser_lin(
            x_dict['advertiser'])
        x_dict['publisher'] = self.publisher_lin(
            x_dict['publisher'])

        x_dict = {k: F.relu(v) for k, v in
                  self.conv1(x_dict, edge_index_dict).items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        return x_dict

    def predict_ctr(self, user_emb, ad_emb):
        return torch.sigmoid(
            (user_emb * ad_emb).sum(dim=-1))

HeteroConv encodes users and ads with graph context. CTR prediction is a dot product between user and ad embeddings, enabling sub-10ms serving.

Expected performance

  • Logistic regression: ~62 AUROC
  • LightGBM (flat-table): 62.44 AUROC
  • GNN (HeteroConv): 75.83 AUROC
  • KumoRFM (zero-shot): 76.71 AUROC

Or use KumoRFM in one line

KumoRFM PQL
PREDICT is_clicked FOR user, ad
USING user, ad, advertiser, publisher, impression

One PQL query. KumoRFM constructs the user-ad graph, handles cold-start, and predicts CTR with collaborative context.

Frequently asked questions

Why use GNNs for CTR prediction instead of deep CTR models?

Deep CTR models (DeepFM, DCN) process user-ad feature interactions but miss the graph-level signal: which users click similar ads, which ads appear in similar contexts, and how user behavior clusters relate. GNNs capture these collaborative patterns, providing lift especially for cold-start ads and long-tail users.

What graph structure works for ad click prediction?

Users, ads, advertisers, and publishers form a heterogeneous graph. Edges connect users to ads they clicked/viewed, ads to their advertisers, ads to publishers where they appeared, and users to user segments. The graph captures both content affinity and contextual relevance.

How do you handle the scale of ad systems with GNNs?

Ad systems serve billions of impressions daily. Use neighbor sampling (PyG's NeighborLoader) with small fan-out (10, 5) and cached node embeddings. Precompute user and ad embeddings offline, then score user-ad pairs in real-time using dot product. Target latency: under 10ms per prediction.

Can GNNs improve ad relevance for new advertisers?

Yes. New advertisers with no click history can be embedded through their ad content features, category, and similarity to existing ads. The GNN transfers click patterns from similar ads to the new advertiser, solving the cold-start problem that plagues ID-based CTR models.

How does KumoRFM handle ad click prediction?

KumoRFM takes your ad platform data (users, ads, impressions, clicks, advertisers) and predicts CTR with one PQL query. It automatically constructs the user-ad graph, handles cold-start, and captures contextual and collaborative signals.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.