diff --git a/.gitignore b/.gitignore index 07ddb5d3..752223e5 100644 --- a/.gitignore +++ b/.gitignore @@ -122,4 +122,7 @@ node_modules/ # test results *_results.txt -*.egg-info \ No newline at end of file +*.egg-info + +# VSCode +.vscode/ \ No newline at end of file diff --git a/nx_arangodb/classes/digraph.py b/nx_arangodb/classes/digraph.py index e836081b..310f33b5 100644 --- a/nx_arangodb/classes/digraph.py +++ b/nx_arangodb/classes/digraph.py @@ -29,7 +29,7 @@ def to_networkx_class(cls) -> type[nx.DiGraph]: def __init__( self, graph_name: str | None = None, - # default_node_type: str = "nxadb_nodes", + # default_node_type: str = "node", # edge_type_func: Callable[[str, str], str] = lambda u, v: f"{u}_to_{v}", *args: Any, **kwargs: Any, diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py index 84ef83aa..cada904a 100644 --- a/nx_arangodb/classes/function.py +++ b/nx_arangodb/classes/function.py @@ -86,9 +86,12 @@ def get_arangodb_graph( def key_is_string(func: Callable[..., Any]) -> Any: """Decorator to check if the key is a string.""" - def wrapper(self: Any, key: str, *args: Any, **kwargs: Any) -> Any: + def wrapper(self: Any, key: Any, *args: Any, **kwargs: Any) -> Any: if not isinstance(key, str): - raise TypeError(f"'{key}' is not a string.") + if not isinstance(key, (int, float)): + raise TypeError(f"{key} cannot be casted to string.") + + key = str(key) return func(self, key, *args, **kwargs) @@ -98,12 +101,29 @@ def wrapper(self: Any, key: str, *args: Any, **kwargs: Any) -> Any: def keys_are_strings(func: Callable[..., Any]) -> Any: """Decorator to check if the keys are strings.""" - def wrapper(self: Any, dict: dict[Any, Any], *args: Any, **kwargs: Any) -> Any: - for key in dict: + def wrapper( + self: Any, data: dict[Any, Any] | zip[Any], *args: Any, **kwargs: Any + ) -> Any: + data_dict = {} + + items: Any + if isinstance(data, dict): + items = data.items() + elif isinstance(data, zip): + items = list(data) + else: + raise TypeError(f"Decorator found unsupported type: {type(data)}.") + + for key, value in items: if not isinstance(key, str): - raise TypeError(f"'{key}' is not a string.") + if not isinstance(key, (int, float)): + raise TypeError(f"{key} cannot be casted to string.") + + key = str(key) + + data_dict[key] = value - return func(self, dict, *args, **kwargs) + return func(self, data_dict, *args, **kwargs) return wrapper @@ -126,12 +146,22 @@ def wrapper(self: Any, key: str, *args: Any, **kwargs: Any) -> Any: def keys_are_not_reserved(func: Any) -> Any: """Decorator to check if the keys are not reserved.""" - def wrapper(self: Any, dict: dict[Any, Any], *args: Any, **kwargs: Any) -> Any: - for key in dict: + def wrapper( + self: Any, data: dict[Any, Any] | zip[Any], *args: Any, **kwargs: Any + ) -> Any: + keys: Any + if isinstance(data, dict): + keys = data.keys() + elif isinstance(data, zip): + keys = (key for key, _ in list(data)) + else: + raise TypeError(f"Decorator found unsupported type: {type(data)}.") + + for key in keys: if key in RESERVED_KEYS: raise KeyError(f"'{key}' is a reserved key.") - return func(self, dict, *args, **kwargs) + return func(self, data, *args, **kwargs) return wrapper @@ -229,7 +259,6 @@ def aql_edge_get( graph_name: str, direction: str, ) -> Any | None: - # TODO: need the use of DISTINCT return_clause = "DISTINCT e" if direction == "ANY" else "e" return aql_edge( db, @@ -248,7 +277,6 @@ def aql_edge_id( graph_name: str, direction: str, ) -> str | None: - # TODO: need the use of DISTINCT return_clause = "DISTINCT e._id" if direction == "ANY" else "e._id" result = aql_edge( db, diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index 6912c0b7..95e9e2d6 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -41,8 +41,9 @@ def to_networkx_class(cls) -> type[nx.Graph]: def __init__( self, graph_name: str | None = None, - default_node_type: str = "nxadb_node", + default_node_type: str = "node", edge_type_func: Callable[[str, str], str] = lambda u, v: f"{u}_to_{v}", + db: StandardDatabase | None = None, *args: Any, **kwargs: Any, ): @@ -50,7 +51,7 @@ def __init__( self.__graph_name = None self.__graph_exists_in_db = False - self.__set_db() + self.__set_db(db) if self.__db is not None: self.__set_graph_name(graph_name) @@ -73,41 +74,47 @@ def __init__( self.edge_type_func = edge_type_func self.default_edge_type = edge_type_func(default_node_type, default_node_type) + # self.__qa_chain = None + incoming_graph_data = kwargs.get("incoming_graph_data") if self.__graph_exists_in_db: - self.adb_graph = self.db.graph(graph_name) - self.__create_default_collections() - self.__set_factory_methods() - - if incoming_graph_data: + if incoming_graph_data is not None: m = "Cannot pass both **incoming_graph_data** and **graph_name** yet if the already graph exists" # noqa: E501 raise NotImplementedError(m) - elif self.__graph_name and incoming_graph_data: - if not isinstance(incoming_graph_data, nx.Graph): - m = f"Type of **incoming_graph_data** not supported yet ({type(incoming_graph_data)})" # noqa: E501 - raise NotImplementedError(m) + self.adb_graph = self.db.graph(self.__graph_name) + self.__create_default_collections() + self.__set_factory_methods() - adapter = ADBNX_Adapter(self.db) - self.adb_graph = adapter.networkx_to_arangodb( - graph_name, - incoming_graph_data, - # TODO: Parameterize the edge definitions - # How can we work with a heterogenous **incoming_graph_data**? - edge_definitions=[ - { - "edge_collection": self.default_edge_type, - "from_vertex_collections": [self.default_node_type], - "to_vertex_collections": [self.default_node_type], - } - ], - ) + elif self.__graph_name and incoming_graph_data is not None: + # TODO: Parameterize the edge definitions + # How can we work with a heterogenous **incoming_graph_data**? + edge_definitions = [ + { + "edge_collection": self.default_edge_type, + "from_vertex_collections": [self.default_node_type], + "to_vertex_collections": [self.default_node_type], + } + ] + + if isinstance(incoming_graph_data, nx.Graph): + self.adb_graph = ADBNX_Adapter(self.db).networkx_to_arangodb( + self.__graph_name, + incoming_graph_data, + edge_definitions=edge_definitions, + ) + + # No longer need this (we've already populated the graph) + del kwargs["incoming_graph_data"] + + else: + self.adb_graph = self.db.create_graph( + self.__graph_name, + edge_definitions=edge_definitions, + ) self.__set_factory_methods() self.__graph_exists_in_db = True - del kwargs["incoming_graph_data"] - - # self.__qa_chain = None super().__init__(*args, **kwargs) @@ -187,6 +194,7 @@ def __set_db(self, db: StandardDatabase | None = None) -> None: m = "arango.database.StandardDatabase" raise TypeError(m) + db.version() self.__db = db return @@ -232,7 +240,6 @@ def __set_graph_name(self, graph_name: str | None = None) -> None: # ArangoDB Methods # #################### - # TODO: proper subgraphing! def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Cursor: return nxadb.classes.function.aql(self.db, query, bind_vars, **kwargs) @@ -267,12 +274,12 @@ def pull(self, load_node_dict=True, load_adj_dict=True, load_coo=True): :param load_node_dict: Load the node dictionary. Enabling this option will clear the existing node dictionary, and replace it with the node data from the database. Comes with - a remote reference to the database. <--- TODO: Should we paramaterize this? + a remote reference to the database. :type load_node_dict: bool :param load_adj_dict: Load the adjacency dictionary. Enabling this option will clear the existing adjacency dictionary, and replace it with the edge data from the database. Comes with - a remote reference to the database. <--- TODO: Should we paramaterize this? + a remote reference to the database. :type load_adj_dict: bool :param load_coo: Load the COO representation. If False, the src & dst indices will be empty, along with the node-ID-to-index mapping. diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index 15c89d02..bf0d31f3 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -44,12 +44,6 @@ def from_networkx( ) -> nxadb.Graph | nxadb.DiGraph: """Convert a networkx graph to nx_arangodb graph. - TEMPORARY ASSUMPTION: The nx_arangodb Graph is a subclass of networkx Graph. - Therefore, I'm going to assume that we _should_ be able instantiate an - nx_arangodb Graph using the **incoming_graph_data** parameter. - - TODO: The actual implementation should store the graph in ArangoDB. - Parameters ---------- G : networkx.Graph @@ -187,8 +181,7 @@ def _to_nx_graph( if isinstance(G, nx.Graph): return G - # TODO: handle cugraph.Graph - raise TypeError + raise TypeError(f"Expected nx_arangodb.Graph or nx.Graph; got {type(G)}") if GPU_ENABLED: diff --git a/pyproject.toml b/pyproject.toml index 559566d5..ab91102f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ dev = [ "Flake8-pyproject", "isort", "mypy", + "pandas", ] gpu = [ "nx-cugraph-cu12 @ https://pypi.nvidia.com" diff --git a/tests/conftest.py b/tests/conftest.py index ae2955ec..d2353c7a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,7 @@ def pytest_addoption(parser: Any) -> None: parser.addoption("--url", action="store", default="http://localhost:8529") parser.addoption("--dbName", action="store", default="_system") parser.addoption("--username", action="store", default="root") - parser.addoption("--password", action="store", default="passwd") + parser.addoption("--password", action="store", default="test") def pytest_configure(config: Any) -> None: diff --git a/tests/static/cluster.conf b/tests/static/cluster.conf index d33e07a3..6df5ec0b 100644 --- a/tests/static/cluster.conf +++ b/tests/static/cluster.conf @@ -8,7 +8,7 @@ port = 8528 jwt-secret = /tests/static/keyfile [args] -all.database.password = passwd +all.database.password = test all.database.extended-names = true all.log.api-enabled = true all.javascript.allow-admin-execute = true diff --git a/tests/test.py b/tests/test.py index 32fe8a20..a589ef74 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1,12 +1,15 @@ from typing import Any import networkx as nx +import pandas as pd import pytest import nx_arangodb as nxadb from .conftest import db +G_NX = nx.karate_club_graph() + def test_db(load_graph: Any) -> None: assert db.version() @@ -17,25 +20,23 @@ def test_load_graph_from_nxadb(): db.delete_graph(graph_name, drop_collections=True, ignore_missing=True) - G_nx = nx.karate_club_graph() - _ = nxadb.Graph( graph_name=graph_name, - incoming_graph_data=G_nx, + incoming_graph_data=G_NX, default_node_type="person", ) assert db.has_graph(graph_name) assert db.has_collection("person") assert db.has_collection("person_to_person") - assert db.collection("person").count() == len(G_nx.nodes) - assert db.collection("person_to_person").count() == len(G_nx.edges) + assert db.collection("person").count() == len(G_NX.nodes) + assert db.collection("person_to_person").count() == len(G_NX.edges) db.delete_graph(graph_name, drop_collections=True) def test_bc(load_graph): - G_1 = nx.karate_club_graph() + G_1 = G_NX G_2 = nxadb.Graph(incoming_graph_data=G_1) G_3 = nxadb.Graph(graph_name="KarateGraph") @@ -71,7 +72,7 @@ def test_bc(load_graph): def test_pagerank(load_graph: Any) -> None: - G_1 = nx.karate_club_graph() + G_1 = G_NX G_2 = nxadb.Graph(incoming_graph_data=G_1) G_3 = nxadb.Graph(graph_name="KarateGraph") @@ -105,7 +106,7 @@ def test_pagerank(load_graph: Any) -> None: def test_louvain(load_graph: Any) -> None: - G_1 = nx.karate_club_graph() + G_1 = G_NX G_2 = nxadb.Graph(incoming_graph_data=G_1) G_3 = nxadb.Graph(graph_name="KarateGraph") @@ -158,7 +159,7 @@ def test_shortest_path(load_graph: Any) -> None: def test_graph_nodes_crud(load_graph: Any) -> None: G_1 = nxadb.Graph(graph_name="KarateGraph", foo="bar") - G_2 = nx.Graph(nx.karate_club_graph()) + G_2 = nx.Graph(G_NX) assert G_1.graph_name == "KarateGraph" assert G_1.graph["foo"] == "bar" @@ -258,7 +259,7 @@ def test_graph_nodes_crud(load_graph: Any) -> None: def test_graph_edges_crud(load_graph: Any) -> None: G_1 = nxadb.Graph(graph_name="KarateGraph") - G_2 = nx.karate_club_graph() + G_2 = G_NX assert len(G_1.adj) == len(G_2.adj) assert len(G_1.edges) == len(G_2.edges) @@ -373,12 +374,11 @@ def test_graph_edges_crud(load_graph: Any) -> None: def test_readme(load_graph: Any) -> None: - G = nxadb.Graph(graph_name="KarateGraph") - G_nx = nx.karate_club_graph() + G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") - assert len(G.nodes) == len(G_nx.nodes) - assert len(G.adj) == len(G_nx.adj) - assert len(G.edges) == len(G_nx.edges) + assert len(G.nodes) == len(G_NX.nodes) + assert len(G.adj) == len(G_NX.adj) + assert len(G.edges) == len(G_NX.edges) G.nodes(data="club", default="unknown") G.edges(data="weight", default=1000) @@ -387,6 +387,8 @@ def test_readme(load_graph: Any) -> None: G.adj["person/1"] G.edges[("person/1", "person/3")] + assert G.nodes["1"] == G.nodes["person/1"] == G.nodes[1] + G.nodes["person/1"]["name"] = "John Doe" G.nodes["person/1"].update({"age": 40}) del G.nodes["person/1"]["name"] @@ -418,9 +420,50 @@ def test_readme(load_graph: Any) -> None: G.clear() - assert len(G.nodes) == len(G_nx.nodes) - assert len(G.adj) == len(G_nx.adj) - assert len(G.edges) == len(G_nx.edges) + assert len(G.nodes) == len(G_NX.nodes) + assert len(G.adj) == len(G_NX.adj) + assert len(G.edges) == len(G_NX.edges) + + +@pytest.mark.parametrize( + "data_type, incoming_graph_data, has_club, has_weight", + [ + ("dict of dicts", nx.karate_club_graph()._adj, False, True), + ( + "dict of lists", + {k: list(v) for k, v in G_NX._adj.items()}, + False, + False, + ), + ("container of edges", list(G_NX.edges), False, False), + ("iterator of edges", iter(G_NX.edges), False, False), + ("generator of edges", (e for e in G_NX.edges), False, False), + ("2D numpy array", nx.to_numpy_array(G_NX), False, True), + ( + "scipy sparse array", + nx.to_scipy_sparse_array(G_NX), + False, + True, + ), + ("Pandas EdgeList", nx.to_pandas_edgelist(G_NX), False, True), + # TODO: Address **nx.relabel.relabel_nodes** issue + # ("Pandas Adjacency", nx.to_pandas_adjacency(G_NX), False, True), + ], +) +def test_incoming_graph_data_not_nx_graph( + data_type: str, incoming_graph_data: Any, has_club: bool, has_weight: bool +) -> None: + # See nx.convert.to_networkx_graph for the official supported types + name = "KarateGraph" + db.delete_graph(name, drop_collections=True, ignore_missing=True) + + G = nxadb.Graph(incoming_graph_data=incoming_graph_data, graph_name=name) + + assert len(G.nodes) == len(G_NX.nodes) == db.collection(G.default_node_type).count() + assert len(G.adj) == len(G_NX.adj) == db.collection(G.default_node_type).count() + assert len(G.edges) == len(G_NX.edges) == db.collection(G.default_edge_type).count() + assert has_club == ("club" in G.nodes["0"]) + assert has_weight == ("weight" in G.adj["0"]["1"]) def test_digraph_nodes_crud() -> None: