Masked token prediction is a self-supervised pre-training objective where a model learns to predict hidden cell values in relational data from their surrounding context. It is the mechanism that allows relational foundation models to learn general-purpose representations from unlabeled enterprise databases. The model sees a table row with some values masked out, reasons over the graph of connected entities, and predicts what the missing values should be.
From BERT to relational databases
In natural language processing, BERT revolutionized the field by introducing masked language modeling: hide 15% of words in a sentence, train the model to predict them, and the resulting representations transfer to virtually any NLP task. Masked token prediction applies the same principle to structured relational data.
The key difference is the nature of context. In text, context is sequential: the words before and after the mask. In relational data, context is structural: it flows through foreign key relationships across tables. To predict a masked purchase amount, the model can draw on the customer's demographic features, their previous orders, the product's category and price history, and the merchant's typical transaction size.
How it works
The process has four stages:
- Tokenization: Each cell value in a relational database is converted into a token. Numerical values are discretized into bins. Categorical values map to learned embeddings. Timestamps get special temporal encodings.
- Masking: A fraction of tokens (typically 15-30%) are replaced with a special [MASK] token. The selection is random but stratified to ensure coverage across columns and tables.
- Encoding: The masked relational graph is processed through a graph neural network (typically a graph transformer). Each node aggregates information from its neighborhood, including cross-table connections.
- Prediction: The model predicts the original value of each masked token. For categorical values, this is a classification head. For numerical values, it is a regression head.
# Simplified masked token prediction pipeline
import torch
from torch_geometric.data import HeteroData
def mask_tokens(data: HeteroData, mask_rate: float = 0.15):
"""Randomly mask cell values across all node types."""
for node_type in data.node_types:
x = data[node_type].x
num_tokens = x.numel()
mask = torch.rand(num_tokens) < mask_rate
data[node_type].mask = mask.view(x.shape)
data[node_type].original = x.clone()
x[data[node_type].mask] = MASK_TOKEN_ID
return data
# Training loop (simplified)
for batch in dataloader:
masked_batch = mask_tokens(batch)
predictions = model(masked_batch)
loss = reconstruction_loss(predictions, masked_batch)
loss.backward()
optimizer.step()The core loop: mask, encode, predict, backpropagate. The model learns relational patterns without any task-specific labels.
Why relational context makes this powerful
Consider a retail database with customers, orders, and products. When the model encounters a masked “order_amount” field, it cannot simply look at the same row. It must reason:
- This customer's average order is $85 (from order history)
- The product costs $120 (from the product table)
- This merchant gives 10% discounts on weekends (from merchant patterns)
- The order was placed on a Saturday (from the timestamp)
The predicted value: approximately $108. To arrive at this answer, the model had to learn customer spending patterns, product pricing, merchant behavior, and temporal effects. These are exactly the representations needed for downstream tasks like demand forecasting or churn prediction.
Masking strategies
The choice of what and how much to mask significantly affects pre-training quality:
- Uniform random masking: Simple and effective. Each cell has an equal probability of being masked. Works well as a baseline.
- Column-aware masking: Ensures each column gets masked proportionally. Prevents the model from ignoring rare columns that happen to be masked less frequently.
- Correlated masking: Masks entire groups of related cells (e.g., all fields in an order row). Forces the model to rely more on cross-table context rather than intra-row correlations.
Masking rate is critical. At 15%, the task is relatively easy and training is stable but slow to converge. At 30%, the model is forced to learn richer representations but training requires more careful tuning. Most relational foundation models use rates between 15-25%.
Connection to downstream tasks
After pre-training with masked token prediction, the model's encoder produces rich representations for every entity in the database. These representations are then used for downstream tasks:
- Classification: Add a classification head on top of customer node representations to predict churn.
- Regression: Use product node representations to predict demand or lifetime value.
- Link prediction: Use pairs of node representations to predict future interactions (recommendations).
On the RelBench benchmark, models pre-trained with masked token prediction and then fine-tuned achieve 81.14 AUROC, compared to 62.44 for flat-table approaches that cannot leverage relational pre-training.