TheDocumentation Index
Fetch the complete documentation index at: https://kumo.ai/docs/llms.txt
Use this file to discover all available pages before exploring further.
kumoai.trainer module provides the Trainer class for training custom GNN models on a Graph and TrainingTable, and generating predictions with a PredictionTable. Models can be customized with ModelPlan, though the plan suggested by PredictiveQuery.suggest_model_plan() is typically sufficient for strong out-of-the-box performance.
Model Plan
AModelPlan defines the full parameter specification for training a Kumo model. It is composed of five sub-plans:
ColumnProcessingPlan— encoder overrides for individual columnsModelArchitecturePlan— GNN or Graph Transformer parametersNeighborSamplingPlan— subgraph sampling parametersOptimizationPlan— learning rate, batch size, epochs, and related settingsTrainingJobPlan— AutoML-level settings
After generating a default model plan with
PredictiveQuery.suggest_model_plan(), no further changes are required to train your first model. These options are available for fine-tuning.ModelPlan
The top-level model configuration object. Each sub-plan is accessible as an attribute.
AutoML job-level settings.
Encoder overrides for individual columns.
Subgraph sampling configuration.
Optimization hyperparameters.
GNN or Graph Transformer architecture parameters.
ColumnProcessingPlan
Specifies encoder overrides and missing value strategy overrides for individual table columns.
A mapping from
"table.column" to an encoder instance. Overrides Kumo’s auto-inferred encoder for that column.A mapping from semantic type to
NAStrategy. Overrides the default imputation strategy for all columns of that semantic type.ModelArchitecturePlan
Base class for architecture plans. Use GNNModelPlan or GraphTransformerModelPlan to configure a specific architecture.
GNNModelPlan
Configures a Graph Neural Network architecture.
Candidate hidden channel sizes for AutoML search.
Candidate aggregation function combinations for AutoML search.
Candidate dropout rates for AutoML search.
GraphTransformerModelPlan
Configures a Graph Transformer architecture.
Candidate hidden channel sizes.
Candidate number of transformer layers.
Candidate number of attention heads.
Candidate dropout rates.
Candidate positional encoding combinations.
NeighborSamplingPlan
Controls how Kumo samples subgraphs during training.
Candidate per-hop neighbor counts for AutoML search. Each inner list specifies the number of neighbors to sample at each hop.
Whether to sample neighbors from the entity table.
OptimizationPlan
Controls learning rate, batch size, epochs, and other training optimization parameters.
Maximum number of training epochs.
Maximum number of training steps per epoch.
Maximum number of validation steps.
Maximum number of test steps.
Candidate loss functions for AutoML search.
Candidate base learning rates.
Candidate weight decay values.
Candidate batch sizes.
Candidate early stopping configurations.
Candidate learning rate scheduler configurations.
Candidate majority class sampling ratios for imbalanced classification.
Candidate sample weighting modes.
TrainingJobPlan
AutoML job-level settings controlling the number of experiments and evaluation metrics.
Number of hyperparameter experiments to run during AutoML.
Evaluation metrics to compute.
The primary metric used to select the best model.
Whether to refit the best model on the combined train+validation set.
Whether to additionally refit on the full dataset (train+validation+test).
Training
Trainer
Trains a Kumo GNN model on a PredictiveQuery. The two primary methods are fit() (training) and predict() (batch inference).
The model plan specifying architecture, optimization, and sampling parameters.
model_plan property
Returns Optional[ModelPlan]
encoders property
Returns Optional[Dict[str, str]] — The encoder configuration used during training.
is_trained property
Returns bool — True if this trainer has been successfully fit and is ready for prediction.
fit()
Trains a model on the provided graph and training table.
The relational graph.
The training table generated from a
PredictiveQuery.If
True, returns a TrainingJob immediately rather than blocking.Optional key-value tags attached to the training job.
Training job ID to warm-start from (initializes from an existing model’s weights).
Union[TrainingJob, TrainingJobResult]
predict()
Generates batch predictions using the trained model.
The relational graph.
The prediction table generated from a
PredictiveQuery.S3 or object store path to write prediction outputs.
Connector to write predictions to.
Table name in the output connector.
If
True, returns a BatchPredictionJob immediately.Optional key-value tags attached to the prediction job.
Union[BatchPredictionJob, BatchPredictionJobResult]
load() staticmethod
Loads a trained Trainer from a completed training job.
The training job ID.
Trainer
TrainingJob
Represents an ongoing training job.
result()
Blocks until complete and returns the TrainingJobResult.
Returns TrainingJobResult
status()
Returns JobStatusReport
cancel()
Cancels the running training job.
metrics_so_far()
Returns Optional[ModelEvaluationMetrics] — Metrics computed so far during training.
progress()
Returns AutoTrainerProgress — Detailed progress information.
TrainingJobResult
Represents a completed training job.
The training job ID.
id property
Returns TrainingJobID
model_plan property
Returns ModelPlan — The model plan used in this training job.
training_table property
Returns Union[TrainingTableJob, TrainingTable]
predictive_query property
Returns PredictiveQuery — The predictive query that defined this training job.
tracking_url property
Returns str — URL to the training job in the Kumo UI.
metrics()
Returns ModelEvaluationMetrics — Evaluation metrics for the completed job.
holdout_df()
Returns pd.DataFrame — The holdout dataset as a DataFrame.
explain()
Returns per-entity feature importances for a predictive query.
The PQL query string.
Entity indices to explain. Explains all entities if
None.The run mode for explanation computation.
Per-hop neighbor counts for subgraph sampling.
The anchor time for temporal explanation.
pd.DataFrame
Batch Prediction
BatchPredictionJob
Represents an ongoing batch prediction job.
result()
Returns BatchPredictionJobResult
status()
Returns JobStatusReport
cancel()
Cancels the running batch prediction job.
BatchPredictionJobResult
Represents a completed batch prediction job.
data_df()
Returns pd.DataFrame — Prediction results.
data_urls()
Returns List[str] — Download URLs for prediction results.
summary()
Returns BatchPredictionJobSummary
Online Serving and Distillation
Distillation training andexport_model() produce a serving bundle (online model directory and embeddings.parquet from batch prediction) in storage you control. Inference uses NVIDIA Triton Inference Server to load that bundle. See the Online Serving guide for the end-to-end flow.
DistillationTrainer
Trains a shallow model for online serving by reusing representations (embeddings) from a base GNN training job.
The distilled model plan.
The training job ID of the base GNN model to distill from.
is_trained property
Returns bool
fit()
The relational graph.
The training table.
If
True, returns a TrainingJob immediately.Optional job tags.
Union[TrainingJob, TrainingJobResult]
load() classmethod
The training job ID.
DistillationTrainer
DistilledModelPlan
Model plan for distillation. Composed of TrainingJobPlan, ColumnProcessingPlan, OptimizationPlan, DistillationPlan, and a distillation-specific architecture plan.
DistillationPlan
Configuration for the distillation process, specifying embedding keys, time offsets, and real-time interaction settings.
Column keys used as embedding inputs.
Maximum time offset for embedding lookups.
Minimum time offset for embedding lookups.
Real-time interaction table configuration.
export_model()
Exports online serving model files and batch prediction embeddings to external storage for use with Triton Inference Server.
Specifies the training job, output path, and batch prediction job to bundle.
If
True, returns an ArtifactExportJob immediately.Union[ArtifactExportJob, ArtifactExportResult]
ModelOutputConfig
Output configuration for export_model(). Specifies output types, destination connector, and table name.
The output types to produce. Valid values:
"predictions", "embeddings", or both.The connector to write outputs to. Local download only if
None.Table name in the output connector. For Databricks, provide a
(schema, table) tuple.Additional metadata columns to include in prediction output. Options:
JOB_TIMESTAMP, ANCHOR_TIMESTAMP.ArtifactExportJob
Represents an ongoing model artifact export job.
id property
Returns str
result()
Returns ArtifactExportResult
status()
Returns JobStatus
cancel()
Returns bool — True if the job was successfully cancelled.
ArtifactExportResult
Represents a completed model artifact export.
tracking_url()
Returns str — URL to the export job in the Kumo UI.