The business problem
Retailers lose $1.1 trillion globally to inventory distortion: $634 billion from out-of-stocks and $472 billion from overstocks. Accurate demand forecasting is the single highest-leverage optimization in retail operations. Even a 1% improvement in forecast accuracy reduces inventory costs by 5-10%.
Traditional forecasting models (ARIMA, Prophet, LSTM) treat each product-store combination as an independent time series. They miss the dependencies: a Coca-Cola promotion cannibalizes Pepsi demand. A snowstorm in Chicago affects salt demand in Milwaukee. A viral TikTok video creates sudden correlated demand across geographies. These cross-product and cross-store effects are graph-structured.
Why flat ML fails
- Independent forecasts: Time series models forecast each SKU-store separately. They cannot model substitution effects (Coke promo reduces Pepsi demand) or complementary effects (chips drive salsa sales).
- No spatial correlation: Stores 5 miles apart serve overlapping customers. Weather, events, and demographics create correlated demand patterns that independent models ignore.
- Promotion blindness: A promotion on one product affects demand for related products. Flat models need hand-engineered cross-product features that are brittle and incomplete.
- New product cold start: New products with no history get no forecast. Graph-based models can transfer patterns from similar products.
The relational schema
Node types:
Product (id, category, brand, price, shelf_life)
Store (id, format, sqft, geo_lat, geo_lon)
Promo (id, type, discount_pct, start_date, end_date)
Edge types:
Product --[substitute_of]--> Product (elasticity)
Product --[complement_of]--> Product (lift_factor)
Store --[near]--> Store (distance_km)
Promo --[applies_to]--> Product
Product --[sold_at]--> Store (daily_units, date)Products, stores, and promotions form an interconnected graph. The sold_at edges carry temporal demand data.
PyG architecture: SAGEConv + temporal regression
We use SAGEConv for spatial aggregation (cross-product, cross-store) combined with a GRU cell for temporal dynamics. Each product-store node gets a time-varying embedding that captures both its own history and its neighbors' influence.
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, Linear
class DemandGNN(torch.nn.Module):
def __init__(self, node_dim, hidden_dim=128, seq_len=28):
super().__init__()
self.node_lin = Linear(node_dim, hidden_dim)
# Spatial: aggregate from substitutes, complements, nearby stores
self.conv1 = SAGEConv(hidden_dim, hidden_dim)
self.conv2 = SAGEConv(hidden_dim, hidden_dim)
# Temporal: GRU over daily snapshots
self.gru = torch.nn.GRU(
input_size=hidden_dim, hidden_size=hidden_dim, batch_first=True)
# Regression head: predict next 7 days of demand
self.head = torch.nn.Sequential(
Linear(hidden_dim, 64),
torch.nn.ReLU(),
Linear(64, 7), # 7-day forecast horizon
)
def forward(self, x_seq, edge_index):
# x_seq: (num_nodes, seq_len, node_dim)
B, T, D = x_seq.shape
embeddings = []
for t in range(T):
x = self.node_lin(x_seq[:, t, :])
x = F.relu(self.conv1(x, edge_index))
x = self.conv2(x, edge_index)
embeddings.append(x)
# Stack temporal embeddings and run GRU
h_seq = torch.stack(embeddings, dim=1) # (B, T, H)
_, h_final = self.gru(h_seq)
# Predict next 7 days
return self.head(h_final.squeeze(0)) # (B, 7)SAGEConv aggregates cross-product/store signals per timestep. GRU captures temporal patterns. Output is a 7-day demand forecast per product-store node.
Training considerations
- Loss function: Use quantile regression loss to produce prediction intervals (P10, P50, P90), not just point estimates. Inventory planning needs uncertainty bounds.
- Temporal split: Train on 12 months of history, validate on month 13, test on month 14. Shift the window weekly for rolling evaluation.
- Graph construction: Substitute/complement edges can come from purchase basket analysis or category taxonomy. Store-proximity edges use geographic distance thresholds.
- Feature encoding: Calendar features (holiday, day-of-week), promotion flags, weather, and recent sales velocity are critical temporal node features.
Expected performance
Demand forecasting is a regression task. The right metric is Weighted Mean Absolute Percentage Error (WMAPE), not AUROC:
- ARIMA (per-SKU): ~18% WMAPE
- LightGBM (flat-table): ~14% WMAPE
- GNN (SAGEConv + GRU): ~10-12% WMAPE
- KumoRFM (zero-shot): ~10% WMAPE
Or use KumoRFM in one line
PREDICT daily_units FOR product, store
USING product, store, promotion, sales_historyOne PQL query. KumoRFM auto-discovers product and store relationships, handles temporal dynamics, and outputs 7-day forecasts per product-store pair.
KumoRFM replaces the graph construction, temporal architecture, and training pipeline with a single query. It handles substitute/complement discovery, seasonal patterns, and promotion effects automatically, matching or exceeding hand-tuned GNN forecasters.