The business problem
Digital advertising is a $600+ billion market. For ad platforms (Google, Meta, Amazon), CTR prediction quality directly determines revenue: better predictions mean higher-quality ad placements, higher click rates, and more advertiser spend. A 0.1% improvement in CTR prediction at Google's scale translates to billions in annual revenue.
State-of-the-art CTR models (DeepFM, DCN-V2, DLRM) learn feature interactions between user attributes and ad attributes. They are powerful but process each user-ad pair independently, missing the collaborative graph signal: users who clicked on this ad also clicked on that one, and ads that perform well on this publisher also perform well on similar publishers.
Why flat ML fails
- No collaborative signal: Feature-interaction models see user X and ad Y in isolation. GNNs see that users similar to X clicked ads similar to Y, providing collaborative CTR estimation.
- Cold-start ads: New ads have no click history. Their category, advertiser, and creative features connect them to the graph, enabling click prediction before any impressions.
- Contextual blindness: The same ad performs differently on different publishers. The graph captures publisher-ad affinity patterns that feature engineering struggles to encode.
- Long-tail users: Users with few interactions get poor predictions from feature models. Graph neighbors provide additional signal for sparse users.
The relational schema
Node types:
User (id, geo, device, interest_vector)
Ad (id, category, creative_emb, bid_price)
Advertiser (id, industry, budget, quality_score)
Publisher (id, category, traffic_volume, audience_demo)
Edge types:
User --[clicked]--> Ad (timestamp)
User --[viewed]--> Ad (timestamp, position)
Ad --[owned_by]--> Advertiser
Ad --[shown_on]--> Publisher (impressions, ctr)
User --[similar_to]--> User (interest_overlap)Four node types. Click and view edges carry engagement signal. Publisher-ad edges capture contextual performance.
PyG architecture: HeteroConv for CTR
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, HeteroConv, Linear
class CTRGNN(torch.nn.Module):
def __init__(self, hidden_dim=128):
super().__init__()
self.user_lin = Linear(-1, hidden_dim)
self.ad_lin = Linear(-1, hidden_dim)
self.advertiser_lin = Linear(-1, hidden_dim)
self.publisher_lin = Linear(-1, hidden_dim)
self.conv1 = HeteroConv({
('user', 'clicked', 'ad'): SAGEConv(
hidden_dim, hidden_dim),
('user', 'viewed', 'ad'): SAGEConv(
hidden_dim, hidden_dim),
('ad', 'owned_by', 'advertiser'): SAGEConv(
hidden_dim, hidden_dim),
('ad', 'shown_on', 'publisher'): SAGEConv(
hidden_dim, hidden_dim),
}, aggr='sum')
self.conv2 = HeteroConv({
('user', 'clicked', 'ad'): SAGEConv(
hidden_dim, hidden_dim),
('ad', 'owned_by', 'advertiser'): SAGEConv(
hidden_dim, hidden_dim),
('ad', 'shown_on', 'publisher'): SAGEConv(
hidden_dim, hidden_dim),
}, aggr='sum')
def encode(self, x_dict, edge_index_dict):
x_dict['user'] = self.user_lin(x_dict['user'])
x_dict['ad'] = self.ad_lin(x_dict['ad'])
x_dict['advertiser'] = self.advertiser_lin(
x_dict['advertiser'])
x_dict['publisher'] = self.publisher_lin(
x_dict['publisher'])
x_dict = {k: F.relu(v) for k, v in
self.conv1(x_dict, edge_index_dict).items()}
x_dict = self.conv2(x_dict, edge_index_dict)
return x_dict
def predict_ctr(self, user_emb, ad_emb):
return torch.sigmoid(
(user_emb * ad_emb).sum(dim=-1))HeteroConv encodes users and ads with graph context. CTR prediction is a dot product between user and ad embeddings, enabling sub-10ms serving.
Expected performance
- Logistic regression: ~62 AUROC
- LightGBM (flat-table): 62.44 AUROC
- GNN (HeteroConv): 75.83 AUROC
- KumoRFM (zero-shot): 76.71 AUROC
Or use KumoRFM in one line
PREDICT is_clicked FOR user, ad
USING user, ad, advertiser, publisher, impressionOne PQL query. KumoRFM constructs the user-ad graph, handles cold-start, and predicts CTR with collaborative context.