Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Guide7 min read

Masked Token Prediction: Pre-training on Relational Data

Masked token prediction is to relational foundation models what masked language modeling is to BERT. Hide a cell value, predict it from relational context, and learn representations that transfer to any downstream task.

PyTorch Geometric

TL;DR

  • 1Masked token prediction hides random cell values in relational tables and trains the model to reconstruct them from surrounding context, including linked rows in other tables via foreign keys.
  • 2This is a self-supervised objective: no labeled data is needed. The relational structure itself provides the training signal, enabling pre-training on massive unlabeled enterprise databases.
  • 3Context flows through the graph. To predict a masked order amount, the model uses the customer's history, the product's price range, the merchant's average ticket, and temporal patterns.
  • 4KumoRFM uses masked token prediction as its core pre-training objective, learning from 103 million rows across diverse relational databases before fine-tuning on specific prediction tasks.
  • 5The quality of masking strategy matters: mask too little and the model learns shallow patterns. Mask too much and context is insufficient. Typical rates are 15-30% of cell values.

Masked token prediction is a self-supervised pre-training objective where a model learns to predict hidden cell values in relational data from their surrounding context. It is the mechanism that allows relational foundation models to learn general-purpose representations from unlabeled enterprise databases. The model sees a table row with some values masked out, reasons over the graph of connected entities, and predicts what the missing values should be.

From BERT to relational databases

In natural language processing, BERT revolutionized the field by introducing masked language modeling: hide 15% of words in a sentence, train the model to predict them, and the resulting representations transfer to virtually any NLP task. Masked token prediction applies the same principle to structured relational data.

The key difference is the nature of context. In text, context is sequential: the words before and after the mask. In relational data, context is structural: it flows through foreign key relationships across tables. To predict a masked purchase amount, the model can draw on the customer's demographic features, their previous orders, the product's category and price history, and the merchant's typical transaction size.

How it works

The process has four stages:

  1. Tokenization: Each cell value in a relational database is converted into a token. Numerical values are discretized into bins. Categorical values map to learned embeddings. Timestamps get special temporal encodings.
  2. Masking: A fraction of tokens (typically 15-30%) are replaced with a special [MASK] token. The selection is random but stratified to ensure coverage across columns and tables.
  3. Encoding: The masked relational graph is processed through a graph neural network (typically a graph transformer). Each node aggregates information from its neighborhood, including cross-table connections.
  4. Prediction: The model predicts the original value of each masked token. For categorical values, this is a classification head. For numerical values, it is a regression head.
masked_token_prediction.py
# Simplified masked token prediction pipeline
import torch
from torch_geometric.data import HeteroData

def mask_tokens(data: HeteroData, mask_rate: float = 0.15):
    """Randomly mask cell values across all node types."""
    for node_type in data.node_types:
        x = data[node_type].x
        num_tokens = x.numel()
        mask = torch.rand(num_tokens) < mask_rate
        data[node_type].mask = mask.view(x.shape)
        data[node_type].original = x.clone()
        x[data[node_type].mask] = MASK_TOKEN_ID
    return data

# Training loop (simplified)
for batch in dataloader:
    masked_batch = mask_tokens(batch)
    predictions = model(masked_batch)
    loss = reconstruction_loss(predictions, masked_batch)
    loss.backward()
    optimizer.step()

The core loop: mask, encode, predict, backpropagate. The model learns relational patterns without any task-specific labels.

Why relational context makes this powerful

Consider a retail database with customers, orders, and products. When the model encounters a masked “order_amount” field, it cannot simply look at the same row. It must reason:

  • This customer's average order is $85 (from order history)
  • The product costs $120 (from the product table)
  • This merchant gives 10% discounts on weekends (from merchant patterns)
  • The order was placed on a Saturday (from the timestamp)

The predicted value: approximately $108. To arrive at this answer, the model had to learn customer spending patterns, product pricing, merchant behavior, and temporal effects. These are exactly the representations needed for downstream tasks like demand forecasting or churn prediction.

Masking strategies

The choice of what and how much to mask significantly affects pre-training quality:

  • Uniform random masking: Simple and effective. Each cell has an equal probability of being masked. Works well as a baseline.
  • Column-aware masking: Ensures each column gets masked proportionally. Prevents the model from ignoring rare columns that happen to be masked less frequently.
  • Correlated masking: Masks entire groups of related cells (e.g., all fields in an order row). Forces the model to rely more on cross-table context rather than intra-row correlations.

Masking rate is critical. At 15%, the task is relatively easy and training is stable but slow to converge. At 30%, the model is forced to learn richer representations but training requires more careful tuning. Most relational foundation models use rates between 15-25%.

Connection to downstream tasks

After pre-training with masked token prediction, the model's encoder produces rich representations for every entity in the database. These representations are then used for downstream tasks:

  • Classification: Add a classification head on top of customer node representations to predict churn.
  • Regression: Use product node representations to predict demand or lifetime value.
  • Link prediction: Use pairs of node representations to predict future interactions (recommendations).

On the RelBench benchmark, models pre-trained with masked token prediction and then fine-tuned achieve 81.14 AUROC, compared to 62.44 for flat-table approaches that cannot leverage relational pre-training.

Frequently asked questions

What is masked token prediction in the context of graphs?

Masked token prediction is a self-supervised pre-training objective where random cell values in a relational database (or node/edge features in a graph) are hidden, and the model learns to reconstruct them from the surrounding relational context. It is the graph equivalent of masked language modeling in BERT.

Why is masked token prediction useful for relational data?

It forces the model to learn general-purpose representations of entities and relationships without requiring any labeled data. A model that can predict a hidden order amount from customer history, product details, and temporal patterns has learned deep relational semantics that transfer to downstream tasks like churn prediction or fraud detection.

How does masked token prediction differ from masked language modeling?

In language models, tokens are words in a sequence. In relational data, tokens are cell values in a table row, and context comes from both the row itself and connected rows across tables via foreign keys. The model must reason over graph structure, not just sequential context.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.