Skip to main content

Overview

This document provides a step-by-step guide to improving the performance of KumoRFM, a graph-based relational foundation model for predictive analytics. The tutorial demonstrates how to evaluate and optimize model settings on the H&M dataset, available through the RelBench repository.

1. Introduction

The H&M database contains extensive customer and product data from the company’s e-commerce operations. It includes detailed purchase histories and metadata, ranging from demographic information to product attributes. In this example, we predict the total price per article_id over a 7-day window (start_date, start_date + 7 days], but only for items that had non-zero sales in the previous 7 days.
  • Prediction Window: (2020-09-07, 2020-09-14]
  • Transactions on 2020-09-07 are excluded.

2. Environment Setup

Install required packages:
!pip install kumoai --pre
!pip install relbench
Initialize the KumoRFM environment:
from kumoai.experimental import rfm
rfm.authenticate()
rfm.init()

3. Data Loading

Download the dataset and prepare data frames for each table:
from relbench.datasets import get_dataset
from relbench.tasks import get_task

db = get_dataset('rel-hm', download=True).get_db()
df_dict = {
    table_name: table.df
    for table_name, table in db.table_dict.items()
}

4. Ground Truth Calculation

Define a reference function to compute the target variable — total sales for each item within the next 7 days, restricted to items with prior sales activity.
import pandas as pd

def total_price_by_item_next_7_days(transactions_df: pd.DataFrame, anchor_time: str) -> pd.DataFrame:
    """
    Compute total price per article_id for the next 7-day window (left-exclusion, right-inclusive],
    restricted to items with non-zero sales in the previous 7 days.
    """
    data = transactions_df.copy()
    data["t_dat"] = pd.to_datetime(data["t_dat"])

    anchor_time = pd.to_datetime(anchor_time)
    anchor_time_7d_after = anchor_time + pd.Timedelta(days=7)
    anchor_time_7d_before = anchor_time - pd.Timedelta(days=7)

    prev_mask = (data["t_dat"] > anchor_time_7d_before) & (data["t_dat"] <= anchor_time)
    prev_totals = data.loc[prev_mask].groupby("article_id")["price"].sum()
    eligible_ids = prev_totals[prev_totals > 0].index

    next_mask = (data["t_dat"] > anchor_time) & (data["t_dat"] <= anchor_time_7d_after)
    out = (
        data.loc[next_mask]
        .groupby("article_id")["price"]
        .sum()
        .reindex(eligible_ids, fill_value=0.0)
        .rename("total_price_7d")
        .reset_index()
        .assign(anchor_time=anchor_time)
    )
    return out

groud_truth_df = total_price_by_item_next_7_days(df_dict['transactions'], '2020-09-07')

5. Model Initialization

Initialize the local relational graph and create a KumoRFM model instance.
graph = rfm.LocalGraph.from_data(df_dict, verbose=False)
model = rfm.KumoRFM(graph)
Sample the evaluation dataset:
import numpy as np

test_df = groud_truth_df.sample(n=2000, random_state=42, replace=False)
test_indices = test_df['article_id']

6. Evaluation Helper Function

Define an evaluation function for consistent benchmarking.
def evaluate(run_mode, num_neighbors, use_prediction_time, query):
    with model.batch_mode():
        df = model.predict(
            query=query,
            anchor_time=pd.Timestamp('2020-09-07'),
            indices=test_indices,
            run_mode=run_mode,
            num_neighbors=num_neighbors,
            use_prediction_time=use_prediction_time,
            verbose=False,
        )

    y_pred = df['TARGET_PRED'].to_numpy()
    y_test = test_df['total_price_7d'].to_numpy()
    print(f'MAE (lower is better): {np.abs(y_test - y_pred).mean():.4f}')

7. Performance Optimization

7.1 Align Predictive Query with Criteria

Ensure that the predictive query aligns with the business logic of the task—only predict for articles that had prior sales activity.
query1 = 'PREDICT SUM(transactions.price, 0, 7, days) FOR EACH article.article_id'
query2 = 'PREDICT SUM(transactions.price, 0, 7, days) FOR EACH article.article_id WHERE SUM(transactions.price, -7, 0, days) > 0'
Note that Query 2 is more closely aligned with the prediction criteria—it only generates predictions for articles that have had transactions in the previous 7 days. This alignment is important because in-context sampling is driven by the predictive query criteria. When the condition WHERE SUM(transactions.price, -7, 0, days) > 0 is not included, all article_ids can be sampled during in-context sampling. However, when the condition is applied, only the eligible article items (those meeting the criteria) are sampled. Evaluate both queries:
evaluate(run_mode='fast', num_neighbors=[], use_prediction_time=False, query=query1)  # MAE: 0.3913
evaluate(run_mode='fast', num_neighbors=[], use_prediction_time=False, query=query2)  # MAE: 0.3745 (better performance)
Run ModeNum NeighborsUse Prediction TimeQueryMAENotes
fast[]Falsequery10.3913Baseline
fast[]Falsequery20.3745Better performance
Ensuring that the (1) predictive query, (2) in-context sampling, and (3) prediction criteria are consistent helps improve overall model performance.

7.2 Tune Neighborhood Sampling

Adjusting the number of neighbors (num_neighbors) affects the model’s receptive field. Increasing neighbors generally improves performance, while too many hops may introduce noise.
evaluate(run_mode='fast', num_neighbors=[], use_prediction_time=False, query=query2)  # MAE: 0.3745
evaluate(run_mode='fast', num_neighbors=[8], use_prediction_time=False, query=query2) # MAE: 0.3297
evaluate(run_mode='fast', num_neighbors=[8, 8], use_prediction_time=False, query=query2) # MAE: 0.3276
evaluate(run_mode='fast', num_neighbors=[32], use_prediction_time=False, query=query2) # MAE: 0.2506
evaluate(run_mode='fast', num_neighbors=[32, 32], use_prediction_time=False, query=query2) # MAE: 0.3352
evaluate(run_mode='fast', num_neighbors=[64], use_prediction_time=False, query=query2) # MAE: 0.2493 (best performance)
evaluate(run_mode='fast', num_neighbors=[64, 64], use_prediction_time=False, query=query2) # MAE: 0.3350

Run ModeNum NeighborsUse Prediction TimeQueryMAENotes
fast[]Falsequery20.3745Baseline
fast[8]Falsequery20.3297Improved
fast[8, 8]Falsequery20.3276Slightly better
fast[32]Falsequery20.2506Significant gain
fast[32, 32]Falsequery20.3352Performance drop
fast[64]Falsequery20.2493Best performance
fast[64, 64]Falsequery20.3350Performance drop

7.3 Adjust Run Mode

run_mode controls the number of in-context examples used during prediction:
Run ModeIn-Context ExamplesDescription
fast1,000Quick but less accurate
normal5,000Balanced
best10,000Highest accuracy
evaluate(run_mode='fast', num_neighbors=[64], use_prediction_time=False, query=query2) # MAE: 0.2493
evaluate(run_mode='normal', num_neighbors=[64], use_prediction_time=False, query=query2) # MAE: 0.2156
evaluate(run_mode='best', num_neighbors=[64], use_prediction_time=False, query=query2)  # MAE: 0.2106 (best performance)
Run ModeNum NeighborsUse Prediction TimeQueryMAENotes
fast[64]Falsequery20.2493Baseline
normal[64]Falsequery20.2156Improved
best[64]Falsequery20.2106Best performance

7.4 Enable Prediction Time Feature (Optional)

Including prediction_time as a feature can capture temporal seasonality. In this example, enabling it did not improve results.
evaluate(run_mode='best', num_neighbors=[64], use_prediction_time=False, query=query2) # MAE: 0.2106
evaluate(run_mode='best', num_neighbors=[64], use_prediction_time=True, query=query2) # MAE: 0.0.2355

7.5 Add or Remove Features from the Graph

By default, all the data in the graph are used for prediction. However, too fine-grained features may introduce noise as part of in-context learning. Removing those features can improve performance. Similarly, you may improve model performance by adding additional signals by providing new tables or new column features in existing table.

8. Results Summary

By iteratively refining predictive query alignment, neighborhood sampling, and run mode, MAE improved from 0.3913 → 0.2106. Best Configuration:
graph = rfm.LocalGraph.from_data(df_dict, verbose=False)
model = rfm.KumoRFM(graph)

evaluate(run_mode='best', num_neighbors=[64], use_prediction_time=False, query=query2)

9. Key Takeaways

  • Align predictive queries with business logic and in-context sampling.
  • Optimize neighborhood sampling (num_neighbors).
  • Use higher run_mode values for accuracy-sensitive applications.
You can check out this notebook: Open in Colab.