> ## Documentation Index
> Fetch the complete documentation index at: https://kumo.ai/docs/llms.txt
> Use this file to discover all available pages before exploring further.

# How to improve model performance

## Overview

This document provides a step-by-step guide to improving the performance of **KumoRFM**. The tutorial demonstrates how to evaluate and optimize model settings on the **H\&M dataset**, available through the RelBench repository.

## 1. Introduction

The **H\&M database** contains extensive customer and product data from the company’s e-commerce operations. It includes detailed purchase histories and metadata, ranging from demographic information to product attributes.

In this example, we predict the **`total price per article_id`** over a 7-day window `(start_date, start_date + 7 days]`, but only for items that had non-zero sales in the previous 7 days.

* **Prediction Window:** `(2020-09-07, 2020-09-14]`
* Transactions on `2020-09-07` are excluded.

## 2. Environment Setup

Install required packages:

```bash theme={null}
!pip install kumoai --pre
!pip install relbench
```

Initialize the KumoRFM environment:

```python theme={null}
import kumoai.rfm as rfm
rfm.authenticate()
rfm.init()
```

## 3. Data Loading

Download the dataset and prepare data frames for each table:

```python theme={null}
from relbench.datasets import get_dataset
from relbench.tasks import get_task

db = get_dataset('rel-hm', download=True).get_db()
df_dict = {
    table_name: table.df
    for table_name, table in db.table_dict.items()
}
```

## 4. Ground Truth Calculation

Define a reference function to compute the target variable — total sales for each item within the next 7 days, restricted to items with prior sales activity.

```python theme={null}
import pandas as pd

def total_price_by_item_next_7_days(transactions_df: pd.DataFrame, anchor_time: str) -> pd.DataFrame:
    """
    Compute total price per article_id for the next 7-day window (left-exclusion, right-inclusive],
    restricted to items with non-zero sales in the previous 7 days.
    """
    data = transactions_df.copy()
    data["t_dat"] = pd.to_datetime(data["t_dat"])

    anchor_time = pd.to_datetime(anchor_time)
    anchor_time_7d_after = anchor_time + pd.Timedelta(days=7)
    anchor_time_7d_before = anchor_time - pd.Timedelta(days=7)

    prev_mask = (data["t_dat"] > anchor_time_7d_before) & (data["t_dat"] <= anchor_time)
    prev_totals = data.loc[prev_mask].groupby("article_id")["price"].sum()
    eligible_ids = prev_totals[prev_totals > 0].index

    next_mask = (data["t_dat"] > anchor_time) & (data["t_dat"] <= anchor_time_7d_after)
    out = (
        data.loc[next_mask]
        .groupby("article_id")["price"]
        .sum()
        .reindex(eligible_ids, fill_value=0.0)
        .rename("total_price_7d")
        .reset_index()
        .assign(anchor_time=anchor_time)
    )
    return out

ground_truth_df = total_price_by_item_next_7_days(df_dict['transactions'], '2020-09-07')
```

## 5. Model Initialization

Initialize the local relational graph and create a KumoRFM model instance.

```python theme={null}
graph = rfm.Graph.from_data(df_dict, verbose=False)
model = rfm.KumoRFM(graph)
```

Sample the evaluation dataset:

```python theme={null}
import numpy as np

test_df = ground_truth_df.sample(n=2000, random_state=42, replace=False)
test_indices = test_df['article_id']
```

## 6. Evaluation Helper Function

Define an evaluation function for consistent benchmarking.

```python theme={null}
def evaluate(run_mode, num_neighbors, use_prediction_time, query):
    with model.batch_mode():
        df = model.predict(
            query=query,
            anchor_time=pd.Timestamp('2020-09-07'),
            indices=test_indices,
            run_mode=run_mode,
            num_neighbors=num_neighbors,
            use_prediction_time=use_prediction_time,
            verbose=False,
        )

    y_pred = df['TARGET_PRED'].to_numpy()
    y_test = test_df['total_price_7d'].to_numpy()
    print(f'MAE (lower is better): {np.abs(y_test - y_pred).mean():.4f}')
```

## 7. Performance Optimization

### 7.1 Align Predictive Query with Criteria

Ensure that the predictive query aligns with the business logic of the task--only predict for articles that had prior sales activity.

```python theme={null}
query1 = 'PREDICT SUM(transactions.price, 0, 7, days) FOR EACH article.article_id'
query2 = 'PREDICT SUM(transactions.price, 0, 7, days) FOR EACH article.article_id WHERE SUM(transactions.price, -7, 0, days) > 0'
```

Note that Query 2 is more closely aligned with the prediction criteria—it only generates predictions for articles that **have had transactions in the previous 7 days**. This alignment is important because in-context sampling is driven by the predictive query criteria.

When the condition `WHERE SUM(transactions.price, -7, 0, days) > 0` is not included, all article\_ids can be sampled during in-context sampling. However, when the condition is applied, only the eligible article items (those meeting the criteria) are sampled.

Evaluate both queries:

```python theme={null}
evaluate(run_mode='fast', num_neighbors=[], use_prediction_time=False, query=query1)  # MAE: 0.3913
evaluate(run_mode='fast', num_neighbors=[], use_prediction_time=False, query=query2)  # MAE: 0.3745 (better performance)
```

| Run Mode | Num Neighbors | Use Prediction Time | Query  | MAE    | Notes                  |
| -------- | ------------- | ------------------- | ------ | ------ | ---------------------- |
| fast     | \[]           | False               | query1 | 0.3913 | Baseline               |
| fast     | \[]           | False               | query2 | 0.3745 | **Better performance** |

Ensuring that the (1) predictive query, (2) in-context sampling, and (3) prediction criteria are consistent helps improve overall model performance.

### 7.2 Tune Neighborhood Sampling

Adjusting the number of neighbors (`num_neighbors`) affects the model’s receptive field. Increasing neighbors generally improves performance, while too many hops may introduce noise.

```python theme={null}
evaluate(run_mode='fast', num_neighbors=[], use_prediction_time=False, query=query2)  # MAE: 0.3745
evaluate(run_mode='fast', num_neighbors=[8], use_prediction_time=False, query=query2) # MAE: 0.3297
evaluate(run_mode='fast', num_neighbors=[8, 8], use_prediction_time=False, query=query2) # MAE: 0.3276
evaluate(run_mode='fast', num_neighbors=[32], use_prediction_time=False, query=query2) # MAE: 0.2506
evaluate(run_mode='fast', num_neighbors=[32, 32], use_prediction_time=False, query=query2) # MAE: 0.3352
evaluate(run_mode='fast', num_neighbors=[64], use_prediction_time=False, query=query2) # MAE: 0.2493 (best performance)
evaluate(run_mode='fast', num_neighbors=[64, 64], use_prediction_time=False, query=query2) # MAE: 0.3350
```

| Run Mode | Num Neighbors | Use Prediction Time | Query  | MAE    | Notes                |
| -------- | ------------- | ------------------- | ------ | ------ | -------------------- |
| fast     | \[]           | False               | query2 | 0.3745 | Baseline             |
| fast     | \[8]          | False               | query2 | 0.3297 | Improved             |
| fast     | \[8, 8]       | False               | query2 | 0.3276 | Slightly better      |
| fast     | \[32]         | False               | query2 | 0.2506 | Significant gain     |
| fast     | \[32, 32]     | False               | query2 | 0.3352 | Performance drop     |
| fast     | \[64]         | False               | query2 | 0.2493 | **Best performance** |
| fast     | \[64, 64]     | False               | query2 | 0.3350 | Performance drop     |

### 7.3 Adjust Run Mode

`run_mode` controls the number of in-context examples used during prediction:

| Run Mode | In-Context Examples | Description             |
| -------- | ------------------- | ----------------------- |
| fast     | 1,000               | Quick but less accurate |
| normal   | 5,000               | Balanced                |
| best     | 10,000              | Highest accuracy        |

```python theme={null}
evaluate(run_mode='fast', num_neighbors=[64], use_prediction_time=False, query=query2) # MAE: 0.2493
evaluate(run_mode='normal', num_neighbors=[64], use_prediction_time=False, query=query2) # MAE: 0.2156
evaluate(run_mode='best', num_neighbors=[64], use_prediction_time=False, query=query2)  # MAE: 0.2106 (best performance)
```

| Run Mode | Num Neighbors | Use Prediction Time | Query  | MAE    | Notes                |
| -------- | ------------- | ------------------- | ------ | ------ | -------------------- |
| fast     | \[64]         | False               | query2 | 0.2493 | Baseline             |
| normal   | \[64]         | False               | query2 | 0.2156 | Improved             |
| best     | \[64]         | False               | query2 | 0.2106 | **Best performance** |

### 7.4 Enable Prediction Time Feature (Optional)

Including `prediction_time` as a feature can capture temporal seasonality. In this example, enabling it did **not** improve results.

```python theme={null}
evaluate(run_mode='best', num_neighbors=[64], use_prediction_time=False, query=query2) # MAE: 0.2106
evaluate(run_mode='best', num_neighbors=[64], use_prediction_time=True, query=query2) # MAE: 0.2355
```

### 7.5 Add or Remove Features from the Graph

By default, all the data in the graph are used for prediction. However, too fine-grained features may introduce noise as part of in-context learning. Removing those features can improve performance.

Similarly, you may improve model performance by adding additional signals by providing new tables or new column features in existing table.

## 8. Results Summary

By iteratively refining predictive query alignment, neighborhood sampling, and run mode, **MAE improved from 0.3913 → 0.2106**.

**Best Configuration:**

```python theme={null}
graph = rfm.Graph.from_data(df_dict, verbose=False)
model = rfm.KumoRFM(graph)

evaluate(run_mode='best', num_neighbors=[64], use_prediction_time=False, query=query2)
```

## 9. Key Takeaways

* Align predictive queries with business logic and in-context sampling.
* Optimize neighborhood sampling (`num_neighbors`).
* Use higher `run_mode` values for accuracy-sensitive applications.

<Info>
  To learn more about explainability see this [example notebook](https://colab.research.google.com/drive/1X_uPP2Z8Xizo2JoXdqjHaVRgDTTvYyTF?usp=sharing)
</Info>
