PyTorch Geometric Integration
This page describes how to integrate FalkorDB with PyTorch Geometric (PyG) using the falkordb-pyg package.
Overview
falkordb-pyg implements PyG’s Remote Backend interface (FeatureStore + GraphStore) for FalkorDB. Once connected, you can plug the backend directly into NeighborLoader, LinkNeighborLoader, and other standard PyG data loaders — no changes to your model or training code required.
Key features:
- Zero-copy lazy loading — features and topology are fetched on demand and cached locally
- Heterogeneous graph support — multiple node and edge types
- Automatic ID remapping — non-contiguous FalkorDB node IDs are mapped to contiguous PyG indices transparently
- Drop-in replacement — works with any PyG workflow that accepts a remote backend
Installation
Prerequisite: PyTorch and PyTorch Geometric must be installed first. Follow the PyTorch and PyG installation guides for your platform and CUDA version.
pip install falkordb-pyg
Or install with PyTorch and PyG included (CPU-only defaults):
pip install 'falkordb-pyg[torch]'
Requires: Python ≥ 3.10, PyTorch ≥ 2.0, PyTorch Geometric ≥ 2.4, FalkorDB Python client ≥ 1.0.
Quick Start
1. Start FalkorDB
docker run -p 6379:6379 falkordb/falkordb:latest
2. Load data into FalkorDB
from falkordb import FalkorDB
db = FalkorDB(host="localhost", port=6379)
graph = db.select_graph("papers")
# Create nodes with features and labels
graph.query("CREATE (:paper {x: [1.0, 0.0, 1.0], y: 0})")
graph.query("CREATE (:paper {x: [0.0, 1.0, 0.5], y: 1})")
# Create edges
graph.query(
"MATCH (a:paper), (b:paper) "
"WHERE ID(a) = 0 AND ID(b) = 1 "
"CREATE (a)-[:cites]->(b)"
)
3. Create the remote backend
from falkordb_pyg import get_remote_backend
feature_store, graph_store = get_remote_backend(
host="localhost",
port=6379,
graph_name="papers",
)
4. Use with NeighborLoader
from torch_geometric.loader import NeighborLoader
import torch
train_nodes = torch.tensor([0])
loader = NeighborLoader(
data=(feature_store, graph_store),
num_neighbors={("paper", "cites", "paper"): [10, 10]},
batch_size=32,
input_nodes=("paper", train_nodes),
)
for batch in loader:
paper_x = batch["paper"].x
paper_y = batch["paper"].y
edge_index = batch["paper", "cites", "paper"].edge_index
# ... forward pass, loss, backward ...
API Reference
get_remote_backend
from falkordb_pyg import get_remote_backend
feature_store, graph_store = get_remote_backend(
host="localhost", # FalkorDB / Redis hostname
port=6379, # FalkorDB / Redis port
graph_name="default", # Graph name in FalkorDB
node_type_to_label=None, # Dict[str, str] — PyG type → FalkorDB label
edge_type_to_rel=None, # Dict[Tuple, str] — PyG edge triple → rel type
)
Returns a (FalkorDBFeatureStore, FalkorDBGraphStore) tuple.
FalkorDBFeatureStore
Implements torch_geometric.data.FeatureStore.
| Method | Description |
|---|---|
_get_tensor(attr) | Fetch a node-feature tensor (lazy, cached) |
_put_tensor(tensor, attr) | Store a tensor in the local cache |
_remove_tensor(attr) | Remove a cached tensor |
_get_tensor_size(attr) | Return the shape of a tensor |
get_all_tensor_attrs() | List all registered TensorAttr objects |
Constructor:
FalkorDBFeatureStore(
graph, # falkordb.Graph instance
node_type_to_label=None, # Optional Dict[str, str]
)
FalkorDBGraphStore
Implements torch_geometric.data.GraphStore.
| Method | Description |
|---|---|
_get_edge_index(attr) | Fetch a COO edge index (lazy, cached) |
_put_edge_index(edge_index, attr) | Store a COO edge index in the local cache |
_remove_edge_index(attr) | Remove a cached edge index |
get_all_edge_attrs() | List all registered EdgeAttr objects |
Constructor:
FalkorDBGraphStore(
graph, # falkordb.Graph instance
node_type_to_label=None, # Optional Dict[str, str]
edge_type_to_rel=None, # Optional Dict[Tuple[str,str,str], str]
)
NodeIDMapper
Bidirectional mapping between FalkorDB internal node IDs and contiguous 0-based PyG indices.
from falkordb_pyg.utils import NodeIDMapper
mapper = NodeIDMapper(falkordb_ids=[100, 200, 300])
mapper.falkor_to_pyg(200) # -> 1
mapper.pyg_to_falkor(1) # -> 200
mapper.num_nodes # -> 3
Node ID Remapping
FalkorDB assigns internal integer IDs to nodes that may not be contiguous or start at zero. falkordb-pyg transparently builds a NodeIDMapper for each node type on first access, converting FalkorDB IDs to contiguous PyG indices. Edges referencing IDs not present in the mapper are silently dropped.
Example: Training GraphSAGE
See examples/train_example.py for a complete GraphSAGE training script that:
- Populates FalkorDB with a synthetic paper-citation graph
- Creates a remote backend with
get_remote_backend - Fetches features and edge indices lazily
- Trains a two-layer GraphSAGE classifier