diff --git a/README.md b/README.md
index 43fe9f6b..211b1f09 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,5 @@
# nx-arangodb
-
-
@@ -55,7 +50,7 @@ Benefits of having ArangoDB as a backend to NetworkX include:
## Does this replace NetworkX?
-No. This is a plugin to NetworkX, which means that you can use NetworkX as you normally would, but with the added benefit of persisting your graphs to a database.
+Not really. This is a plugin to NetworkX, which means that you can use NetworkX as you normally would, but with the added benefit of persisting your graphs to a database.
```python
import os
@@ -111,7 +106,7 @@ pip install nx-cugraph-cu12 --extra-index-url https://pypi.nvidia.com
pip install nx-arangodb
```
-## What are the easiests ways to set up ArangoDB?
+## How can I set up ArangoDB?
**1) Local Instance via Docker**
@@ -149,7 +144,7 @@ os.environ["DATABASE_NAME"] = credentials["dbName"]
# ...
```
-## How does Algorithm Dispatching work?
+## How does algorithm dispatching work?
`nx-arangodb` will automatically dispatch algorithm calls to either CPU or GPU based on if `nx-cugraph` is installed. We rely on a rust-based library called [phenolrs](https://github.com/arangoml/phenolrs) to retrieve ArangoDB Graphs as fast as possible.
diff --git a/_nx_arangodb/__init__.py b/_nx_arangodb/__init__.py
index 0498f571..616e961b 100644
--- a/_nx_arangodb/__init__.py
+++ b/_nx_arangodb/__init__.py
@@ -26,8 +26,8 @@
"project": "nx-arangodb",
"package": "nx_arangodb",
"url": "https://github.com/arangodb/nx-arangodb",
- "short_summary": "Remote storage backend.",
- # "description": "TODO",
+ "short_summary": "ArangoDB storage backend to NetworkX.",
+ "description": "Persist, maintain, and reload NetworkX graphs with ArangoDB.",
"functions": {
# BEGIN: functions
"shortest_path",
@@ -81,7 +81,6 @@ def get_info():
"db_name": None,
"read_parallelism": None,
"read_batch_size": None,
- "write_batch_size": None,
"use_gpu": True,
}
diff --git a/nx_arangodb/algorithms/shortest_paths/generic.py b/nx_arangodb/algorithms/shortest_paths/generic.py
index f5a9025b..7328b257 100644
--- a/nx_arangodb/algorithms/shortest_paths/generic.py
+++ b/nx_arangodb/algorithms/shortest_paths/generic.py
@@ -54,7 +54,7 @@ def shortest_path(
"weight": weight,
}
- result = list(G.aql(query, bind_vars=bind_vars))
+ result = list(G.query(query, bind_vars=bind_vars))
if not result:
raise nx.NodeNotFound(f"Either source {source} or target {target} is not in G")
diff --git a/nx_arangodb/classes/dict/node.py b/nx_arangodb/classes/dict/node.py
index f41b1666..e55c5171 100644
--- a/nx_arangodb/classes/dict/node.py
+++ b/nx_arangodb/classes/dict/node.py
@@ -19,6 +19,7 @@
doc_delete,
doc_insert,
doc_update,
+ edges_delete,
get_arangodb_graph,
get_node_id,
get_node_type_and_id,
@@ -303,21 +304,7 @@ def __delitem__(self, key: str) -> None:
if not self.graph.has_vertex(node_id):
raise KeyError(key)
- # TODO: wrap in edges_delete() method
- remove_statements = "\n".join(
- f"REMOVE e IN `{edge_def['edge_collection']}` OPTIONS {{ignoreErrors: true}}" # noqa
- for edge_def in self.graph.edge_definitions()
- )
-
- query = f"""
- FOR v, e IN 1..1 ANY @src_node_id GRAPH @graph_name
- {remove_statements}
- """
-
- bind_vars = {"src_node_id": node_id, "graph_name": self.graph.name}
-
- aql(self.db, query, bind_vars)
- #####
+ edges_delete(self.db, self.graph, node_id)
doc_delete(self.db, node_id)
diff --git a/nx_arangodb/classes/digraph.py b/nx_arangodb/classes/digraph.py
index e2bea65c..ccf7d65f 100644
--- a/nx_arangodb/classes/digraph.py
+++ b/nx_arangodb/classes/digraph.py
@@ -66,6 +66,13 @@ def __init__(
self.remove_node = self.remove_node_override
self.reverse = self.reverse_override
+ assert isinstance(self._succ, AdjListOuterDict)
+ assert isinstance(self._pred, AdjListOuterDict)
+ self._succ.mirror = self._pred
+ self._pred.mirror = self._succ
+ self._succ.traversal_direction = TraversalDirection.OUTBOUND
+ self._pred.traversal_direction = TraversalDirection.INBOUND
+
if (
not self.is_multigraph()
and incoming_graph_data is not None
@@ -78,6 +85,8 @@ def __init__(
#######################
# TODO?
+ # If we want to continue with "Experimental Views" we need to implement the
+ # InEdgeView and OutEdgeView classes.
# @cached_property
# def in_edges(self):
# pass
diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py
index 19f09324..c9b73822 100644
--- a/nx_arangodb/classes/function.py
+++ b/nx_arangodb/classes/function.py
@@ -578,6 +578,24 @@ def doc_delete(db: StandardDatabase, id: str, **kwargs: Any) -> None:
db.delete_document(id, silent=True, **kwargs)
+def edges_delete(
+ db: StandardDatabase, graph: Graph, src_node_id: str, **kwargs: Any
+) -> None:
+ remove_statements = "\n".join(
+ f"REMOVE e IN `{edge_def['edge_collection']}` OPTIONS {{ignoreErrors: true}}" # noqa
+ for edge_def in graph.edge_definitions()
+ )
+
+ query = f"""
+ FOR v, e IN 1..1 ANY @src_node_id GRAPH @graph_name
+ {remove_statements}
+ """
+
+ bind_vars = {"src_node_id": src_node_id, "graph_name": graph.name}
+
+ aql(db, query, bind_vars)
+
+
def doc_insert(
db: StandardDatabase,
collection: str,
diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py
index 53e7abd6..2bb23831 100644
--- a/nx_arangodb/classes/graph.py
+++ b/nx_arangodb/classes/graph.py
@@ -27,8 +27,6 @@
node_attr_dict_factory,
node_dict_factory,
)
-from .dict.adj import AdjListOuterDict
-from .enum import TraversalDirection
from .function import get_node_id
from .reportviews import CustomEdgeView, CustomNodeView
@@ -76,27 +74,21 @@ def __init__(
*args: Any,
**kwargs: Any,
):
- self._db = None
+ self.__db = None
self.__name = None
self.__use_experimental_views = use_experimental_views
+ self.__graph_exists_in_db = False
- self._graph_exists_in_db = False
- self._loaded_incoming_graph_data = False
-
- self._set_db(db)
- if self._db is not None:
- self._set_graph_name(name)
+ self.__set_db(db)
+ if self.__db is not None:
+ self.__set_graph_name(name)
- self.read_parallelism = read_parallelism
- self.read_batch_size = read_batch_size
- self.write_batch_size = write_batch_size
-
- self._set_edge_collections_attributes_to_fetch(edge_collections_attributes)
+ self.__set_edge_collections_attributes(edge_collections_attributes)
# NOTE: Need to revisit these...
# self.maintain_node_dict_cache = False
# self.maintain_adj_dict_cache = False
- self.use_nx_cache = True
+ # self.use_nx_cache = True
self.use_nxcg_cache = True
self.nxcg_graph = None
@@ -111,7 +103,9 @@ def __init__(
# m = "Must set **graph_name** if passing **incoming_graph_data**"
# raise ValueError(m)
- if self._graph_exists_in_db:
+ self._loaded_incoming_graph_data = False
+
+ if self.__graph_exists_in_db:
if incoming_graph_data is not None:
m = "Cannot pass both **incoming_graph_data** and **name** yet if the already graph exists" # noqa: E501
raise NotImplementedError(m)
@@ -152,7 +146,7 @@ def edge_type_func(u: str, v: str) -> str:
self.default_node_type = default_node_type
self._set_factory_methods()
- self._set_arangodb_backend_config()
+ self.__set_arangodb_backend_config(read_parallelism, read_batch_size)
elif self.__name:
@@ -181,7 +175,7 @@ def edge_type_func(u: str, v: str) -> str:
self.__name,
incoming_graph_data,
edge_definitions=edge_definitions,
- batch_size=self.write_batch_size,
+ batch_size=write_batch_size,
use_async=write_async,
)
@@ -194,23 +188,15 @@ def edge_type_func(u: str, v: str) -> str:
)
self._set_factory_methods()
- self._set_arangodb_backend_config()
+ self.__set_arangodb_backend_config(read_parallelism, read_batch_size)
logger.info(f"Graph '{name}' created.")
- self._graph_exists_in_db = True
+ self.__graph_exists_in_db = True
if self.__name is not None:
kwargs["name"] = self.__name
super().__init__(*args, **kwargs)
- if self.is_directed() and self.graph_exists_in_db:
- assert isinstance(self._succ, AdjListOuterDict)
- assert isinstance(self._pred, AdjListOuterDict)
- self._succ.mirror = self._pred
- self._pred.mirror = self._succ
- self._succ.traversal_direction = TraversalDirection.OUTBOUND
- self._pred.traversal_direction = TraversalDirection.INBOUND
-
if self.graph_exists_in_db:
self.copy = self.copy_override
self.subgraph = self.subgraph_override
@@ -220,6 +206,11 @@ def edge_type_func(u: str, v: str) -> str:
self.number_of_edges = self.number_of_edges_override
self.nbunch_iter = self.nbunch_iter_override
+ # If incoming_graph_data wasn't loaded by the NetworkX Adapter,
+ # then we can rely on the CRUD operations of the modified dictionaries
+ # to load the data into the graph. However, if the graph is directed
+ # or multigraph, then we leave that responsibility to the child classes
+ # due to the possibility of additional CRUD-based method overrides.
if (
not self.is_directed()
and not self.is_multigraph()
@@ -232,21 +223,6 @@ def edge_type_func(u: str, v: str) -> str:
# Init helper methods #
#######################
- def _set_arangodb_backend_config(self) -> None:
- if not all([self._host, self._username, self._password, self._db_name]):
- m = "Must set all environment variables to use the ArangoDB Backend with an existing graph" # noqa: E501
- raise OSError(m)
-
- config = nx.config.backends.arangodb
- config.host = self._host
- config.username = self._username
- config.password = self._password
- config.db_name = self._db_name
- config.read_parallelism = self.read_parallelism
- config.read_batch_size = self.read_batch_size
- config.write_batch_size = self.write_batch_size
- config.use_gpu = True # Only used by default if nx-cugraph is available
-
def _set_factory_methods(self) -> None:
"""Set the factory methods for the graph, _node, and _adj dictionaries.
@@ -281,58 +257,33 @@ def _set_factory_methods(self) -> None:
*adj_args, self.symmetrize_edges
)
- def _set_edge_collections_attributes_to_fetch(
- self, attributes: set[str] | None
+ def __set_arangodb_backend_config(
+ self, read_parallelism: int, read_batch_size: int
) -> None:
- if attributes is None:
- self._edge_collections_attributes = set()
- return
- if len(attributes) > 0:
- self._edge_collections_attributes = attributes
- if "_id" not in attributes:
- self._edge_collections_attributes.add("_id")
-
- ###########
- # Getters #
- ###########
-
- @property
- def db(self) -> StandardDatabase:
- if self._db is None:
- raise DatabaseNotSet("Database not set")
-
- return self._db
-
- @property
- def name(self) -> str:
- if self.__name is None:
- raise GraphNameNotSet("Graph name not set")
-
- return self.__name
-
- @name.setter
- def name(self, s):
- if self.__name is not None:
- raise ValueError("Existing graph cannot be renamed")
+ if not all([self._host, self._username, self._password, self._db_name]):
+ m = "Must set all environment variables to use the ArangoDB Backend with an existing graph" # noqa: E501
+ raise OSError(m)
- self.__name = s
- m = "Note that setting the graph name does not create the graph in the database" # noqa: E501
- logger.warning(m)
- nx._clear_cache(self)
+ config = nx.config.backends.arangodb
+ config.host = self._host
+ config.username = self._username
+ config.password = self._password
+ config.db_name = self._db_name
+ config.read_parallelism = read_parallelism
+ config.read_batch_size = read_batch_size
+ config.use_gpu = True # Only used by default if nx-cugraph is available
- @property
- def graph_exists_in_db(self) -> bool:
- return self._graph_exists_in_db
+ def __set_edge_collections_attributes(self, attributes: set[str] | None) -> None:
+ if not attributes:
+ self._edge_collections_attributes = set()
+ return
- @property
- def get_edge_attributes(self) -> set[str]:
- return self._edge_collections_attributes
+ self._edge_collections_attributes = attributes
- ###########
- # Setters #
- ###########
+ if "_id" not in attributes:
+ self._edge_collections_attributes.add("_id")
- def _set_db(self, db: StandardDatabase | None = None) -> None:
+ def __set_db(self, db: Any = None) -> None:
self._host = os.getenv("DATABASE_HOST")
self._username = os.getenv("DATABASE_USERNAME")
self._password = os.getenv("DATABASE_PASSWORD")
@@ -344,27 +295,26 @@ def _set_db(self, db: StandardDatabase | None = None) -> None:
raise TypeError(m)
db.version()
- self._db = db
+ self.__db = db
return
- # TODO: Raise a custom exception if any of the environment
- # variables are missing. For now, we'll just set db to None.
if not all([self._host, self._username, self._password, self._db_name]):
- self._db = None
- logger.warning("Database environment variables not set")
+ m = "Database environment variables not set. Can't connect to the database"
+ logger.warning(m)
+ self.__db = None
return
- self._db = ArangoClient(hosts=self._host, request_timeout=None).db(
+ self.__db = ArangoClient(hosts=self._host, request_timeout=None).db(
self._db_name, self._username, self._password, verify=True
)
- def _set_graph_name(self, name: str | None = None) -> None:
- if self._db is None:
+ def __set_graph_name(self, name: Any = None) -> None:
+ if self.__db is None:
m = "Cannot set graph name without setting the database first"
raise DatabaseNotSet(m)
if name is None:
- self._graph_exists_in_db = False
+ self.__graph_exists_in_db = False
logger.warning(f"**name** not set for {self.__class__.__name__}")
return
@@ -372,9 +322,51 @@ def _set_graph_name(self, name: str | None = None) -> None:
raise TypeError("**name** must be a string")
self.__name = name
- self._graph_exists_in_db = self.db.has_graph(name)
+ self.__graph_exists_in_db = self.db.has_graph(name)
+
+ logger.info(f"Graph '{name}' exists: {self.__graph_exists_in_db}")
+
+ ###########
+ # Getters #
+ ###########
+
+ @property
+ def db(self) -> StandardDatabase:
+ if self.__db is None:
+ raise DatabaseNotSet("Database not set")
+
+ return self.__db
+
+ @property
+ def name(self) -> str:
+ if self.__name is None:
+ raise GraphNameNotSet("Graph name not set")
+
+ return self.__name
- logger.info(f"Graph '{name}' exists: {self._graph_exists_in_db}")
+ @name.setter
+ def name(self, s):
+ if self.graph_exists_in_db:
+ raise ValueError("Existing graph cannot be renamed")
+
+ m = "Note that setting the graph name does not create the graph in the database" # noqa: E501
+ logger.warning(m)
+
+ self.__name = s
+ self.graph["name"] = s
+ nx._clear_cache(self)
+
+ @property
+ def graph_exists_in_db(self) -> bool:
+ return self.__graph_exists_in_db
+
+ @property
+ def edge_attributes(self) -> set[str]:
+ return self._edge_collections_attributes
+
+ ###########
+ # Setters #
+ ###########
####################
# ArangoDB Methods #
@@ -383,7 +375,9 @@ def _set_graph_name(self, name: str | None = None) -> None:
def clear_nxcg_cache(self):
self.nxcg_graph = None
- def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Cursor:
+ def query(
+ self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any
+ ) -> Cursor:
return nxadb.classes.function.aql(self.db, query, bind_vars, **kwargs)
# def pull(self) -> None:
@@ -399,7 +393,7 @@ def chat(
m = "LLM dependencies not installed. Install with **pip install nx-arangodb[llm]**" # noqa: E501
raise ModuleNotFoundError(m)
- if not self._graph_exists_in_db:
+ if not self.__graph_exists_in_db:
m = "Cannot chat without a graph in the database"
raise GraphNameNotSet(m)
@@ -440,7 +434,7 @@ def adj(self):
def edges(self):
if self.__use_experimental_views and self.graph_exists_in_db:
if self.is_directed():
- logger.warning("CustomEdgeView for Directed Graphs not yet implemented")
+ logger.warning("CustomEdgeView for DiGraphs not yet implemented")
return super().edges
if self.is_multigraph():
@@ -463,7 +457,11 @@ def copy_override(self, *args, **kwargs):
return G
def subgraph_override(self, nbunch):
- raise NotImplementedError("Subgraphing is not yet implemented")
+ if self.graph_exists_in_db:
+ m = "Subgraphing an ArangoDB Graph is not yet implemented"
+ raise NotImplementedError(m)
+
+ return super().subgraph(nbunch)
def clear_override(self):
logger.info("Note that clearing only erases the local cache")
diff --git a/nx_arangodb/classes/multidigraph.py b/nx_arangodb/classes/multidigraph.py
index d9a57f9e..fe25eb93 100644
--- a/nx_arangodb/classes/multidigraph.py
+++ b/nx_arangodb/classes/multidigraph.py
@@ -7,7 +7,6 @@
import nx_arangodb as nxadb
from nx_arangodb.classes.digraph import DiGraph
from nx_arangodb.classes.multigraph import MultiGraph
-from nx_arangodb.logger import logger
networkx_api = nxadb.utils.decorators.networkx_class(nx.MultiDiGraph) # type: ignore
diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py
index 16a724d0..17458b90 100644
--- a/nx_arangodb/convert.py
+++ b/nx_arangodb/convert.py
@@ -13,7 +13,6 @@
try:
import cupy as cp
- import numpy as np
import nx_cugraph as nxcg
GPU_AVAILABLE = True
@@ -127,9 +126,9 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph:
load_node_dict=True,
load_adj_dict=True,
load_coo=False,
- edge_collections_attributes=G.get_edge_attributes,
+ edge_collections_attributes=G.edge_attributes,
load_all_vertex_attributes=False,
- load_all_edge_attributes=do_load_all_edge_attributes(G.get_edge_attributes),
+ load_all_edge_attributes=do_load_all_edge_attributes(G.edge_attributes),
is_directed=G.is_directed(),
is_multigraph=G.is_multigraph(),
symmetrize_edges_if_directed=G.symmetrize_edges if G.is_directed() else False,
@@ -185,9 +184,9 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
load_node_dict=False,
load_adj_dict=False,
load_coo=True,
- edge_collections_attributes=G.get_edge_attributes,
+ edge_collections_attributes=G.edge_attributes,
load_all_vertex_attributes=False, # not used
- load_all_edge_attributes=do_load_all_edge_attributes(G.get_edge_attributes),
+ load_all_edge_attributes=do_load_all_edge_attributes(G.edge_attributes),
is_directed=G.is_directed(),
is_multigraph=G.is_multigraph(),
symmetrize_edges_if_directed=(
diff --git a/nx_arangodb/interface.py b/nx_arangodb/interface.py
index 725c110d..47048752 100644
--- a/nx_arangodb/interface.py
+++ b/nx_arangodb/interface.py
@@ -62,7 +62,6 @@ def _auto_func(func_name: str, /, *args: Any, **kwargs: Any) -> Any:
"""
dfunc = _registered_algorithms[func_name]
- # TODO: Use `nx.config.backends.arangodb.backend_priority` instead
backend_priority = []
if nxadb.convert.GPU_AVAILABLE and nx.config.backends.arangodb.use_gpu:
backend_priority.append("cugraph")
@@ -143,6 +142,8 @@ def _run_with_backend(
# TODO: Convert to nxadb.Graph?
# What would this look like? Create a new graph in ArangoDB?
# Or just establish a remote connection?
+ # For now, if dfunc._returns_graph is True, it will return a
+ # regular nx.Graph object.
# if dfunc._returns_graph:
# raise NotImplementedError("Returning Graphs not implemented yet")
diff --git a/pyproject.toml b/pyproject.toml
index 6f930e30..7d181d32 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,7 +17,6 @@ authors = [
license = { text = "Apache 2.0" }
requires-python = ">=3.10"
classifiers = [
- "Development Status :: 3 - Alpha",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
@@ -30,7 +29,6 @@ classifiers = [
]
dependencies = [
"networkx>=3.0,<=3.3",
- "numpy>=1.23,<2.0a0",
"phenolrs",
"python-arango",
"adbnx-adapter"
@@ -55,9 +53,6 @@ dev = [
"sphinx",
"sphinx_rtd_theme",
]
-gpu = [
- "nx-cugraph-cu12 @ https://pypi.nvidia.com"
-]
llm = [
"langchain~=0.2.14",
"langchain-openai~=0.1.22",
@@ -65,7 +60,7 @@ llm = [
]
[project.urls]
-Homepage = "https://github.com/aMahanna/nx-arangodb"
+Homepage = "https://github.com/arangodb/nx-arangodb"
# "plugin" used in nx version < 3.2
[project.entry-points."networkx.plugins"]
diff --git a/run_nx_tests.sh b/run_nx_tests.sh
index 977e991b..6d0c499f 100755
--- a/run_nx_tests.sh
+++ b/run_nx_tests.sh
@@ -10,7 +10,7 @@ NETWORKX_FALLBACK_TO_NX=True \
--cov-report= \
"$@"
coverage report \
- --include="*/nx_arangodb/algorithms/*" \
+ --include="*/nx_arangodb/classes/*" \
--omit=__init__.py \
--show-missing \
--rcfile=$(dirname $0)/pyproject.toml
diff --git a/tests/test.py b/tests/test.py
index 79277180..58ea73f8 100644
--- a/tests/test.py
+++ b/tests/test.py
@@ -65,23 +65,6 @@ def assert_pagerank(
assert_same_dict_values(d1, d2, digit)
-def assert_louvain(l1: list[set[Any]], l2: list[set[Any]]) -> None:
- # TODO: Implement some kind of comparison
- # Reason: Louvain returns different results on different runs
- assert l1
- assert l2
- pass
-
-
-def assert_k_components(
- d1: dict[int, list[set[Any]]], d2: dict[int, list[set[Any]]]
-) -> None:
- assert d1
- assert d2
- assert d1.keys() == d2.keys(), "Dictionaries have different keys"
- assert d1 == d2
-
-
def test_db(load_karate_graph: Any) -> None:
assert db.version()
@@ -312,7 +295,7 @@ def assert_symmetry_differences(
assert_func(r_13_orig, r_9_orig)
-def test_shortest_path_remote_algorithm(load_karate_graph: Any) -> None:
+def test_shortest_path(load_karate_graph: Any) -> None:
G_1 = nxadb.Graph(name="KarateGraph")
G_2 = nxadb.DiGraph(name="KarateGraph")
@@ -321,8 +304,15 @@ def test_shortest_path_remote_algorithm(load_karate_graph: Any) -> None:
r_3 = nx.shortest_path(G_2, source="person/0", target="person/33")
r_4 = nx.shortest_path(G_2, source="person/0", target="person/33", weight="weight")
+ r_5 = nx.shortest_path.orig_func(
+ G_1, source="person/0", target="person/33", weight="weight"
+ )
+ r_6 = nx.shortest_path.orig_func(
+ G_2, source="person/0", target="person/33", weight="weight"
+ )
+
assert r_1 == r_3
- assert r_2 == r_4
+ assert r_2 == r_4 == r_5 == r_6
assert r_1 != r_2
assert r_3 != r_4