02/28/2024
Graph Machine Learning at the Scale of Modern Data Warehouses
Introduction
What is Deep Learning?
Deep learning is a subset of machine learning (ML) that uses neural networks with multiple layers (hence “deep”) to learn and extract patterns from data. Inspired by the structure and function of the human brain, neural networks are composed of interconnected nodes (i.e., neurons) organized into multiple layers, each layer processing information passed from the previous layer, enabling the extraction of increasingly abstract features from input data.
Deep learning has significantly impacted ML in several ways:
- Feature Representation: Deep learning models automatically learn hierarchical representations of data, eliminating the need for handcrafted feature engineering. This ability to automatically learn features from raw data has led to improved performance in various tasks.
- Scalability: Deep learning algorithms efficiently handle large amounts of data. The availability of massive datasets and computational resources allows deep learning models to dramatically scale and tackle complex problems in areas such as computer vision, natural language processing, and speech recognition.
- Performance: Deep learning models achieve state-of-the-art performance in numerous tasks, surpassing traditional ML approaches in various domains. This includes tasks such as image classification, object detection, machine translation, and sentiment analysis.
- Generalization: Deep learning models generalize well to unseen data, capturing underlying patterns and structures in data. This ability to generalize effectively contributes to the robustness of deep learning models in real-world applications.
- Versatility: Deep learning techniques can be applied across a wide range of domains, including but not limited to image recognition, speech recognition, natural language processing, reinforcement learning, recommendation systems, and healthcare.
Overall, the impact of deep learning on ML has been transformative, unlocking new levels of accuracy, efficiency, and capabilities across a wide range of applications. The ability to transfer learnings from data rich tasks to data poor tasks has further extended the capabilities of deep learning, allowing high model performance to be achieved with fewer labels.
Although deep learning has achieved superiority in processing visual data and natural languages, enterprise datasets more often consist of interconnected tabular data with numerous attribute-rich entities. These data structures can be more aptly visualized as a graph, with tables linked via primary and foreign keys. Graph neural networks (GNNs) can therefore deliver deep learning benefits through graph-based enterprise data.
What is a Graph Neural Network (GNN)?
Graph neural networks (GNNs) are a class of deep learning models designed to operate on graph-structured data. Unlike traditional neural networks that process fixed-size vectors or sequences, GNNs can directly operate on data with an arbitrary graph structure. This makes them particularly suitable for tasks involving relational data, to include social networks, knowledge graphs, enterprise data, and more.
The core concept behind GNNs involves deriving representations for individual nodes within a graph by consolidating information from their neighboring nodes. This process unfolds through a sequence of message-passing iterations, wherein each node receives messages from its neighbors, integrates them, and subsequently transmits its own message to neighboring nodes. During each iteration, the GNN computes a hidden state for each node based on its current state and the received messages, followed by an update utilizing a nonlinear activation function. Updated hidden states are then propagated to subsequent GNN layers in an iterative procedure that continues until a final representation is established for each node, which can be utilized downstream for various node classification, link prediction, and graph-level prediction tasks. GNNs serve as potent tools for effectively capturing the complex relationships and dependencies present in the graph data.
Comparing Graph Machine Learning to Tabular Machine Learning
Conventional tabular ML typically involves several distinct steps, including problem formulation, feature engineering, algorithm selection, and dataset construction, followed by model training using various frameworks. However, due to the multitude of individual steps involved, traditional tabular ML is often error-prone, resembling a trial-and-error approach where various methods are attempted to find out which one works best. In this approach, problem formulation and mapping must be done from scratch for each use case, making it difficult to generalize across multiple tasks. In addition, having to juggle multiple frameworks and their idiosyncrasies further increases the problem complexity.
In terms of scalability and generalizability, graph-based learning is superior to conventional methods in a diverse range of tasks. Problem formation becomes notably simpler, as each use case translates into just one of a few graph ML tasks. Once a graph ML task is defined, GNNs autonomously aggregate and fuse information to achieve learning of intricate relational patterns at scale. GNNs excel at reasoning over multiple connections—a challenging feat for traditional models when it comes to input feature calculation and capture. GNNs’ learned representations therefore surpass manually engineered features in effectiveness and generalizability, making them better-suited for downstream tasks.
Given these numerous benefits, it’s no surprise that GNN adoption has surged in recent years. GNN-based ML models have consistently proven their efficacy in practical settings, and ML teams at leading organizations have successfully implemented them across a broad range of use cases including recommendation systems, personalization algorithms, fraud detection mechanisms, dynamic system forecasting, intricate network modeling, to name a few.
Graph Representation Learning with PyG
The following is a quick overview of how modern graph representation learning works, with a focus on PyTorch Geometric (PyG).
PyG is a Python library for deep learning on irregular input data, particularly graph-structured data. By providing a framework for implementing GNNs using PyTorch, a popular deep learning library, PyG offers various utilities for data preprocessing, dataset handling, and building custom graph neural network architectures. The library also includes pre-implemented GNN layers, graph convolutions, pooling operations, and message passing functions, simplifying the process of constructing and training GNN models. PyG is widely used in both research and industry settings for tasks such as node classification, graph classification, link prediction, and graph generation, among others.
How to implement PyG
PyG is designed to be flexible and modular, allowing users to easily define and experiment with different types of GNNs for various graph-based ML tasks.
Step 1: Create and instantiate graph datasets and transformations
PyG provides a variety of predefined graph datasets, including citation networks, social networks, and bioinformatics graphs. You can also create your own custom dataset by extending the PyG dataset class. Graph transformations allow you to perform preprocessing steps on your graphs such as adding self-loops, standardizing node features, and so on.
Step 2: Define how to obtain mini-batches from your dataset
In PyG, mini-batches are generated using the DataLoader class, which requires a graph dataset and a batch size as inputs. PyG offers various pre-implemented sampling methods for mini-batching, including random sampling and neighbor sampling for graph edges, among others. You also have the option to define custom sampling methods by designing a personalized sampler class.
Step 3: Design your custom GNN
This step involves either assembling customized GNN components from pre-defined building blocks or employing pre-made GNN models. PyG offers an array of predefined GNN layers, including graph convolutional networks, graph attention networks, and graph edge modules. Alternatively, users can tailor their own GNN layers by extending the foundational PyG message passing class. Because GNNs in PyG are modular in design, they can be stacked together seamlessly to form multilayer GNN architectures.
Step 4: Implement your own training and inference routines
PyG offers an API tailored for training and evaluating GNNs, mirroring the structure of the PyTorch API. The training process of a GNN entails specifying a loss function, optimizer, and invoking the train method on the GNN module. Inference involves calling the eval method on the GNN module, followed by making predictions on new graphs.
Drawbacks of using PyG in production
Although PyG offers a solid foundation for graph learning, operationalizing GNNs requires numerous additional functionalities that can be challenging to develop and scale effectively.
Graph creation
PyG requires graphs to adhere to either coordinate list (COO) or compressed sparse row (CSR) format to accommodate complex heterogeneous graphs; however, graph nodes and edges can contain any assortment of curated features. While PyG provides readily usable datasets, users must reconstruct graphs into their expected formats when using non-curated datasets commonly encountered in enterprise settings.. Managing graph creation, particularly on a large scale, presents non-trivial challenges.
Problem formulation
While PyG accommodates various graph-related ML tasks, problem formulation is up to the user, and they must translate their business problem into a supported PyG graph learning task type. Users are also responsible for curating training labels for a given task. During this process, temporal consistency must be preserved in label generation and neighbor sampling to prevent future entity data leakage in the ML pipeline. The same applies to predictions. When using PyG, mitigating data leaks and deploying GNN predictions can pose substantial challenges.
Customization
While GNNs support full customization from model architecture to training pipeline, determining the optimal model architecture is both data and task dependent. Users must consider numerous factors like the most suitable GNN for a particular task, the appropriate number of neighbors and hops to sample, the model’s temporal generalization ability, and issues related to class imbalances and overfitting. Moreover, when new data arrives or the graph structure changes, you must update the graph, retrain the model, and version the model outputs.
Kumo’s Graph Neural Network platform
Kumo simplifies the process of translating business challenges into graph-based ML tasks,facilitates secure connectivity to extensive data warehouses, and enables seamless query execution. Upon establishing a connection to a data source and defining a meta-graph, Kumo automatically generates the graph at the record level. The business problem can then be specified declaratively using predictive query (pquery) syntax. Pqueries are written using PQL, a SQL-like language interface for describing ML predictions.
These pqueries are internally compiled into query execution plans, encompassing the training plan for the corresponding ML task. As data within the data source evolves, the materialized graph and features undergo incremental updates and Kumo dynamically optimizes the graph structure for specific tasks. This optimization includes actions like introducing meta-paths between records in the same table to minimize the necessary hops for effective representation learning.
For model training, Kumo offers out-of-the-box few-shot AutoML capabilities specifically tailored to GNNs. Additionally, Kumo delivers enterprise-grade features such as explainability alongside MLOps functionalities.
How to implement Kumo
Kumo’s programming model for training and deploying GNNs follows a straightforward, five step approach:
Step 1: Create secure connectors to one or more data sources
Kumo connectors abstract physical, underlying data sources into a standardized interface for accessing metadata and data. Connectors are currently available for S3 and other cloud warehouses such as Snowflake, Google BigQuery, and Amazon Redshift. Kumo supports both CSV and Parquet file formats for datasets residing in an AWS S3 bucket, as well as several common partitioning schemes and direct tree layouts.
Kumo securely handles data ingestion and caches connections for subsequent processing tasks. Maintaining data cleanliness and addressing typing-related issues can be challenging when managing diverse data sources; however, Kumo takes care of these tasks automatically.
Step 2: Create one or more business graphs and define linkages
This step involves registering a meta-graph by incorporating tables and defining connections between them. Ideally, these connections should include primary key and foreign key relationships, with the meta-graph reflecting the graph at the schema level. Constructing and managing the actual graph at the record level can be difficult to create and maintain. However, Kumo’s backend automatically materializes this graph in a scalable way.
Step 3: Formulate the business problem with predictive queries on the business graphs
Users can declaratively define their business use cases using PQL, the pquery interface and language syntax. This approach simplifies the expression of business problems, eliminating concerns about mapping them to graph-based ML tasks.
Step 4: Train the predictive queries
During pquery training, Kumo automatically handles temporal accuracy to prevent data leakage. Users can create any number of pqueries for a given graph and leverage multiple options for model creation. In various AutoML processes, Kumo automatically deduces the task type, creates training labels, manages the training and evaluation data splits, handles target class upsampling/downsampling and tuning per allocated training budget, and selects the optimal training strategy and search space for few-shot AutoML processes. These configurations are tailored to each pquery based on its understanding of both the data and the task, encompassing GNN design space options like model architecture, optimization parameters, and more.
The choice of encoding for features is also part of the search space, selected automatically based on data understanding and statistics to circumvent error-prone and ad hoc manual feature engineering. Advanced users have the flexibility to customize the AutoML search space, adjusting the set of training experiments to be conducted.
Step 5: Run inference multiple times on the trained model
After pquery training, inferences are run on the best trained model. This inference process could potentially be executed multiple times per day, with predictions either sent to an AWS S3 bucket or directly stored in a data warehouse like Snowflake, Amazon Redshift, or Google BigQuery. Kumo’s GNNs are capable of generating both predictions and embeddings. ML practitioners typically integrate predictions directly into business applications, while embeddings are typically used in various downstream applications to enhance their performance.
In summary, Kumo’s GNNs offer a flexible and robust approach to conducting ML on enterprise graph data. The simplicity of Kumo’s programming model enables both seasoned ML practitioners and non-experts to achieve the fastest time-to-value.
Read more about how Kumo’s architecture resolves the numerous challenges in deploying graph neural networks at scale.