Skip to main content
Online serving is Kumo’s low-latency inference path: a distilled (shallow) model runs at request time, using embeddings from a base (deep GNN) job and real-time interaction (RTI) features. For large-scale offline scoring, use batch prediction instead (trainer). This walkthrough assumes you completed introduction—you can use a connector, build a Graph, define a PredictiveQuery, and call fit(). On top of that, it covers distillation, batch prediction for embeddings, artifact export, and how those connect to inference, using one concrete graph below.

Running example: H&M retail graph

The three-table retail pattern below (customers, transactions, articles) is used previously in introduction and many Kumo examples. The scenario is inspired by the H&M personalized fashion recommendation dataset from RelBench (see the RelBench repository ).
The fraud label and transaction_id primary key in the code below are illustrative.
Define the graph and two query strings:
import kumoai as kumo
from kumoai.trainer import ModelPlan, Trainer, TrainingJobResult

# connector = ...  # Snowflake, S3, etc. — see `introduction`
customers = kumo.Table(
    source_table=connector.table("customers"),
    primary_key="customer_id",
)
customers.infer_metadata()
articles = kumo.Table(
    source_table=connector.table("articles"),
    primary_key="article_id",
)
articles.infer_metadata()
transactions = kumo.Table(source_table=connector.table("transactions"))
transactions.infer_metadata()

graph = kumo.Graph({
    "customers": customers,
    "articles": articles,
    "transactions": transactions,
})
graph.link("customers", "customer_id", "transactions")
graph.link("articles", "article_id", "transactions")

query_churn = (
    "PREDICT COUNT(transactions.*, 0, 10, days) = 0 "
    "FOR EACH customers.customer_id"
)
query_fraud = (
    "PREDICT transactions.fraud = 1 "
    "FOR EACH transactions.transaction_id"
)

pq_churn = kumo.PredictiveQuery(graph=graph, query=query_churn)
pq_fraud = kumo.PredictiveQuery(graph=graph, query=query_fraud)

Why two predictive queries?

The graph above uses two PredictiveQuery objects on the same graph:
  • “pq_churn“ — Trains the deep GNN on churn at the customer level (10-day activity). That job produces node embeddings you reuse when exporting artifacts for online serving.
  • “pq_fraud“ — Defines the serving task (fraud on each transaction) that the distilled model will score at low latency.
Distillation links the shallow model to the base embeddings so online predictions stay aligned with the deep model’s view of entities (e.g. customers).

End-to-end flow

  1. Train the base GNN with pq_churn (churn on customers).
  2. Distill with pq_fraud (fraud on transactions) using suggest_distilled_model_plan(..., base_model_id=...) and DistillationTrainer.
  3. Batch-predict embeddings with load() on the base job id, a prediction table from pq_churn, and output_types including embeddings—export consumes this batch job.
  4. Export with `export_model(): ``training_job_id`` is the **distilled** job, ``batch_prediction_job_id`` is that embedding job, ``output_path` is your bundle prefix (e.g. S3).
  5. Deploy and infer from the exported artifacts; managed hosting is set up through Kumo.

Step 1 — Train the base (deep) model on churn

Use Trainer and fit() on pq_churn’s training table and suggested model plan. Save base_job_id.
train_churn = pq_churn.generate_training_table()
model_plan_churn: ModelPlan = pq_churn.suggest_model_plan()

trainer = Trainer(graph, pq_churn, model_plan_churn)
result: TrainingJobResult = trainer.fit(train_churn, non_blocking=False)
base_job_id = result.job_id

Step 2 — Suggest and train the distilled fraud model

On the fraud query’s PredictiveQuery object (pq_fraud in the example), call suggest_distilled_model_plan() with base_model_id=base_job_id. The platform checks graph/encoder alignment and that embedding keys resolve to base entities (in the example, customers is a base entity, so each transaction uses the deep model’s customer embedding via transactions.customer_id). Train with DistillationTrainer on the same graph and the fraud query’s training table—not the churn table. The YAML block below shows the shape of the distillation section inside the returned DistilledModelPlan. Compare those fields to the distillation block in your plan object—numeric offsets, keys, and hop strings can differ from the example.
from kumoai.trainer import DistillationTrainer, DistilledModelPlan

train_fraud = pq_fraud.generate_training_table()

distilled_plan: DistilledModelPlan = pq_fraud.suggest_distilled_model_plan(
    base_model_id=base_job_id,
)
distiller = DistillationTrainer(distilled_plan, base_job_id)
dist_result = distiller.fit(graph, train_fraud, non_blocking=False)
distilled_job_id = dist_result.job_id

Example distillation section

The following YAML is not something you paste into the SDK by hand; it mirrors part of what suggest_distilled_model_plan() returns inside DistilledModelPlan.
  • “embedding_keys“ — Foreign keys on the fact row (here transactions) that point to base entities whose deep embeddings are attached for distillation (here transactions.customer_id → customer embedding).
  • “max_embedding_offset“ / “min_embedding_offset“ — How far back (and how “fresh”) the base embedding is allowed to be relative to the prediction time; your plan may use different values.
  • “real_time_offset“ — How RTI history is anchored in time relative to the request (confirm on your plan).
  • “real_time_interactions“ — Maps an RTI hop key (a path through the graph) to a maximum sequence length (here 32 recent transactions along the hop below). That same hop string appears again in Triton input tensor names at inference.
distillation:
  embedding_keys:
    - transactions.customer_id
  max_embedding_offset:
    value: 14
    unit: days
  min_embedding_offset:
    value: 12
    unit: hours
  real_time_offset:
    value: 1
    unit: hours
  real_time_interactions:
    transactions.customer_id->customers.customer_id->transactions.customer_id: 32

Step 3 — Batch prediction for embeddings (export input)

The export step needs a finished batch prediction job whose outputs include embeddings from the base (churn) model; see end_to_end_flow. Load the base trainer, build a churn prediction table, and call predict() with embeddings in output_types (details in Batch Prediction in trainer). Keep bp_job_id.
from kumoai.trainer import BatchPredictionJobResult
from kumoai.trainer.config import OutputConfig

base_trainer = Trainer.load(base_job_id)
pred_table = pq_churn.generate_prediction_table(non_blocking=True)
bp_result = base_trainer.predict(
    graph=graph,
    prediction_table=pred_table,
    output_config=OutputConfig(
        output_types={"predictions", "embeddings"},
        output_connector=connector,
        output_table_name="churn_embeddings_export_input",
    ),
    training_job_id=base_job_id,
    non_blocking=False,
)
assert isinstance(bp_result, BatchPredictionJobResult)
bp_job_id = bp_result.job_id

Step 4 — Export artifacts

export_model() (also kumoai.export_model) with ModelOutputConfig copies the online serving model directory and bundles embeddings.parquet from bp_job_id into output_path; see end_to_end_flow. Use non_blocking=True for an ArtifactExportJob, or False to block until ArtifactExportResult. Object storage: Export targets S3-style URIs (s3://…) in typical flows. Contact Kumo if you need to export to another blob store.
The export_model / ModelOutputConfig API does not ask you for a model name string. A fixed serving-side name (for example online-model) can be applied when Kumo wires your bundle into managed inference, without changing this SDK call.
from kumoai.trainer import ModelOutputConfig, export_model

config = ModelOutputConfig(
    training_job_id=distilled_job_id,
    batch_prediction_job_id=bp_job_id,
    output_path="s3://your-bucket/path/to/serving-bundle/",
)
export_job = export_model(config, non_blocking=True)
export_result = export_job.attach()

Step 5 — Deploy and run inference

Step 4 produces a Triton model repository: the online serving model layout plus bundled embeddings.parquet (and related artifacts), ready to load in NVIDIA Triton Inference Server. See the Triton Inference Server documentation for how Triton loads model repositories and exposes the HTTP/gRPC V2 inference API. Managed deployment: Hosting exported model artifacts in production via KServe with Triton is arranged through Kumo. Contact your Kumo team for setup, URLs, and authentication. Self-managed deployment: If you already operate a Triton-compatible inference stack, Kumo can provide a container image and guidance to run it with the artifacts from Step 4. Contact your Kumo team for details. Request shape (example below, batch size 1): Inputs are a flat list of named tensors, exactly as in your exported config.pbtxt.
  • “anchor_time“ — INT64, shape [1, 1], nanoseconds since Unix epoch.
  • Fact row{table}.{column} on the scored entity (in the example, transactions.*), including the embedding foreign key (e.g. transactions.customer_id).
  • RTI history{RTI_key}:{column} where RTI_key matches real_time_interactions in the distillation plan (in the example, the three-segment hop transactions.customer_id->customers.customer_id->transactions.customer_id). Use the same feature columns as on the fact row when both are modeled (e.g. price, article_id, t_dat). Shape [1, seq_len, 1] with actual seq_len (no zero-padding to the configured max such as 32).
Authoritative names, dtypes, and dimensions always come from your config.pbtxt.
{
  "inputs": [
    {
      "name": "anchor_time",
      "datatype": "INT64",
      "shape": [1, 1],
      "data": [1710000000000000000]
    },
    {
      "name": "transactions.customer_id",
      "datatype": "INT64",
      "shape": [1, 1],
      "data": [42]
    },
    {
      "name": "transactions.transaction_id",
      "datatype": "INT64",
      "shape": [1, 1],
      "data": [1001]
    },
    {
      "name": "transactions.price",
      "datatype": "FP32",
      "shape": [1, 1],
      "data": [29.99]
    },
    {
      "name": "transactions.article_id",
      "datatype": "INT64",
      "shape": [1, 1],
      "data": [5408]
    },
    {
      "name": "transactions.t_dat",
      "datatype": "INT64",
      "shape": [1, 1],
      "data": [1710086400000000000]
    },
    {
      "name": "transactions.customer_id->customers.customer_id->transactions.customer_id:price",
      "datatype": "FP32",
      "shape": [1, 3, 1],
      "data": [19.99, 4.99, 12.5]
    },
    {
      "name": "transactions.customer_id->customers.customer_id->transactions.customer_id:article_id",
      "datatype": "INT64",
      "shape": [1, 3, 1],
      "data": [108, 775, 5408]
    },
    {
      "name": "transactions.customer_id->customers.customer_id->transactions.customer_id:t_dat",
      "datatype": "INT64",
      "shape": [1, 3, 1],
      "data": [1709904000000000000, 1710163200000000000, 1710336000000000000]
    }
  ]
}
In this example, 42 is the customer id whose base-model embedding comes from embeddings.parquet. Fact tensors describe the current transaction; RTI tensors use seq_len = 3 (three prior interactions). String/categorical columns may use BYTES in Triton—follow your generated config. Example infer call (Triton V2 HTTP on localhost:8000; replace host, port, and model name—see NVIDIA’s Triton docs for response format and errors):
curl -sS -X POST "http://localhost:8000/v2/models/<MODEL_NAME>/infer" \
  -H "Content-Type: application/json" \
  -d @payload.json

See also

  • introduction — start here if you have not built a graph and trained a first model yet.
  • trainer — trainer, batch prediction, distillation, and artifact export API reference.