
How the Kumo Training Backend Works
Large-scale training of graph transformers: why precomputing neighborhoods fails, how online sampling solves the problem, and what it takes to train on 10-30 billion node graphs.
The First Attempt: Precomputing Neighborhoods with Spark
The natural first approach is to precompute every neighborhood offline. Run a Spark job that materializes each target node's multi-hop subgraph as a self-contained training example. Write the results to disk. Then feed them into a standard PyTorch training loop.
This approach breaks down in four specific ways.
Data explosion
A single 3-hop neighborhood can contain hundreds of megabytes of node and edge features. Multiply that by millions of training examples, and the materialized dataset reaches terabytes. Storage costs become a bottleneck before training even begins.
Slow iteration
Any change to the sampling configuration (hop depth, fanout, edge types) requires a full re-run of the Spark pipeline. In practice, this means hours of recomputation for each experiment. Teams that start this way quickly find that their iteration speed is gated by data preparation, not model development.
Stale data
Static dumps do not reflect evolving graphs. Real-world graphs change constantly: new users sign up, new transactions occur, new relationships form. A precomputed snapshot is outdated the moment it finishes writing. Models trained on stale neighborhoods learn patterns that no longer hold.
Temporal leakage
This is the most dangerous failure mode. Without careful filtering, future edges infiltrate training sets. A fraud model trained with temporal leakage sees transactions that occurred after the prediction timestamp, inflating offline metrics while failing in production. Enforcing temporal correctness in a batch Spark job requires meticulous timestamp tracking across every join, and most implementations get it wrong.
The Shift: Online Sampling
The solution is to stop precomputing neighborhoods and start generating them on the fly. Kumo's training backend replaces the Spark pipeline with three tightly integrated components that run in parallel during training.
Graph Sampler (RAM)
Holds the graph structure in memory. Expands neighborhoods per configuration: hop depth, fanouts per edge type, metapaths, and time constraints. Automatically enforces temporal correctness by only surfacing edges valid at the training example's timestamp.
Feature Store (SSD)
Receives node IDs from the sampler and fetches attributes from SSD storage. Enables terabyte-scale feature handling without GPU memory overload.
GPU Trainer (PyTorch)
Standard GNN/transformer layer pipeline. Loss computation and backpropagation follow familiar patterns. Each batch is a freshly constructed neighborhood.
How the components interact
The graph sampler sits in RAM and serves as a high-speed index over the graph topology. When the trainer requests a batch, the sampler expands neighborhoods according to the current configuration (hop depth, per-edge fanouts, temporal constraints) and returns a set of node IDs. Those IDs are passed to the feature store, which looks up the corresponding attributes on SSD. The assembled subgraphs, now with full features, are sent to the GPU for the forward pass.
This architecture generates thousands of neighborhoods per second, maintaining GPU saturation without disk overflow. Configuration changes (different fanouts, deeper hops, new edge types) take effect immediately: there is no re-materialization step.
Temporal correctness by construction
The sampler enforces temporal correctness at the point of neighborhood construction. For each training example, it filters edges based on the example's timestamp, ensuring that only historically valid relationships appear in the subgraph. This eliminates the temporal leakage problem that plagues batch precomputation.
What This Architecture Enables
The online sampling architecture removes the constraints that forced earlier systems into shallow, small-scale graph learning. Here is what becomes possible.
| Parameter | Capability |
|---|---|
| Graph Size | 10-30 billion nodes |
| Subgraph Depth | 6+ hops (on sparse graphs) |
| Inference Mode | Batch or Online (inductive) |
| Sampling Strategy | Static or Temporal (per metapath) |
Deeper models
With precomputation, going beyond 2 hops was prohibitively expensive because each additional hop multiplied the materialized data size. Online sampling makes 6+ hop models practical on sparse graphs, because only the neighborhoods actually needed for the current batch are constructed in memory.
Billion-node graphs
The graph structure (adjacency lists) lives in RAM, while features live on SSD. A single 1TB server can hold graphs with 10 to 30 billion nodes. The separation of structure and features is what makes this possible: storing adjacency lists is far cheaper per node than storing full feature vectors.
Batch and online inference
The same sampler that generates training neighborhoods also serves inference. For batch scoring, it iterates over all target entities. For online scoring, it constructs a single neighborhood on demand when a prediction is requested. Because the sampler settings are consistent across training and inference, there is no train/serve skew.
Relating Back to PyTorch
If you have trained a model in PyTorch, the Kumo training backend maps directly to concepts you already know.
| Kumo Component | PyTorch Equivalent | What It Emits |
|---|---|---|
| Graph Sampler | DataLoader | Subgraphs (instead of rows/images) |
| Feature Store | collate_fn | Node/edge attributes attached to subgraphs |
| Model | nn.Module | GraphSAGE, GAT, transformers (all compatible) |
The training loop itself is unchanged. Fetch a batch, run the forward pass, compute the loss, call optimizer.step(). The only difference is that each batch is a freshly constructed neighborhood rather than a static tensor.
This is an important design decision. By keeping the training loop standard, the system remains compatible with PyTorch ecosystem tools: learning rate schedulers, gradient clipping, mixed-precision training, distributed data parallelism. Nothing about the optimizer or the model architecture needs to change.
Working with Real Graphs
Production graphs are not simple. They have multiple node types, multiple edge types, and timestamped interactions. The sampling configuration needs to reflect this complexity, giving fine-grained control over how neighborhoods are constructed for each edge type and each hop.
The num_neighbors parameter
The core configuration is num_neighbors, which controls how many neighbors are sampled at each hop. The simplest version sets a default fanout per hop:
num_neighbors:
- hop1:
default: 16
- hop2:
default: 8This samples 16 neighbors at the first hop and 8 at the second hop, applied uniformly across all edge types.
Per-edge type customization
Real graphs require per-edge control. Some relationships are high-signal and deserve larger fanouts. Others are noise and should be excluded entirely. The configuration supports overrides per edge type at each hop:
num_neighbors:
- hop1:
default: 16
USERS.USER_ID->TRANSACTIONS.USER_ID: 128
- hop2:
default: 8
TRANSACTIONS.STORE_ID->STORES.STORE_ID: 0In this example, user-to-transaction edges get a fanout of 128 at hop 1 (because transaction history is highly predictive), while transaction-to-store edges are disabled entirely at hop 2 (by setting the fanout to 0). Every other edge type uses the default.
Temporal sampling strategies
For timestamped edges, the sampler supports multiple strategies. The default weights newer edges more heavily, reflecting the intuition that recent interactions are more predictive than old ones. An alternative strategy samples uniformly across time, giving equal weight to historical and recent edges. The system can also infer optimal fanouts per connection based on degree statistics, automatically allocating more budget to high-degree nodes and less to low-degree ones.
Putting It Into Practice
Here is how the full workflow comes together, from defining a prediction task to running inference.
1. Define Task
Write a Predictive Query specifying the prediction target: churn, fraud risk, lifetime value, or any other entity-level or link-level task.
2. Generate ModelPlan
The SDK suggests a baseline configuration: model architecture, sampling depth, and fanouts. This plan is fully customizable.
3. Configure Sampling
Set num_neighbors with per-hop and per-edge-type fanouts. Choose temporal sampling strategies per metapath.
4. Train Model
The Trainer API launches training. The online sampling backend generates fresh neighborhoods for each batch, maintaining GPU saturation.
5. Run Inference
Use the same sampler settings for batch or online scoring. Consistent sampler configuration eliminates train/serve skew.
The bottleneck is the neighborhoods, not the math
The central insight behind the Kumo training backend is that the computational bottleneck in graph learning is not the GNN or transformer layers. It is the construction of input neighborhoods. The matrix multiplications and attention computations are well optimized by existing GPU kernels. What slows teams down is the upstream data pipeline: materializing neighborhoods, enforcing temporal correctness, iterating on sampling configurations, and scaling to production graph sizes.
By moving neighborhood construction online and separating graph topology (RAM) from features (SSD), the Kumo backend eliminates this bottleneck. Billion-node graphs become trainable. 6+ hop models become practical. Temporal correctness is enforced by construction. And the PyTorch training loop stays exactly the same.
Try KumoRFM on your own data
Zero-shot predictions are free. Fine-tuning is available with a trial.