Why GNNs stress feature stores
A tabular ML model fetches features for one entity per prediction. A 2-layer GNN with fanout [15, 10] fetches features for up to 150 entities per prediction. A batch of 1024 predictions can require features for 150K distinct nodes.
This amplification factor means feature store latency and throughput are critical for GNN performance. A feature store that works fine for tabular ML (10 lookups per prediction) may collapse under GNN load (150 lookups per prediction).
PyG’s FeatureStore and GraphStore interface
from torch_geometric.data import FeatureStore, GraphStore
import torch
class RedisFeatureStore(FeatureStore):
"""Custom feature store backed by Redis."""
def __init__(self, redis_client):
super().__init__()
self.redis = redis_client
def get_tensor(self, attr):
"""Fetch features for a group of nodes."""
node_type = attr.group_name
indices = attr.index
# Batch Redis lookup
keys = [f"{node_type}:{i}" for i in indices.tolist()]
values = self.redis.mget(keys)
return torch.stack([
torch.frombuffer(v, dtype=torch.float32)
for v in values
])
def put_tensor(self, tensor, attr):
"""Store features (for materialization)."""
node_type = attr.group_name
for i, row in enumerate(tensor):
self.redis.set(
f"{node_type}:{attr.index[i]}",
row.numpy().tobytes(),
)Implement get_tensor and put_tensor. PyG's NeighborLoader calls these automatically during sampling.
Architecture patterns
Training: offline store
During training, you access features in bulk (millions of nodes per epoch). Use an offline store optimized for throughput:
- Parquet files on S3/GCS: Cheapest. Load into memory at epoch start. Works up to ~100M nodes.
- Snowflake/BigQuery: SQL-based access. Good for feature engineering. Higher latency per query but handles any scale.
- Feast offline store: Standardized interface to BigQuery, Snowflake, or Redshift. Provides point-in-time feature retrieval for temporal correctness.
Serving: online store
During inference, you need sub-5ms feature retrieval for individual nodes:
- Redis: Sub-1ms per lookup. Good up to ~10M nodes per instance. Use Redis Cluster for larger scales.
- DynamoDB: Sub-5ms with auto-scaling. Cost-effective for sporadic access patterns.
- Feast online store: Abstracts over Redis/DynamoDB with a consistent API. Handles materialization from offline to online automatically.
Optimizing feature access for GNNs
- Batch lookups: Never fetch features one node at a time. Always batch into a single mget/multi-row query. This reduces round-trips from 150 to 1 per batch.
- Prefetching: During neighbor sampling on CPU, prefetch features for already-sampled nodes on a background thread. By the time the GPU needs features, they are ready.
- Feature compression: Store features as float16 or int8 in the feature store. Decompress on GPU. This halves network transfer and storage costs.
- Hub caching: Cache features for high-degree nodes (top 1% by degree) locally. These nodes appear in most subgraphs, so caching them eliminates 30-50% of feature store lookups.
What breaks in production
- Feature staleness: Features computed daily become stale for real-time predictions. Define freshness SLAs per feature and set up refresh pipelines accordingly.
- Schema evolution: Adding a new feature column requires updating the feature store, the graph construction pipeline, and the model input layer simultaneously. Use feature versioning.
- Cost at scale: Storing 128-dim float32 features for 1B nodes in Redis requires ~512 GB of RAM ($5K+/month). Use tiered storage: hot nodes in Redis, cold nodes in S3 with on-demand loading.