Documentation Index
Fetch the complete documentation index at: https://kumo.ai/docs/llms.txt
Use this file to discover all available pages before exploring further.
Overview
This document provides a step-by-step guide to improving the performance of KumoRFM. The tutorial demonstrates how to evaluate and optimize model settings on the H&M dataset, available through the RelBench repository.
1. Introduction
The H&M database contains extensive customer and product data from the company’s e-commerce operations. It includes detailed purchase histories and metadata, ranging from demographic information to product attributes.
In this example, we predict the total price per article_id over a 7-day window (start_date, start_date + 7 days], but only for items that had non-zero sales in the previous 7 days.
- Prediction Window:
(2020-09-07, 2020-09-14]
- Transactions on
2020-09-07 are excluded.
2. Environment Setup
Install required packages:
!pip install kumoai --pre
!pip install relbench
Initialize the KumoRFM environment:
from kumoai.experimental import rfm
rfm.authenticate()
rfm.init()
3. Data Loading
Download the dataset and prepare data frames for each table:
from relbench.datasets import get_dataset
from relbench.tasks import get_task
db = get_dataset('rel-hm', download=True).get_db()
df_dict = {
table_name: table.df
for table_name, table in db.table_dict.items()
}
4. Ground Truth Calculation
Define a reference function to compute the target variable — total sales for each item within the next 7 days, restricted to items with prior sales activity.
import pandas as pd
def total_price_by_item_next_7_days(transactions_df: pd.DataFrame, anchor_time: str) -> pd.DataFrame:
"""
Compute total price per article_id for the next 7-day window (left-exclusion, right-inclusive],
restricted to items with non-zero sales in the previous 7 days.
"""
data = transactions_df.copy()
data["t_dat"] = pd.to_datetime(data["t_dat"])
anchor_time = pd.to_datetime(anchor_time)
anchor_time_7d_after = anchor_time + pd.Timedelta(days=7)
anchor_time_7d_before = anchor_time - pd.Timedelta(days=7)
prev_mask = (data["t_dat"] > anchor_time_7d_before) & (data["t_dat"] <= anchor_time)
prev_totals = data.loc[prev_mask].groupby("article_id")["price"].sum()
eligible_ids = prev_totals[prev_totals > 0].index
next_mask = (data["t_dat"] > anchor_time) & (data["t_dat"] <= anchor_time_7d_after)
out = (
data.loc[next_mask]
.groupby("article_id")["price"]
.sum()
.reindex(eligible_ids, fill_value=0.0)
.rename("total_price_7d")
.reset_index()
.assign(anchor_time=anchor_time)
)
return out
groud_truth_df = total_price_by_item_next_7_days(df_dict['transactions'], '2020-09-07')
5. Model Initialization
Initialize the local relational graph and create a KumoRFM model instance.
graph = rfm.LocalGraph.from_data(df_dict, verbose=False)
model = rfm.KumoRFM(graph)
Sample the evaluation dataset:
import numpy as np
test_df = groud_truth_df.sample(n=2000, random_state=42, replace=False)
test_indices = test_df['article_id']
6. Evaluation Helper Function
Define an evaluation function for consistent benchmarking.
def evaluate(run_mode, num_neighbors, use_prediction_time, query):
with model.batch_mode():
df = model.predict(
query=query,
anchor_time=pd.Timestamp('2020-09-07'),
indices=test_indices,
run_mode=run_mode,
num_neighbors=num_neighbors,
use_prediction_time=use_prediction_time,
verbose=False,
)
y_pred = df['TARGET_PRED'].to_numpy()
y_test = test_df['total_price_7d'].to_numpy()
print(f'MAE (lower is better): {np.abs(y_test - y_pred).mean():.4f}')
7.1 Align Predictive Query with Criteria
Ensure that the predictive query aligns with the business logic of the task—only predict for articles that had prior sales activity.
query1 = 'PREDICT SUM(transactions.price, 0, 7, days) FOR EACH article.article_id'
query2 = 'PREDICT SUM(transactions.price, 0, 7, days) FOR EACH article.article_id WHERE SUM(transactions.price, -7, 0, days) > 0'
Note that Query 2 is more closely aligned with the prediction criteria—it only generates predictions for articles that have had transactions in the previous 7 days. This alignment is important because in-context sampling is driven by the predictive query criteria.
When the condition WHERE SUM(transactions.price, -7, 0, days) > 0 is not included, all article_ids can be sampled during in-context sampling. However, when the condition is applied, only the eligible article items (those meeting the criteria) are sampled.
Evaluate both queries:
evaluate(run_mode='fast', num_neighbors=[], use_prediction_time=False, query=query1) # MAE: 0.3913
evaluate(run_mode='fast', num_neighbors=[], use_prediction_time=False, query=query2) # MAE: 0.3745 (better performance)
| Run Mode | Num Neighbors | Use Prediction Time | Query | MAE | Notes |
|---|
| fast | [] | False | query1 | 0.3913 | Baseline |
| fast | [] | False | query2 | 0.3745 | Better performance |
Ensuring that the (1) predictive query, (2) in-context sampling, and (3) prediction criteria are consistent helps improve overall model performance.
7.2 Tune Neighborhood Sampling
Adjusting the number of neighbors (num_neighbors) affects the model’s receptive field. Increasing neighbors generally improves performance, while too many hops may introduce noise.
evaluate(run_mode='fast', num_neighbors=[], use_prediction_time=False, query=query2) # MAE: 0.3745
evaluate(run_mode='fast', num_neighbors=[8], use_prediction_time=False, query=query2) # MAE: 0.3297
evaluate(run_mode='fast', num_neighbors=[8, 8], use_prediction_time=False, query=query2) # MAE: 0.3276
evaluate(run_mode='fast', num_neighbors=[32], use_prediction_time=False, query=query2) # MAE: 0.2506
evaluate(run_mode='fast', num_neighbors=[32, 32], use_prediction_time=False, query=query2) # MAE: 0.3352
evaluate(run_mode='fast', num_neighbors=[64], use_prediction_time=False, query=query2) # MAE: 0.2493 (best performance)
evaluate(run_mode='fast', num_neighbors=[64, 64], use_prediction_time=False, query=query2) # MAE: 0.3350
| Run Mode | Num Neighbors | Use Prediction Time | Query | MAE | Notes |
|---|
| fast | [] | False | query2 | 0.3745 | Baseline |
| fast | [8] | False | query2 | 0.3297 | Improved |
| fast | [8, 8] | False | query2 | 0.3276 | Slightly better |
| fast | [32] | False | query2 | 0.2506 | Significant gain |
| fast | [32, 32] | False | query2 | 0.3352 | Performance drop |
| fast | [64] | False | query2 | 0.2493 | Best performance |
| fast | [64, 64] | False | query2 | 0.3350 | Performance drop |
7.3 Adjust Run Mode
run_mode controls the number of in-context examples used during prediction:
| Run Mode | In-Context Examples | Description |
|---|
| fast | 1,000 | Quick but less accurate |
| normal | 5,000 | Balanced |
| best | 10,000 | Highest accuracy |
evaluate(run_mode='fast', num_neighbors=[64], use_prediction_time=False, query=query2) # MAE: 0.2493
evaluate(run_mode='normal', num_neighbors=[64], use_prediction_time=False, query=query2) # MAE: 0.2156
evaluate(run_mode='best', num_neighbors=[64], use_prediction_time=False, query=query2) # MAE: 0.2106 (best performance)
| Run Mode | Num Neighbors | Use Prediction Time | Query | MAE | Notes |
|---|
| fast | [64] | False | query2 | 0.2493 | Baseline |
| normal | [64] | False | query2 | 0.2156 | Improved |
| best | [64] | False | query2 | 0.2106 | Best performance |
7.4 Enable Prediction Time Feature (Optional)
Including prediction_time as a feature can capture temporal seasonality. In this example, enabling it did not improve results.
evaluate(run_mode='best', num_neighbors=[64], use_prediction_time=False, query=query2) # MAE: 0.2106
evaluate(run_mode='best', num_neighbors=[64], use_prediction_time=True, query=query2) # MAE: 0.0.2355
7.5 Add or Remove Features from the Graph
By default, all the data in the graph are used for prediction. However, too fine-grained features may introduce noise as part of in-context learning. Removing those features can improve performance.
Similarly, you may improve model performance by adding additional signals by providing new tables or new column features in existing table.
8. Results Summary
By iteratively refining predictive query alignment, neighborhood sampling, and run mode, MAE improved from 0.3913 → 0.2106.
Best Configuration:
graph = rfm.LocalGraph.from_data(df_dict, verbose=False)
model = rfm.KumoRFM(graph)
evaluate(run_mode='best', num_neighbors=[64], use_prediction_time=False, query=query2)
9. Key Takeaways
- Align predictive queries with business logic and in-context sampling.
- Optimize neighborhood sampling (
num_neighbors).
- Use higher
run_mode values for accuracy-sensitive applications.