Skip to main content

Documentation Index

Fetch the complete documentation index at: https://kumo.ai/docs/llms.txt

Use this file to discover all available pages before exploring further.

The kumoai.trainer module provides the Trainer class for training custom GNN models on a Graph and TrainingTable, and generating predictions with a PredictionTable. Models can be customized with ModelPlan, though the plan suggested by PredictiveQuery.suggest_model_plan() is typically sufficient for strong out-of-the-box performance.

Model Plan

A ModelPlan defines the full parameter specification for training a Kumo model. It is composed of five sub-plans:
  • ColumnProcessingPlan — encoder overrides for individual columns
  • ModelArchitecturePlan — GNN or Graph Transformer parameters
  • NeighborSamplingPlan — subgraph sampling parameters
  • OptimizationPlan — learning rate, batch size, epochs, and related settings
  • TrainingJobPlan — AutoML-level settings
After generating a default model plan with PredictiveQuery.suggest_model_plan(), no further changes are required to train your first model. These options are available for fine-tuning.

ModelPlan

The top-level model configuration object. Each sub-plan is accessible as an attribute.
model_plan = pq.suggest_model_plan()

# Customize a sub-plan:
model_plan.optimization.max_epochs = 50
model_plan.optimization.base_lr = [1e-3, 3e-3]
training_job
TrainingJobPlan
default:"TrainingJobPlan()"
AutoML job-level settings.
column_processing
ColumnProcessingPlan
default:"ColumnProcessingPlan()"
Encoder overrides for individual columns.
neighbor_sampling
NeighborSamplingPlan
default:"NeighborSamplingPlan()"
Subgraph sampling configuration.
optimization
OptimizationPlan
default:"OptimizationPlan()"
Optimization hyperparameters.
model_architecture
ModelArchitecturePlan
default:"ModelArchitecturePlan()"
GNN or Graph Transformer architecture parameters.

ColumnProcessingPlan

Specifies encoder overrides and missing value strategy overrides for individual table columns.
from kumoai.encoder import GloVe
from kumoai.trainer import ColumnProcessingPlan

plan = ColumnProcessingPlan(
    encoder_overrides={"products.description": GloVe(model_name="glove-wiki-gigaword-50")},
)
encoder_overrides
Optional[Dict[str, Encoder]]
default:"None"
A mapping from "table.column" to an encoder instance. Overrides Kumo’s auto-inferred encoder for that column.
na_strategy
Optional[Dict[Stype, NAStrategy]]
default:"None"
A mapping from semantic type to NAStrategy. Overrides the default imputation strategy for all columns of that semantic type.

ModelArchitecturePlan

Base class for architecture plans. Use GNNModelPlan or GraphTransformerModelPlan to configure a specific architecture.

GNNModelPlan

Configures a Graph Neural Network architecture.
from kumoai.trainer import GNNModelPlan

arch = GNNModelPlan(channels=[64, 128], aggregation=[["sum", "mean"]])
channels
List[int]
default:"inferred"
Candidate hidden channel sizes for AutoML search.
aggregation
List[List[AggregationType]]
default:"inferred"
Candidate aggregation function combinations for AutoML search.
dropout
List[float]
default:"inferred"
Candidate dropout rates for AutoML search.

GraphTransformerModelPlan

Configures a Graph Transformer architecture.
from kumoai.trainer import GraphTransformerModelPlan

arch = GraphTransformerModelPlan(num_layers=[2, 4], num_heads=[4, 8])
channels
List[int]
default:"inferred"
Candidate hidden channel sizes.
num_layers
List[int]
default:"inferred"
Candidate number of transformer layers.
num_heads
List[int]
default:"inferred"
Candidate number of attention heads.
dropout
List[float]
default:"inferred"
Candidate dropout rates.
positional_encodings
List[List[PositionalEncodingType]]
default:"inferred"
Candidate positional encoding combinations.

NeighborSamplingPlan

Controls how Kumo samples subgraphs during training.
num_neighbors
List[List[int]]
default:"inferred"
Candidate per-hop neighbor counts for AutoML search. Each inner list specifies the number of neighbors to sample at each hop.
sample_from_entity_table
bool
default:"inferred"
Whether to sample neighbors from the entity table.

OptimizationPlan

Controls learning rate, batch size, epochs, and other training optimization parameters.
max_epochs
int
default:"inferred"
Maximum number of training epochs.
max_steps_per_epoch
int
default:"inferred"
Maximum number of training steps per epoch.
max_val_steps
int
default:"inferred"
Maximum number of validation steps.
max_test_steps
int
default:"inferred"
Maximum number of test steps.
loss
List[Union[str, LossConfig]]
default:"inferred"
Candidate loss functions for AutoML search.
base_lr
List[float]
default:"inferred"
Candidate base learning rates.
weight_decay
List[float]
default:"inferred"
Candidate weight decay values.
batch_size
List[int]
default:"inferred"
Candidate batch sizes.
early_stopping
List[Optional[EarlyStoppingConfig]]
default:"inferred"
Candidate early stopping configurations.
lr_scheduler
List[Optional[LRSchedulerConfig]]
default:"inferred"
Candidate learning rate scheduler configurations.
majority_sampling_ratio
List[Optional[float]]
default:"inferred"
Candidate majority class sampling ratios for imbalanced classification.
weight_mode
List[Optional[WeightMode]]
default:"inferred"
Candidate sample weighting modes.

TrainingJobPlan

AutoML job-level settings controlling the number of experiments and evaluation metrics.
num_experiments
int
default:"inferred"
Number of hyperparameter experiments to run during AutoML.
metrics
List[str]
default:"inferred"
Evaluation metrics to compute.
tune_metric
str
default:"inferred"
The primary metric used to select the best model.
refit_trainval
bool
default:"True"
Whether to refit the best model on the combined train+validation set.
refit_full
bool
default:"False"
Whether to additionally refit on the full dataset (train+validation+test).

Training

Trainer

Trains a Kumo GNN model on a PredictiveQuery. The two primary methods are fit() (training) and predict() (batch inference).
from kumoai.trainer import Trainer

trainer = Trainer(model_plan=model_plan)
result = trainer.fit(graph=graph, train_table=train_table)
model_plan
ModelPlan
required
The model plan specifying architecture, optimization, and sampling parameters.

model_plan property

Returns Optional[ModelPlan]

encoders property

Returns Optional[Dict[str, str]] — The encoder configuration used during training.

is_trained property

Returns boolTrue if this trainer has been successfully fit and is ready for prediction.

fit()

Trains a model on the provided graph and training table.
graph
Graph
required
The relational graph.
train_table
Union[TrainingTable, TrainingTableJob]
required
The training table generated from a PredictiveQuery.
non_blocking
bool
default:"False"
If True, returns a TrainingJob immediately rather than blocking.
custom_tags
Mapping[str, str]
default:"{}"
Optional key-value tags attached to the training job.
warm_start_job_id
str
default:"None"
Training job ID to warm-start from (initializes from an existing model’s weights).
Returns Union[TrainingJob, TrainingJobResult]

predict()

Generates batch predictions using the trained model.
graph
Graph
required
The relational graph.
prediction_table
Union[PredictionTable, PredictionTableJob]
required
The prediction table generated from a PredictiveQuery.
output_dir
str
default:"None"
S3 or object store path to write prediction outputs.
output_connector
Connector
default:"None"
Connector to write predictions to.
output_table_name
str
default:"None"
Table name in the output connector.
non_blocking
bool
default:"False"
If True, returns a BatchPredictionJob immediately.
custom_tags
Mapping[str, str]
default:"{}"
Optional key-value tags attached to the prediction job.
Returns Union[BatchPredictionJob, BatchPredictionJobResult]

load() staticmethod

Loads a trained Trainer from a completed training job.
training_job_id
str
required
The training job ID.
Returns Trainer

TrainingJob

Represents an ongoing training job.

result()

Blocks until complete and returns the TrainingJobResult. Returns TrainingJobResult

status()

Returns JobStatusReport

cancel()

Cancels the running training job.

metrics_so_far()

Returns Optional[ModelEvaluationMetrics] — Metrics computed so far during training.

progress()

Returns AutoTrainerProgress — Detailed progress information.

TrainingJobResult

Represents a completed training job.
result = trainer.fit(graph=graph, train_table=train_table)
metrics = result.metrics()
job_id
TrainingJobID
required
The training job ID.

id property

Returns TrainingJobID

model_plan property

Returns ModelPlan — The model plan used in this training job.

training_table property

Returns Union[TrainingTableJob, TrainingTable]

predictive_query property

Returns PredictiveQuery — The predictive query that defined this training job.

tracking_url property

Returns str — URL to the training job in the Kumo UI.

metrics()

Returns ModelEvaluationMetrics — Evaluation metrics for the completed job.

holdout_df()

Returns pd.DataFrame — The holdout dataset as a DataFrame.

explain()

Returns per-entity feature importances for a predictive query.
query
str
required
The PQL query string.
indices
Sequence[Union[str, float, int]]
default:"None"
Entity indices to explain. Explains all entities if None.
run_mode
RunMode
default:"RunMode.FAST"
The run mode for explanation computation.
num_neighbors
List[int]
default:"None"
Per-hop neighbor counts for subgraph sampling.
anchor_time
Union[pd.Timestamp, Literal['entity']]
default:"None"
The anchor time for temporal explanation.
Returns pd.DataFrame

Batch Prediction

BatchPredictionJob

Represents an ongoing batch prediction job.

result()

Returns BatchPredictionJobResult

status()

Returns JobStatusReport

cancel()

Cancels the running batch prediction job.

BatchPredictionJobResult

Represents a completed batch prediction job.

data_df()

Returns pd.DataFrame — Prediction results.

data_urls()

Returns List[str] — Download URLs for prediction results.

summary()

Returns BatchPredictionJobSummary

Online Serving and Distillation

Distillation training and export_model() produce a serving bundle (online model directory and embeddings.parquet from batch prediction) in storage you control. Inference uses NVIDIA Triton Inference Server to load that bundle. See the Online Serving guide for the end-to-end flow.

DistillationTrainer

Trains a shallow model for online serving by reusing representations (embeddings) from a base GNN training job.
from kumoai.trainer import DistillationTrainer

distil = DistillationTrainer(model_plan=distil_plan, base_training_job_id="<job_id>")
result = distil.fit(graph=graph, train_table=train_table)
model_plan
DistilledModelPlan
required
The distilled model plan.
base_training_job_id
str
required
The training job ID of the base GNN model to distill from.

is_trained property

Returns bool

fit()

graph
Graph
required
The relational graph.
train_table
Union[TrainingTable, TrainingTableJob]
required
The training table.
non_blocking
bool
default:"False"
If True, returns a TrainingJob immediately.
custom_tags
Mapping[str, str]
default:"{}"
Optional job tags.
Returns Union[TrainingJob, TrainingJobResult]

load() classmethod

job_id
str
required
The training job ID.
Returns DistillationTrainer

DistilledModelPlan

Model plan for distillation. Composed of TrainingJobPlan, ColumnProcessingPlan, OptimizationPlan, DistillationPlan, and a distillation-specific architecture plan.

DistillationPlan

Configuration for the distillation process, specifying embedding keys, time offsets, and real-time interaction settings.
embedding_keys
List[str]
default:"inferred"
Column keys used as embedding inputs.
max_embedding_offset
TimeOffset
default:"inferred"
Maximum time offset for embedding lookups.
min_embedding_offset
TimeOffset
default:"inferred"
Minimum time offset for embedding lookups.
real_time_interactions
Dict[str, int]
default:"{}"
Real-time interaction table configuration.

export_model()

Exports online serving model files and batch prediction embeddings to external storage for use with Triton Inference Server.
from kumoai.trainer import export_model, ModelOutputConfig

result = export_model(config=output_config, non_blocking=False)
config
ModelOutputConfig
required
Specifies the training job, output path, and batch prediction job to bundle.
non_blocking
bool
default:"True"
If True, returns an ArtifactExportJob immediately.
Returns Union[ArtifactExportJob, ArtifactExportResult]

ModelOutputConfig

Output configuration for export_model(). Specifies output types, destination connector, and table name.
output_types
Set[str]
required
The output types to produce. Valid values: "predictions", "embeddings", or both.
output_connector
Connector
default:"None"
The connector to write outputs to. Local download only if None.
output_table_name
Union[str, Tuple[str, str]]
default:"None"
Table name in the output connector. For Databricks, provide a (schema, table) tuple.
output_metadata_fields
List[MetadataField]
default:"None"
Additional metadata columns to include in prediction output. Options: JOB_TIMESTAMP, ANCHOR_TIMESTAMP.

ArtifactExportJob

Represents an ongoing model artifact export job.

id property

Returns str

result()

Returns ArtifactExportResult

status()

Returns JobStatus

cancel()

Returns boolTrue if the job was successfully cancelled.

ArtifactExportResult

Represents a completed model artifact export.

tracking_url()

Returns str — URL to the export job in the Kumo UI.