diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 039c725e..66d5afb8 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -21,11 +21,10 @@ jobs: cache: 'pip' cache-dependency-path: setup.py - - name: Set up ArangoDB Instance via Docker - run: docker create --name adb -p 8529:8529 -e ARANGO_ROOT_PASSWORD=test arangodb/arangodb - - - name: Start ArangoDB Instance - run: docker start adb + - name: Set up ArangoDB + run: | + chmod +x starter.sh + ./starter.sh - name: Setup pip run: python -m pip install --upgrade pip setuptools wheel diff --git a/README.md b/README.md index 0b835a31..18fb4e9f 100644 --- a/README.md +++ b/README.md @@ -3,29 +3,85 @@ Development Sandbox: Open In Colab What's currently possible: -- Algorithm dispatching for GPU & CPU (`betweenness_centrality`, `pagerank`, `louvain_communities`) -- Data Load from ArangoDB to `nx` -- Data Load from ArangoDB to `nxcg` +- ArangoDB CRUD Interface for `nx.Graph` +- Algorithm dispatching to `nx` & `nxcg` (`betweenness_centrality`, `pagerank`, `louvain_communities`) +- Algorithm dispatching to ArangoDB (`shortest_path`) +- Data Load from ArangoDB to `nx` object +- Data Load from ArangoDB to `nxcg` object +- Data Load from ArangoDB via dictionary-based remote connection -Next Milestone: -- NetworkX CRUD Interface for ArangoDB +Next steps: +- Generalize `nxadb`'s support for `nx` & `nxcg` algorithms +- Improve support for `nxadb.DiGraph` +- CRUD Interface Improvements -Planned, but not yet scopped: -- NetworkX Graph Query Method -- Data Write to ArangoDB from `nx` -- Data Write to ArangoDB from `nxcg` +Planned: +- Support for `nxadb.MultiGraph` & `nxadb.MultiDiGraph` +- Data Load from `nx` to ArangoDB +- Data Load from `nxcg` to ArangoDB ```py - +import os import networkx as nx import nx_arangodb as nxadb -G_1 = nx.karate_club_graph() +os.environ["DATABASE_HOST"] = "http://localhost:8529" +os.environ["DATABASE_USERNAME"] = "root" +os.environ["DATABASE_PASSWORD"] = "password" +os.environ["DATABASE_NAME"] = "_system" + +G = nxadb.Graph(graph_name="KarateGraph") + +G_nx = nx.karate_club_graph() +assert len(G.nodes) == len(G_nx.nodes) +assert len(G.adj) == len(G_nx.adj) +assert len(G.edges) == len(G_nx.edges) + +nx.betweenness_centrality(G) +nx.pagerank(G) +nx.community.louvain_communities(G) +nx.shortest_path(G, "person/1", "person/34") +nx.all_neighbors(G, "person/1") + +G.nodes(data='club', default='unknown') +G.edges(data='weight', default=1000) + +G.nodes["person/1"] +G.adj["person/1"] +G.edges[("person/1", "person/3")] + +G.nodes["person/1"]["name"] = "John Doe" +G.nodes["person/1"].update({"age": 40}) +del G.nodes["person/1"]["name"] + +G.adj["person/1"]["person/3"]["weight"] = 2 +G.adj["person/1"]["person/3"].update({"weight": 3}) +del G.adj["person/1"]["person/3"]["weight"] + +G.edges[("person/1", "person/3")]["weight"] = 0.5 +assert G.adj["person/1"]["person/3"]["weight"] == 0.5 + +G.add_node("person/35", name="Jane Doe") +G.add_nodes_from( + [("person/36", {"name": "Jack Doe"}), ("person/37", {"name": "Jill Doe"})] +) +G.add_edge("person/1", "person/35", weight=1.5, _edge_type="knows") +G.add_edges_from( + [ + ("person/1", "person/36", {"weight": 2}), + ("person/1", "person/37", {"weight": 3}), + ], + _edge_type="knows", +) + +G.remove_edge("person/1", "person/35") +G.remove_edges_from([("person/1", "person/36"), ("person/1", "person/37")]) +G.remove_node("person/35") +G.remove_nodes_from(["person/36", "person/37"]) -G_2 = nxadb.Graph(G_1) +G.clear() -bc_1 = nx.betweenness_centrality(G_1) -bc_2 = nx.betweenness_centrality(G_2) -bc_3 = nx.betweenness_centrality(G_1, backend="arangodb") -bc_4 = nx.betweenness_centrality(G_2, backend="arangodb") -``` +assert len(G.nodes) == len(G_nx.nodes) +assert len(G.adj) == len(G_nx.adj) +assert len(G.edges) == len(G_nx.edges) +``` \ No newline at end of file diff --git a/_nx_arangodb/__init__.py b/_nx_arangodb/__init__.py index 0bf0e9b9..ff11a6d0 100644 --- a/_nx_arangodb/__init__.py +++ b/_nx_arangodb/__init__.py @@ -1,5 +1,3 @@ -# Copied from nx-cugraph - """Tell NetworkX about the arangodb backend. This file can update itself: $ make plugin-info @@ -31,20 +29,25 @@ "functions": { # BEGIN: functions "betweenness_centrality", + "is_partition", "louvain_communities", "louvain_partitions", "modularity", "pagerank", + "shortest_path", "to_scipy_sparse_array", # END: functions }, "additional_docs": { # BEGIN: additional_docs - + "shortest_path": "limited version of nx.shortest_path", # END: additional_docs }, "additional_parameters": { # BEGIN: additional_parameters + "is_partition": { + "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", + }, "louvain_communities": { "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", }, @@ -57,6 +60,9 @@ "pagerank": { "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", }, + "shortest_path": { + "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", + }, "to_scipy_sparse_array": { "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", }, diff --git a/nx_arangodb/__init__.py b/nx_arangodb/__init__.py index fcb0b7ac..9ea41f95 100644 --- a/nx_arangodb/__init__.py +++ b/nx_arangodb/__init__.py @@ -1,5 +1,3 @@ -# Copied from nx-cugraph - from networkx.exception import * from . import utils @@ -13,4 +11,6 @@ from . import algorithms from .algorithms import * +from .logger import logger + from _nx_arangodb._version import __git_commit__, __version__ diff --git a/nx_arangodb/algorithms/__init__.py b/nx_arangodb/algorithms/__init__.py index 60bb5633..00c2d294 100644 --- a/nx_arangodb/algorithms/__init__.py +++ b/nx_arangodb/algorithms/__init__.py @@ -1,4 +1,5 @@ -from . import centrality, community, link_analysis +from . import centrality, community, link_analysis, shortest_paths from .centrality import * from .community import * from .link_analysis import * +from .shortest_paths import * diff --git a/nx_arangodb/algorithms/centrality/betweenness.py b/nx_arangodb/algorithms/centrality/betweenness.py index 2ec752b4..57af49c7 100644 --- a/nx_arangodb/algorithms/centrality/betweenness.py +++ b/nx_arangodb/algorithms/centrality/betweenness.py @@ -1,24 +1,19 @@ -from networkx.algorithms.centrality import betweenness as nx_betweenness +import networkx as nx from nx_arangodb.convert import _to_nxadb_graph, _to_nxcg_graph +from nx_arangodb.logger import logger from nx_arangodb.utils import networkx_algorithm try: import nx_cugraph as nxcg GPU_ENABLED = True - print("ANTHONY: GPU is enabled") except ModuleNotFoundError: GPU_ENABLED = False - print("ANTHONY: GPU is disabled") __all__ = ["betweenness_centrality"] -# 1. If GPU is enabled, call nx-cugraph bc() after converting to an ncxg graph (in-memory graph) -# 2. If GPU is not enabled, call networkx bc() after converting to an nxadb graph (in-memory graph) -# 3. If GPU is not enabled, call networkx bc() **without** converting to a nxadb graph (remote graph) - @networkx_algorithm( is_incomplete=True, @@ -27,56 +22,26 @@ _plc="betweenness_centrality", ) def betweenness_centrality( - G, k=None, normalized=True, weight=None, endpoints=False, seed=None, run_on_gpu=True + G, + k=None, + normalized=True, + weight=None, + endpoints=False, + seed=None, + run_on_gpu=True, + pull_graph_on_cpu=True, ): - print("ANTHONY: Calling betweenness_centrality from nx_arangodb") + logger.debug(f"nxadb.betweenness_centrality for {G.__class__.__name__}") - # 1. if GPU_ENABLED and run_on_gpu: - print("ANTHONY: to_nxcg") G = _to_nxcg_graph(G, weight) - print("ANTHONY: Using nxcg bc()") + logger.debug("using nxcg.betweenness_centrality") return nxcg.betweenness_centrality(G, k=k, normalized=normalized, weight=weight) - # 2. - else: - - print("ANTHONY: to_nxadb") - G = _to_nxadb_graph(G) - - print("ANTHONY: Using nx bc()") - - betweenness = dict.fromkeys(G, 0.0) # b[v]=0 for v in G - if k is None: - nodes = G - else: - nodes = seed.sample(list(G.nodes()), k) - for s in nodes: - # single source shortest paths - if weight is None: # use BFS - S, P, sigma, _ = nx_betweenness._single_source_shortest_path_basic(G, s) - else: # use Dijkstra's algorithm - S, P, sigma, _ = nx_betweenness._single_source_dijkstra_path_basic( - G, s, weight - ) - # accumulation - if endpoints: - betweenness, _ = nx_betweenness._accumulate_endpoints( - betweenness, S, P, sigma, s - ) - else: - betweenness, _ = nx_betweenness._accumulate_basic( - betweenness, S, P, sigma, s - ) - - betweenness = nx_betweenness._rescale( - betweenness, - len(G), - normalized=normalized, - directed=G.is_directed(), - k=k, - endpoints=endpoints, - ) + G = _to_nxadb_graph(G, pull_graph=pull_graph_on_cpu) - return betweenness + logger.debug("using nx.betweenness_centrality") + return nx.betweenness_centrality.orig_func( + G, k=k, normalized=normalized, weight=weight, endpoints=endpoints, seed=seed + ) diff --git a/nx_arangodb/algorithms/community/louvain.py b/nx_arangodb/algorithms/community/louvain.py index 3aa77857..4a7f3473 100644 --- a/nx_arangodb/algorithms/community/louvain.py +++ b/nx_arangodb/algorithms/community/louvain.py @@ -3,16 +3,15 @@ import networkx as nx from nx_arangodb.convert import _to_nxadb_graph, _to_nxcg_graph +from nx_arangodb.logger import logger from nx_arangodb.utils import _dtype_param, networkx_algorithm try: import nx_cugraph as nxcg GPU_ENABLED = True - print("ANTHONY: GPU is enabled") except ModuleNotFoundError: GPU_ENABLED = False - print("ANTHONY: GPU is disabled") @networkx_algorithm( @@ -33,12 +32,14 @@ def louvain_communities( max_level=None, seed=None, run_on_gpu=True, + pull_graph_on_cpu=True, ): + logger.debug(f"nxadb.louvain_communities for {G.__class__.__name__}") + if GPU_ENABLED and run_on_gpu: - print("ANTHONY: to_nxcg") G = _to_nxcg_graph(G, weight) - print("ANTHONY: Using nxcg louvain()") + logger.debug("using nxcg.louvain_communities") return nxcg.algorithms.community.louvain._louvain_communities( G, weight=weight, @@ -48,16 +49,17 @@ def louvain_communities( seed=seed, ) - else: - print("ANTHONY: to_nxadb") - G = _to_nxadb_graph(G) - - print("ANTHONY: Using nx pagerank()") - import random + G = _to_nxadb_graph(G, pull_graph=pull_graph_on_cpu) - d = louvain_partitions(G, weight, resolution, threshold, random.Random()) - q = deque(d, maxlen=1) - return q.pop() + logger.debug("using nx.louvain_communities") + return nx.community.louvain_communities.orig_func( + G, + weight=weight, + resolution=resolution, + threshold=threshold, + max_level=max_level, + seed=seed, + ) @networkx_algorithm( @@ -73,37 +75,9 @@ def louvain_communities( def louvain_partitions( G, weight="weight", resolution=1, threshold=0.0000001, seed=None ): - partition = [{u} for u in G.nodes()] - if nx.is_empty(G): - yield partition - return - mod = modularity(G, partition, resolution=resolution, weight=weight) - is_directed = G.is_directed() - if G.is_multigraph(): - graph = nx.community._convert_multigraph(G, weight, is_directed) - else: - graph = G.__class__() - graph.add_nodes_from(G) - graph.add_weighted_edges_from(G.edges(data=weight, default=1)) - - m = graph.size(weight="weight") - partition, inner_partition, improvement = nx.community.louvain._one_level( - graph, m, partition, resolution, is_directed, seed + return nx.community.louvain_partitions.orig_func( + G, weight=weight, resolution=resolution, threshold=threshold, seed=seed ) - improvement = True - while improvement: - # gh-5901 protect the sets in the yielded list from further manipulation here - yield [s.copy() for s in partition] - new_mod = modularity( - graph, inner_partition, resolution=resolution, weight="weight" - ) - if new_mod - mod <= threshold: - return - mod = new_mod - graph = nx.community.louvain._gen_graph(graph, inner_partition) - partition, inner_partition, improvement = nx.community.louvain._one_level( - graph, m, partition, resolution, is_directed, seed - ) @networkx_algorithm( @@ -115,30 +89,18 @@ def louvain_partitions( version_added="23.10", ) def modularity(G, communities, weight="weight", resolution=1): - if not isinstance(communities, list): - communities = list(communities) - # if not is_partition(G, communities): - # raise NotAPartition(G, communities) - - directed = G.is_directed() - if directed: - out_degree = dict(G.out_degree(weight=weight)) - in_degree = dict(G.in_degree(weight=weight)) - m = sum(out_degree.values()) - norm = 1 / m**2 - else: - out_degree = in_degree = dict(G.degree(weight=weight)) - deg_sum = sum(out_degree.values()) - m = deg_sum / 2 - norm = 1 / deg_sum**2 - - def community_contribution(community): - comm = set(community) - L_c = sum(wt for u, v, wt in G.edges(comm, data=weight, default=1) if v in comm) - - out_degree_sum = sum(out_degree[u] for u in comm) - in_degree_sum = sum(in_degree[u] for u in comm) if directed else out_degree_sum - - return L_c / m - resolution * out_degree_sum * in_degree_sum * norm - - return sum(map(community_contribution, communities)) + return nx.community.modularity.orig_func( + G, communities, weight=weight, resolution=resolution + ) + + +@networkx_algorithm( + extra_params={ + **_dtype_param, + }, + is_incomplete=True, # seed not supported; self-loops not supported + is_different=True, # RNG different + version_added="23.10", +) +def is_partition(G, communities): + return nx.community.is_partition.orig_func(G, communities) diff --git a/nx_arangodb/algorithms/link_analysis/pagerank_alg.py b/nx_arangodb/algorithms/link_analysis/pagerank_alg.py index d8f0212c..2cd9dfc1 100644 --- a/nx_arangodb/algorithms/link_analysis/pagerank_alg.py +++ b/nx_arangodb/algorithms/link_analysis/pagerank_alg.py @@ -1,16 +1,15 @@ import networkx as nx from nx_arangodb.convert import _to_nxadb_graph, _to_nxcg_graph +from nx_arangodb.logger import logger from nx_arangodb.utils import _dtype_param, networkx_algorithm try: import nx_cugraph as nxcg GPU_ENABLED = True - print("ANTHONY: GPU is enabled") except ModuleNotFoundError: GPU_ENABLED = False - print("ANTHONY: GPU is disabled") @networkx_algorithm( @@ -31,15 +30,14 @@ def pagerank( *, dtype=None, run_on_gpu=True, + pull_graph_on_cpu=True, ): - print("ANTHONY: Calling pagerank from nx_arangodb") + logger.debug(f"nxadb.pagerank for {G.__class__.__name__}") - # 1. if GPU_ENABLED and run_on_gpu: - print("ANTHONY: to_nxcg") G = _to_nxcg_graph(G, weight) - print("ANTHONY: Using nxcg pagerank()") + logger.debug("using nxcg.pagerank") return nxcg.pagerank( G, alpha=alpha, @@ -52,22 +50,19 @@ def pagerank( dtype=dtype, ) - # 2. - else: - print("ANTHONY: to_nxadb") - G = _to_nxadb_graph(G) + G = _to_nxadb_graph(G, pull_graph=pull_graph_on_cpu) - print("ANTHONY: Using nx pagerank()") - return nx.algorithms.link_analysis.pagerank_alg._pagerank_scipy( - G, - alpha=alpha, - personalization=personalization, - max_iter=max_iter, - tol=tol, - nstart=nstart, - weight=weight, - dangling=dangling, - ) + logger.debug("using nx.pagerank") + return nx.algorithms.pagerank.orig_func( + G, + alpha=alpha, + personalization=personalization, + max_iter=max_iter, + tol=tol, + nstart=nstart, + weight=weight, + dangling=dangling, + ) @networkx_algorithm( @@ -75,54 +70,4 @@ def pagerank( version_added="23.12", ) def to_scipy_sparse_array(G, nodelist=None, dtype=None, weight="weight", format="csr"): - import scipy as sp - - if len(G) == 0: - raise nx.NetworkXError("Graph has no nodes or edges") - - if nodelist is None: - nodelist = list(G) - nlen = len(G) - else: - nlen = len(nodelist) - if nlen == 0: - raise nx.NetworkXError("nodelist has no nodes") - nodeset = set(G.nbunch_iter(nodelist)) - if nlen != len(nodeset): - for n in nodelist: - if n not in G: - raise nx.NetworkXError(f"Node {n} in nodelist is not in G") - raise nx.NetworkXError("nodelist contains duplicates.") - if nlen < len(G): - G = G.subgraph(nodelist) - - index = dict(zip(nodelist, range(nlen))) - coefficients = zip( - *((index[u], index[v], wt) for u, v, wt in G.edges(data=weight, default=1)) - ) - try: - row, col, data = coefficients - except ValueError: - # there is no edge in the subgraph - row, col, data = [], [], [] - - if G.is_directed(): - A = sp.sparse.coo_array((data, (row, col)), shape=(nlen, nlen), dtype=dtype) - else: - # symmetrize matrix - d = data + data - r = row + col - c = col + row - # selfloop entries get double counted when symmetrizing - # so we subtract the data on the diagonal - selfloops = list(nx.selfloop_edges(G, data=weight, default=1)) - if selfloops: - diag_index, diag_data = zip(*((index[u], -wt) for u, v, wt in selfloops)) - d += diag_data - r += diag_index - c += diag_index - A = sp.sparse.coo_array((d, (r, c)), shape=(nlen, nlen), dtype=dtype) - try: - return A.asformat(format) - except ValueError as err: - raise nx.NetworkXError(f"Unknown sparse matrix format: {format}") from err + return nx.to_scipy_sparse_array.orig_func(G, nodelist, dtype, weight, format) diff --git a/nx_arangodb/algorithms/shortest_paths/__init__.py b/nx_arangodb/algorithms/shortest_paths/__init__.py new file mode 100644 index 00000000..c9840bc1 --- /dev/null +++ b/nx_arangodb/algorithms/shortest_paths/__init__.py @@ -0,0 +1 @@ +from .generic import * diff --git a/nx_arangodb/algorithms/shortest_paths/generic.py b/nx_arangodb/algorithms/shortest_paths/generic.py new file mode 100644 index 00000000..b0ec7c09 --- /dev/null +++ b/nx_arangodb/algorithms/shortest_paths/generic.py @@ -0,0 +1,53 @@ +import networkx as nx + +import nx_arangodb as nxadb +from nx_arangodb.exceptions import ShortestPathError +from nx_arangodb.utils import _dtype_param, networkx_algorithm + +__all__ = ["shortest_path"] + + +@networkx_algorithm( + extra_params=_dtype_param, version_added="24.04", _plc={"bfs", "sssp"} +) +def shortest_path( + G: nxadb.Graph | nxadb.DiGraph, + source=None, + target=None, + weight=None, + method="dijkstra", + *, + dtype=None, +): + """limited version of nx.shortest_path""" + + if not G.graph_exists: + return nx.shortest_path.orig_func( + G, source=source, target=target, weight=weight, method=method + ) + + if target is None or source is None: + raise ShortestPathError("Both source and target must be specified for now") + + if method != "dijkstra": + raise ShortestPathError("Only dijkstra method is supported") + + query = """ + FOR vertex IN ANY SHORTEST_PATH @source TO @target GRAPH @graph + OPTIONS {'weightAttribute': @weight} + RETURN vertex._id + """ + + bind_vars = { + "source": source, + "target": target, + "graph": G.graph_name, + "weight": weight, + } + + result = list(G.aql(query, bind_vars=bind_vars)) + + if not result: + raise nx.NodeNotFound(f"Either source {source} or target {target} is not in G") + + return result diff --git a/nx_arangodb/classes/dict.py b/nx_arangodb/classes/dict.py new file mode 100644 index 00000000..c7e4ee91 --- /dev/null +++ b/nx_arangodb/classes/dict.py @@ -0,0 +1,1096 @@ +from __future__ import annotations + +from collections import UserDict, defaultdict +from collections.abc import Iterator +from typing import Any, Callable + +from arango.database import StandardDatabase +from arango.exceptions import DocumentInsertError +from arango.graph import Graph + +from nx_arangodb.logger import logger + +from .function import ( + aql, + aql_as_list, + aql_doc_get_key, + aql_doc_get_keys, + aql_doc_get_length, + aql_doc_has_key, + aql_edge_exists, + aql_edge_get, + aql_edge_id, + aql_fetch_data, + aql_single, + create_collection, + doc_delete, + doc_get_or_insert, + doc_insert, + doc_update, + get_node_id, + get_node_type_and_id, + key_is_not_reserved, + key_is_string, + keys_are_not_reserved, + keys_are_strings, +) + + +def graph_dict_factory( + db: StandardDatabase, graph_name: str +) -> Callable[..., GraphDict]: + return lambda: GraphDict(db, graph_name) + + +def node_dict_factory( + db: StandardDatabase, graph: Graph, default_node_type: str +) -> Callable[..., NodeDict]: + return lambda: NodeDict(db, graph, default_node_type) + + +def node_attr_dict_factory( + db: StandardDatabase, graph: Graph +) -> Callable[..., NodeAttrDict]: + return lambda: NodeAttrDict(db, graph) + + +def adjlist_outer_dict_factory( + db: StandardDatabase, + graph: Graph, + default_node_type: str, + edge_type_func: Callable[[str, str], str], +) -> Callable[..., AdjListOuterDict]: + return lambda: AdjListOuterDict(db, graph, default_node_type, edge_type_func) + + +def adjlist_inner_dict_factory( + db: StandardDatabase, + graph: Graph, + default_node_type: str, + edge_type_func: Callable[[str, str], str], + adjlist_outer_dict: AdjListOuterDict | None = None, +) -> Callable[..., AdjListInnerDict]: + return lambda: AdjListInnerDict( + db, graph, default_node_type, edge_type_func, adjlist_outer_dict + ) + + +def edge_attr_dict_factory( + db: StandardDatabase, graph: Graph +) -> Callable[..., EdgeAttrDict]: + return lambda: EdgeAttrDict(db, graph) + + +class GraphDict(UserDict): + """A dictionary-like object for storing graph attributes. + + Given that ArangoDB does not have a concept of graph attributes, this class + stores the attributes in a collection with the graph name as the document key. + + :param db: The ArangoDB database. + :type db: StandardDatabase + :param graph_name: The graph name. + :type graph_name: str + """ + + COLLECTION_NAME = "nxadb_graphs" + + def __init__(self, db: StandardDatabase, graph_name: str, *args, **kwargs): + logger.debug("GraphDict.__init__") + super().__init__(*args, **kwargs) + + self.db = db + self.graph_name = graph_name + self.graph_id = f"{self.COLLECTION_NAME}/{graph_name}" + + self.adb_graph = db.graph(graph_name) + self.collection = create_collection(db, self.COLLECTION_NAME) + + data = doc_get_or_insert(self.db, self.COLLECTION_NAME, self.graph_id) + self.data.update(data) + + @key_is_string + def __contains__(self, key: str) -> bool: + """'foo' in G.graph""" + if key in self.data: + logger.debug(f"cached in GraphDict.__contains__({key})") + return True + + logger.debug("aql_doc_has_key in GraphDict.__contains__") + return aql_doc_has_key(self.db, self.graph_id, key) + + @key_is_string + def __getitem__(self, key: Any) -> Any: + """G.graph['foo']""" + if value := self.data.get(key): + return value + + logger.debug("aql_doc_get_key in GraphDict.__getitem__") + result = aql_doc_get_key(self.db, self.graph_id, key) + + if not result: + raise KeyError(key) + + self.data[key] = result + + return result + + @key_is_string + @key_is_not_reserved + # @value_is_json_serializable # TODO? + def __setitem__(self, key: str, value: Any): + """G.graph['foo'] = 'bar'""" + self.data[key] = value + logger.debug(f"doc_update in GraphDict.__setitem__({key})") + doc_update(self.db, self.graph_id, {key: value}) + + @key_is_string + @key_is_not_reserved + def __delitem__(self, key): + """del G.graph['foo']""" + self.data.pop(key, None) + logger.debug(f"doc_update in GraphDict.__delitem__({key})") + doc_update(self.db, self.graph_id, {key: None}) + + @keys_are_strings + @keys_are_not_reserved + # @values_are_json_serializable # TODO? + def update(self, attrs): + """G.graph.update({'foo': 'bar'})""" + if attrs: + self.data.update(attrs) + logger.debug(f"doc_update in GraphDict.update({attrs})") + doc_update(self.db, self.graph_id, attrs) + + def clear(self): + """G.graph.clear()""" + self.data.clear() + logger.debug("cleared GraphDict") + + # if clear_remote: + # doc_insert(self.db, self.COLLECTION_NAME, self.graph_id, silent=True) + + +class NodeDict(UserDict): + """The outer-level of the dict of dict structure representing the nodes (vertices) of a graph. + + The outer dict is keyed by ArangoDB Vertex IDs and the inner dict is keyed by Vertex attributes. + + :param db: The ArangoDB database. + :type db: StandardDatabase + :param graph: The ArangoDB graph. + :type graph: Graph + :param default_node_type: The default node type. Used if the node ID is not formatted as 'type/id'. + :type default_node_type: str + """ + + def __init__( + self, + db: StandardDatabase, + graph: Graph, + default_node_type: str, + *args, + **kwargs, + ): + logger.debug("NodeDict.__init__") + super().__init__(*args, **kwargs) + + self.db = db + self.graph = graph + self.default_node_type = default_node_type + self.node_attr_dict_factory = node_attr_dict_factory(self.db, self.graph) + + @key_is_string + def __contains__(self, key: str) -> bool: + """'node/1' in G._node""" + node_id = get_node_id(key, self.default_node_type) + + if node_id in self.data: + logger.debug(f"cached in NodeDict.__contains__({node_id})") + return True + + logger.debug(f"graph.has_vertex in NodeDict.__contains__({node_id})") + return self.graph.has_vertex(node_id) + + @key_is_string + def __getitem__(self, key: str) -> NodeAttrDict: + """G._node['node/1']""" + node_id = get_node_id(key, self.default_node_type) + + if value := self.data.get(node_id): + logger.debug(f"cached in NodeDict.__getitem__({node_id})") + return value + + if value := self.graph.vertex(node_id): + logger.debug(f"graph.vertex in NodeDict.__getitem__({node_id})") + node_attr_dict: NodeAttrDict = self.node_attr_dict_factory() + node_attr_dict.node_id = node_id + node_attr_dict.data = value + + self.data[node_id] = node_attr_dict + + return node_attr_dict + + raise KeyError(key) + + @key_is_string + def __setitem__(self, key: str, value: NodeAttrDict): + """G._node['node/1'] = {'foo': 'bar'} + + Not to be confused with: + - G.add_node('node/1', foo='bar') + """ + assert isinstance(value, NodeAttrDict) + + node_type, node_id = get_node_type_and_id(key, self.default_node_type) + + logger.debug(f"doc_insert in NodeDict.__setitem__({key})") + result = doc_insert(self.db, node_type, node_id, value.data) + + node_attr_dict = self.node_attr_dict_factory() + node_attr_dict.node_id = node_id + node_attr_dict.data = result + + self.data[node_id] = node_attr_dict + + @key_is_string + def __delitem__(self, key: Any) -> None: + """del g._node['node/1']""" + node_id = get_node_id(key, self.default_node_type) + + if not self.graph.has_vertex(node_id): + raise KeyError(key) + + remove_statements = "\n".join( + f"REMOVE e IN `{edge_def['edge_collection']}` OPTIONS {{ignoreErrors: true}}" + for edge_def in self.graph.edge_definitions() + ) + + query = f""" + FOR v, e IN 1..1 ANY @src_node_id GRAPH @graph_name + {remove_statements} + """ + + bind_vars = {"src_node_id": node_id, "graph_name": self.graph.name} + + logger.debug(f"remove_edges in NodeDict.__delitem__({node_id})") + aql(self.db, query, bind_vars) + + logger.debug(f"doc_delete in NodeDict.__delitem__({node_id})") + doc_delete(self.db, node_id) + + self.data.pop(node_id, None) + + def __len__(self) -> int: + """len(g._node)""" + logger.debug("NodeDict.__len__") + return sum( + [ + self.graph.vertex_collection(c).count() + for c in self.graph.vertex_collections() + ] + ) + + def __iter__(self) -> Iterator[str]: + """iter(g._node)""" + logger.debug("NodeDict.__iter__") + for collection in self.graph.vertex_collections(): + for node_id in self.graph.vertex_collection(collection).ids(): + yield node_id + + def clear(self): + """g._node.clear()""" + self.data.clear() + logger.debug("cleared NodeDict") + + # if clear_remote: + # for collection in self.graph.vertex_collections(): + # self.graph.vertex_collection(collection).truncate() + + @keys_are_strings + def update(self, nodes: dict[str, dict[str, Any]]): + """g._node.update({'node/1': {'foo': 'bar'}, 'node/2': {'baz': 'qux'}})""" + raise NotImplementedError("NodeDict.update()") + # for node_id, attrs in nodes.items(): + # node_id = get_node_id(node_id, self.default_node_type) + + # result = doc_update(self.db, node_id, attrs) + + # node_attr_dict = self.node_attr_dict_factory() + # node_attr_dict.node_id = node_id + # node_attr_dict.data = result + + # self.data[node_id] = node_attr_dict + + def keys(self): + """g._node.keys()""" + logger.debug("NodeDict.keys()") + return self.__iter__() + + def values(self): + """g._node.values()""" + logger.debug("NodeDict.values()") + self.__fetch_all() + return self.data.values() + + def items(self, data: str | None = None, default: Any | None = None): + """g._node.items() or G._node.items(data='foo')""" + if data is None: + logger.debug("NodeDict.items(data=None)") + self.__fetch_all() + return self.data.items() + + logger.debug(f"NodeDict.items(data={data})") + v_cols = list(self.graph.vertex_collections()) + return aql_fetch_data(self.db, v_cols, data, default, is_edge=False) + + def __fetch_all(self): + logger.debug("NodeDict.__fetch_all()") + + self.data.clear() + for collection in self.graph.vertex_collections(): + for doc in self.graph.vertex_collection(collection).all(): + node_id = doc["_id"] + + node_attr_dict = self.node_attr_dict_factory() + node_attr_dict.node_id = node_id + node_attr_dict.data = doc + + self.data[node_id] = node_attr_dict + + +class NodeAttrDict(UserDict): + """The inner-level of the dict of dict structure representing the nodes (vertices) of a graph. + + :param db: The ArangoDB database. + :type db: StandardDatabase + :param graph: The ArangoDB graph. + :type graph: Graph + """ + + def __init__(self, db: StandardDatabase, graph: Graph, *args, **kwargs): + logger.debug("NodeAttrDict.__init__") + + self.db = db + self.graph = graph + self.node_id: str | None = None + + super().__init__(*args, **kwargs) + + @key_is_string + def __contains__(self, key: str) -> bool: + """'foo' in G._node['node/1']""" + if key in self.data: + logger.debug(f"cached in NodeAttrDict.__contains__({key})") + return True + + logger.debug("aql_doc_has_key in NodeAttrDict.__contains__") + return aql_doc_has_key(self.db, self.node_id, key) + + @key_is_string + def __getitem__(self, key: str) -> Any: + """G._node['node/1']['foo']""" + if value := self.data.get(key): + logger.debug(f"cached in NodeAttrDict.__getitem__({key})") + return value + + logger.debug(f"aql_doc_get_key in NodeAttrDict.__getitem__({key})") + result = aql_doc_get_key(self.db, self.node_id, key) + + if not result: + raise KeyError(key) + + self.data[key] = result + + return result + + @key_is_string + @key_is_not_reserved + # @value_is_json_serializable # TODO? + def __setitem__(self, key: str, value: Any): + """G._node['node/1']['foo'] = 'bar'""" + self.data[key] = value + logger.debug(f"doc_update in NodeAttrDict.__setitem__({key})") + doc_update(self.db, self.node_id, {key: value}) + + @key_is_string + @key_is_not_reserved + def __delitem__(self, key: str): + """del G._node['node/1']['foo']""" + self.data.pop(key, None) + logger.debug(f"doc_update in NodeAttrDict({self.node_id}).__delitem__({key})") + doc_update(self.db, self.node_id, {key: None}) + + def __iter__(self) -> Iterator[str]: + """for key in G._node['node/1']""" + logger.debug(f"NodeAttrDict({self.node_id}).__iter__") + for key in aql_doc_get_keys(self.db, self.node_id): + yield key + + def __len__(self) -> int: + """len(G._node['node/1'])""" + logger.debug(f"NodeAttrDict({self.node_id}).__len__") + return aql_doc_get_length(self.db, self.node_id) + + def keys(self): + """G._node['node/1'].keys()""" + logger.debug(f"NodeAttrDict({self.node_id}).keys()") + return self.__iter__() + + def values(self): + """G._node['node/1'].values()""" + logger.debug(f"NodeAttrDict({self.node_id}).values()") + self.data = self.db.document(self.node_id) + return self.data.values() + + def items(self): + """G._node['node/1'].items()""" + logger.debug(f"NodeAttrDict({self.node_id}).items()") + self.data = self.db.document(self.node_id) + return self.data.items() + + def clear(self): + """G._node['node/1'].clear()""" + self.data.clear() + logger.debug(f"cleared NodeAttrDict({self.node_id})") + + # if clear_remote: + # doc_insert(self.db, self.node_id, silent=True, overwrite=True) + + def update(self, attrs: dict[str, Any]): + """G._node['node/1'].update({'foo': 'bar'})""" + if attrs: + self.data.update(attrs) + + if not self.node_id: + logger.debug(f"Node ID not set, skipping NodeAttrDict(?).update()") + return + + logger.debug(f"NodeAttrDict({self.node_id}).update({attrs})") + doc_update(self.db, self.node_id, attrs) + + +class AdjListOuterDict(UserDict): + """The outer-level of the dict of dict of dict structure representing the Adjacency List of a graph. + + The outer-dict is keyed by the node ID of the source node. + + :param db: The ArangoDB database. + :type db: StandardDatabase + :param graph: The ArangoDB graph. + :type graph: Graph + :param default_node_type: The default node type. + :type default_node_type: str + :param edge_type_func: The function to generate the edge type. + :type edge_type_func: Callable[[str, str], str] + """ + + def __init__( + self, + db: StandardDatabase, + graph: Graph, + default_node_type: str, + edge_type_func: Callable[[str, str], str], + *args, + **kwargs, + ): + logger.debug("AdjListOuterDict.__init__") + + super().__init__(*args, **kwargs) + + self.db = db + self.graph = graph + self.default_node_type = default_node_type + self.edge_type_func = edge_type_func + self.adjlist_inner_dict_factory = adjlist_inner_dict_factory( + db, graph, default_node_type, edge_type_func, self + ) + + self.FETCHED_ALL_DATA = False + + # def __repr__(self) -> str: + # return f"'{self.graph.name}'" + + # def __str__(self) -> str: + # return f"'{self.graph.name}'" + + @key_is_string + def __contains__(self, key) -> bool: + """'node/1' in G.adj""" + node_id = get_node_id(key, self.default_node_type) + + if node_id in self.data: + logger.debug(f"cached in AdjListOuterDict.__contains__({node_id})") + return True + + logger.debug("graph.has_vertex in AdjListOuterDict.__contains__") + return self.graph.has_vertex(node_id) + + @key_is_string + def __getitem__(self, key: str) -> AdjListInnerDict: + """G.adj["node/1"]""" + node_type, node_id = get_node_type_and_id(key, self.default_node_type) + + if value := self.data.get(node_id): + logger.debug(f"cached in AdjListOuterDict.__getitem__({node_id})") + return value + + if self.graph.has_vertex(node_id): + logger.debug(f"graph.vertex in AdjListOuterDict.__getitem__({node_id})") + adjlist_inner_dict: AdjListInnerDict = self.adjlist_inner_dict_factory() + adjlist_inner_dict.src_node_id = node_id + + self.data[node_id] = adjlist_inner_dict + + return adjlist_inner_dict + + raise KeyError(key) + + @key_is_string + def __setitem__(self, src_key: str, adjlist_inner_dict: AdjListInnerDict): + """ + g._adj['node/1'] = AdjListInnerDict() + """ + assert isinstance(adjlist_inner_dict, AdjListInnerDict) + assert not adjlist_inner_dict.src_node_id + + logger.debug(f"AdjListOuterDict.__setitem__({src_key})") + + src_node_type, src_node_id = get_node_type_and_id( + src_key, self.default_node_type + ) + + # NOTE: this might not actually be needed... + results = {} + edge_dict: dict[str, Any] + for dst_key, edge_dict in adjlist_inner_dict.data.items(): + dst_node_type, dst_node_id = get_node_type_and_id( + dst_key, self.default_node_type + ) + + edge_type = edge_dict.get("_edge_type") + if edge_type is None: + edge_type = self.edge_type_func(src_node_type, dst_node_type) + + logger.debug(f"graph.link({src_key}, {dst_key})") + results[dst_key] = self.graph.link( + edge_type, src_node_id, dst_node_id, edge_dict + ) + + adjlist_inner_dict.src_node_id = src_node_id + adjlist_inner_dict.data = results + + self.data[src_node_id] = adjlist_inner_dict + + @key_is_string + def __delitem__(self, key: Any) -> None: + """ + del G._adj['node/1'] + """ + # Nothing else to do here, as this delete is always invoked by + # G.remove_node(), which already removes all edges via + # del G._node['node/1'] + logger.debug(f"AdjListOuterDict.__delitem__({key}) (just cache)") + node_id = get_node_id(key, self.default_node_type) + self.data.pop(node_id, None) + + def __len__(self) -> int: + """len(g._adj)""" + logger.debug("AdjListOuterDict.__len__") + return sum( + [ + self.graph.vertex_collection(c).count() + for c in self.graph.vertex_collections() + ] + ) + + def __iter__(self) -> Iterator[str]: + """for k in g._adj""" + logger.debug("AdjListOuterDict.__iter__") + + if self.FETCHED_ALL_DATA: + yield from self.data.keys() + + else: + for collection in self.graph.vertex_collections(): + for id in self.graph.vertex_collection(collection).ids(): + yield id + + def keys(self): + """g._adj.keys()""" + logger.debug("AdjListOuterDict.keys()") + return self.__iter__() + + def clear(self): + """g._node.clear()""" + self.data.clear() + self.FETCHED_ALL_DATA = False + logger.debug("cleared AdjListOuterDict") + + # if clear_remote: + # for ed in self.graph.edge_definitions(): + # self.graph.edge_collection(ed["edge_collection"]).truncate() + + @keys_are_strings + def update(self, edges: dict[str, dict[str, dict[str, Any]]]): + """g._adj.update({'node/1': {'node/2': {'foo': 'bar'}})""" + raise NotImplementedError("AdjListOuterDict.update()") + + def values(self): + """g._adj.values()""" + logger.debug("AdjListOuterDict.values()") + self.__fetch_all() + return self.data.values() + + def items(self, data: str | None = None, default: Any | None = None): + """g._adj.items() or G._adj.items(data='foo')""" + if data is None: + logger.debug("AdjListOuterDict.items(data=None)") + self.__fetch_all() + return self.data.items() + + logger.debug(f"AdjListOuterDict.items(data={data})") + e_cols = [ed["edge_collection"] for ed in self.graph.edge_definitions()] + result = aql_fetch_data(self.db, e_cols, data, default, is_edge=True) + yield from result + + # TODO: Revisit + def __fetch_all(self) -> None: + logger.debug("AdjListOuterDict.__fetch_all()") + + if self.FETCHED_ALL_DATA: + logger.debug("Already fetched data, skipping fetch") + return + + self.clear() + # items = defaultdict(dict) + for ed in self.graph.edge_definitions(): + collection = ed["edge_collection"] + + for edge in self.graph.edge_collection(collection): + src_node_id = edge["_from"] + dst_node_id = edge["_to"] + + # items[src_node_id][dst_node_id] = edge + # items[dst_node_id][src_node_id] = edge + + if src_node_id in self.data: + src_inner_dict = self.data[src_node_id] + else: + src_inner_dict = self.adjlist_inner_dict_factory() + src_inner_dict.src_node_id = src_node_id + self.data[src_node_id] = src_inner_dict + + if dst_node_id in self.data: + dst_inner_dict = self.data[dst_node_id] + else: + dst_inner_dict = self.adjlist_inner_dict_factory() + dst_inner_dict.src_node_id = dst_node_id + self.data[dst_node_id] = dst_inner_dict + + edge_attr_dict = src_inner_dict.edge_attr_dict_factory() + edge_attr_dict.edge_id = edge["_id"] + edge_attr_dict.data = edge + + self.data[src_node_id].data[dst_node_id] = edge_attr_dict + self.data[dst_node_id].data[src_node_id] = edge_attr_dict + + self.FETCHED_ALL_DATA = True + + +class AdjListInnerDict(UserDict): + """The inner-level of the dict of dict of dict structure representing the Adjacency List of a graph. + + The inner-dict is keyed by the node ID of the destination node. + + :param db: The ArangoDB database. + :type db: StandardDatabase + :param graph: The ArangoDB graph. + :type graph: Graph + :param default_node_type: The default node type. + :type default_node_type: str + :param edge_type_func: The function to generate the edge type. + :type edge_type_func: Callable[[str, str], str] + """ + + def __init__( + self, + db: StandardDatabase, + graph: Graph, + default_node_type: str, + edge_type_func: Callable[[str, str], str], + adjlist_outer_dict: AdjListOuterDict, + *args, + **kwargs, + ): + logger.debug("AdjListInnerDict.__init__") + + super().__init__(*args, **kwargs) + + self.db = db + self.graph = graph + self.default_node_type = default_node_type + self.edge_type_func = edge_type_func + self.adjlist_outer_dict = adjlist_outer_dict + + self.src_node_id = None + + self.edge_attr_dict_factory = edge_attr_dict_factory(self.db, self.graph) + + self.FETCHED_ALL_DATA = False + + def __get_mirrored_edge_attr_dict(self, dst_node_id: str) -> bool: + logger.debug(f"checking for mirrored edge ({self.src_node_id}, {dst_node_id})") + if dst_node_id in self.adjlist_outer_dict.data: + if self.src_node_id in self.adjlist_outer_dict.data[dst_node_id].data: + return self.adjlist_outer_dict.data[dst_node_id].data[self.src_node_id] + + return None + + # def __repr__(self) -> str: + # return f"'{self.src_node_id}'" + + # def __str__(self) -> str: + # return f"'{self.src_node_id}'" + + @key_is_string + def __contains__(self, key) -> bool: + """'node/2' in G.adj['node/1']""" + dst_node_id = get_node_id(key, self.default_node_type) + + if dst_node_id in self.data: + logger.debug(f"cached in AdjListInnerDict.__contains__({dst_node_id})") + return True + + logger.debug(f"aql_edge_exists in AdjListInnerDict.__contains__({dst_node_id})") + return aql_edge_exists( + self.db, + self.src_node_id, + dst_node_id, + self.graph.name, + direction="ANY", + ) + + @key_is_string + def __getitem__(self, key) -> EdgeAttrDict: + """g._adj['node/1']['node/2']""" + dst_node_id = get_node_id(key, self.default_node_type) + + if dst_node_id in self.data: + m = f"cached in AdjListInnerDict({self.src_node_id}).__getitem__({dst_node_id})" + logger.debug(m) + return self.data[dst_node_id] + + if mirrored_edge_attr_dict := self.__get_mirrored_edge_attr_dict(dst_node_id): + logger.debug("No need to fetch the edge, as it is already cached") + self.data[dst_node_id] = mirrored_edge_attr_dict + return mirrored_edge_attr_dict + + m = f"aql_edge_get in AdjListInnerDict({self.src_node_id}).__getitem__({dst_node_id})" + edge = aql_edge_get( + self.db, + self.src_node_id, + dst_node_id, + self.graph.name, + direction="ANY", + ) + + if not edge: + raise KeyError(key) + + edge_attr_dict = self.edge_attr_dict_factory() + edge_attr_dict.edge_id = edge["_id"] + edge_attr_dict.data = edge + + self.data[dst_node_id] = edge_attr_dict + + return edge_attr_dict + + @key_is_string + def __setitem__(self, key: str, value: dict | EdgeAttrDict): + """g._adj['node/1']['node/2'] = {'foo': 'bar'}""" + assert isinstance(value, EdgeAttrDict) + logger.debug(f"AdjListInnerDict({self.src_node_id}).__setitem__({key})") + + src_node_type = self.src_node_id.split("/")[0] + dst_node_type, dst_node_id = get_node_type_and_id(key, self.default_node_type) + + if mirrored_edge_attr_dict := self.__get_mirrored_edge_attr_dict(dst_node_id): + logger.debug("No need to create a new edge, as it already exists") + self.data[dst_node_id] = mirrored_edge_attr_dict + return + + edge_type = value.data.get("_edge_type") + if edge_type is None: + edge_type = self.edge_type_func(src_node_type, dst_node_type) + logger.debug(f"No edge type specified, so generated: {edge_type})") + + if edge_id := value.edge_id: + m = f"edge id found, deleting ({self.src_node_id, dst_node_id})" + logger.debug(m) + self.graph.delete_edge(edge_id) + + elif edge_id := aql_edge_id( + self.db, + self.src_node_id, + dst_node_id, + self.graph.name, + direction="ANY", + ): + m = f"existing edge found, deleting ({self.src_node_id, dst_node_id})" + logger.debug(m) + self.graph.delete_edge(edge_id) + + edge_data = value.data + logger.debug(f"graph.link({self.src_node_id}, {dst_node_id})") + edge = self.graph.link(edge_type, self.src_node_id, dst_node_id, edge_data) + + edge_attr_dict = self.edge_attr_dict_factory() + edge_attr_dict.edge_id = edge["_id"] + edge_attr_dict.data = { + **edge_data, + **edge, + "_from": self.src_node_id, + "_to": dst_node_id, + } + + self.data[dst_node_id] = edge_attr_dict + + @key_is_string + def __delitem__(self, key: Any) -> None: + """del g._adj['node/1']['node/2']""" + dst_node_id = get_node_id(key, self.default_node_type) + self.data.pop(dst_node_id, None) + + if self.__get_mirrored_edge_attr_dict(dst_node_id): + m = "No need to delete the edge, as the next del will take care of it" + logger.debug(m) + return + + logger.debug(f"fetching edge ({self.src_node_id, dst_node_id})") + edge_id = aql_edge_id( + self.db, + self.src_node_id, + dst_node_id, + self.graph.name, + direction="ANY", + ) + + if not edge_id: + m = f"edge not found, AdjListInnerDict({self.src_node_id}).__delitem__({dst_node_id})" + logger.debug(m) + return + + logger.debug(f"graph.delete_edge({edge_id})") + self.graph.delete_edge(edge_id) + + def __len__(self) -> int: + """len(g._adj['node/1'])""" + assert self.src_node_id + + if self.FETCHED_ALL_DATA: + m = f"Already fetched data, skipping AdjListInnerDict({self.src_node_id}).__len__" + logger.debug(m) + return len(self.data) + + query = """ + RETURN LENGTH( + FOR v, e IN 1..1 OUTBOUND @src_node_id GRAPH @graph_name + RETURN 1 + ) + """ + + bind_vars = {"src_node_id": self.src_node_id, "graph_name": self.graph.name} + + logger.debug(f"aql_single in AdjListInnerDict({self.src_node_id}).__len__") + count = aql_single(self.db, query, bind_vars) + + return count if count is not None else 0 + + def __iter__(self) -> Iterator[str]: + """for k in g._adj['node/1']""" + if self.FETCHED_ALL_DATA: + m = f"Already fetched data, skipping AdjListInnerDict({self.src_node_id}).__iter__" + logger.debug(m) + yield from self.data.keys() + + else: + query = """ + FOR v, e IN 1..1 OUTBOUND @src_node_id GRAPH @graph_name + RETURN e._to + """ + + bind_vars = {"src_node_id": self.src_node_id, "graph_name": self.graph.name} + + logger.debug(f"aql in AdjListInnerDict({self.src_node_id}).__iter__") + yield from aql(self.db, query, bind_vars) + + def keys(self): + """g._adj['node/1'].keys()""" + logger.debug(f"AdjListInnerDict({self.src_node_id}).keys()") + return self.__iter__() + + def clear(self): + """G._adj['node/1'].clear()""" + self.data.clear() + self.FETCHED_ALL_DATA = False + logger.debug(f"cleared AdjListInnerDict({self.src_node_id})") + + def update(self, edges: dict[str, dict[str, Any]]): + """g._adj['node/1'].update({'node/2': {'foo': 'bar'}})""" + raise NotImplementedError("AdjListInnerDict.update()") + + def values(self): + """g._adj['node/1'].values()""" + logger.debug(f"AdjListInnerDict({self.src_node_id}).values()") + self.__fetch_all() + return self.data.values() + + def items(self): + """g._adj['node/1'].items()""" + logger.debug(f"AdjListInnerDict({self.src_node_id}).items()") + self.__fetch_all() + return self.data.items() + + def __fetch_all(self): + logger.debug(f"AdjListInnerDict({self.src_node_id}).__fetch_all()") + + if self.FETCHED_ALL_DATA: + logger.debug("Already fetched data, skipping fetch") + return + + self.clear() + + query = """ + FOR v, e IN 1..1 OUTBOUND @src_node_id GRAPH @graph_name + RETURN e + """ + + bind_vars = {"src_node_id": self.src_node_id, "graph_name": self.graph.name} + + for edge in aql(self.db, query, bind_vars): + edge_attr_dict = self.edge_attr_dict_factory() + edge_attr_dict.edge_id = edge["_id"] + edge_attr_dict.data = edge + + self.data[edge["_to"]] = edge_attr_dict + + self.FETCHED_ALL_DATA = True + + +class EdgeAttrDict(UserDict): + """The innermost-level of the dict of dict of dict structure representing the Adjacency List of a graph. + + The innermost-dict is keyed by the edge attribute key. + + :param db: The ArangoDB database. + :type db: StandardDatabase + :param graph: The ArangoDB graph. + :type graph: Graph + """ + + def __init__( + self, + db: StandardDatabase, + graph: Graph, + *args, + **kwargs, + ): + logger.debug("EdgeAttrDict.__init__") + + super().__init__(*args, **kwargs) + + self.db = db + self.graph = graph + self.edge_id: str | None = None + + @key_is_string + def __contains__(self, key: str) -> bool: + """'foo' in G._adj['node/1']['node/2']""" + if key in self.data: + logger.debug(f"cached in EdgeAttrDict({self.edge_id}).__contains__({key})") + return True + + logger.debug(f"aql_doc_has_key in EdgeAttrDict({self.edge_id}).__contains__") + return aql_doc_has_key(self.db, self.edge_id, key) + + @key_is_string + def __getitem__(self, key: str) -> Any: + """G._adj['node/1']['node/2']['foo']""" + if value := self.data.get(key): + logger.debug(f"cached in EdgeAttrDict({self.edge_id}).__getitem__({key})") + return value + + logger.debug( + f"aql_doc_get_key in EdgeAttrDict({self.edge_id}).__getitem__({key})" + ) + result = aql_doc_get_key(self.db, self.edge_id, key) + + if not result: + raise KeyError(key) + + self.data[key] = result + + return result + + @key_is_string + @key_is_not_reserved + # @value_is_json_serializable # TODO? + def __setitem__(self, key: str, value: Any): + """G._adj['node/1']['node/2']['foo'] = 'bar'""" + self.data[key] = value + logger.debug(f"doc_update in EdgeAttrDict({self.edge_id}).__setitem__({key})") + doc_update(self.db, self.edge_id, {key: value}) + + @key_is_string + @key_is_not_reserved + def __delitem__(self, key: str): + """del G._adj['node/1']['node/2']['foo']""" + self.data.pop(key, None) + logger.debug(f"doc_update in EdgeAttrDict({self.edge_id}).__delitem__({key})") + doc_update(self.db, self.edge_id, {key: None}) + + def __iter__(self) -> Iterator[str]: + """for key in G._adj['node/1']['node/2']""" + logger.debug(f"EEdgeAttrDict({self.edge_id}).__iter__") + for key in aql_doc_get_keys(self.db, self.edge_id): + yield key + + def __len__(self) -> int: + """len(G._adj['node/1']['node/'2])""" + logger.debug(f"EdgeAttrDict({self.edge_id}).__len__") + return aql_doc_get_length(self.db, self.edge_id) + + def keys(self): + """G._adj['node/1']['node/'2].keys()""" + logger.debug(f"EdgeAttrDict({self.edge_id}).keys()") + return self.__iter__() + + def values(self): + """G._adj['node/1']['node/'2].values()""" + logger.debug(f"EdgeAttrDict({self.edge_id}).values()") + self.data = self.db.document(self.edge_id) + return self.data.values() + + def items(self): + """G._adj['node/1']['node/'2].items()""" + logger.debug(f"EdgeAttrDict({self.edge_id}).items()") + self.data = self.db.document(self.edge_id) + return self.data.items() + + def clear(self): + """G._adj['node/1']['node/'2].clear()""" + self.data.clear() + logger.debug(f"cleared EdgeAttrDict({self.edge_id})") + + def update(self, attrs: dict[str, Any]): + """G._adj['node/1']['node/'2].update({'foo': 'bar'})""" + if attrs: + self.data.update(attrs) + + if not self.edge_id: + logger.debug("Edge ID not set, skipping EdgeAttrDict(?).update()") + return + + logger.debug(f"EdgeAttrDict({self.edge_id}).update({attrs})") + doc_update(self.db, self.edge_id, attrs) diff --git a/nx_arangodb/classes/digraph.py b/nx_arangodb/classes/digraph.py index 838f2dc9..83b7333a 100644 --- a/nx_arangodb/classes/digraph.py +++ b/nx_arangodb/classes/digraph.py @@ -1,76 +1,93 @@ import os +from typing import ClassVar import networkx as nx from arango import ArangoClient +from arango.cursor import Cursor from arango.database import StandardDatabase +from arango.exceptions import ServerConnectionError import nx_arangodb as nxadb -from nx_arangodb.classes.graph import Graph +from nx_arangodb.exceptions import * +from nx_arangodb.logger import logger networkx_api = nxadb.utils.decorators.networkx_class(nx.DiGraph) __all__ = ["DiGraph"] -class DiGraph(nx.DiGraph, Graph): +class DiGraph(nx.DiGraph): + __networkx_backend__: ClassVar[str] = "arangodb" # nx >=3.2 + __networkx_plugin__: ClassVar[str] = "arangodb" # nx <3.2 + @classmethod def to_networkx_class(cls) -> type[nx.DiGraph]: return nx.DiGraph def __init__( self, + graph_name: str | None = None, + # default_node_type: str = "nxadb_nodes", + # edge_type_func: Callable[[str, str], str] = lambda u, v: f"{u}_to_{v}", *args, **kwargs, ): - super().__init__(*args, **kwargs) + m = "Please note that nxadb.DiGraph has no ArangoDB CRUD support yet." + logger.warning(m) + + if kwargs.get("incoming_graph_data") is not None and graph_name is not None: + m = "Cannot pass both **incoming_graph_data** and **graph_name** yet" + raise NotImplementedError(m) self.__db = None self.__graph_name = None self.__graph_exists = False + self.__set_db() + if self.__db is not None: + self.__set_graph_name(graph_name) + + self.auto_sync = True + self.graph_loader_parallelism = None self.graph_loader_batch_size = None - self.use_node_and_adj_dict_cache = False + # NOTE: Need to revisit these... + # self.maintain_node_dict_cache = False + # self.maintain_adj_dict_cache = False + self.use_nx_cache = False self.use_coo_cache = False self.src_indices = None self.dst_indices = None self.vertex_ids_to_index = None - self.set_db() - if self.__db is not None: - self.set_graph_name() + # self.default_node_type = default_node_type + # self.edge_type_func = edge_type_func + # self.default_edge_type = edge_type_func(default_node_type, default_node_type) - @property - def db(self) -> StandardDatabase: - if self.__db is None: - raise ValueError("Database not set") - - return self.__db - - @property - def graph_name(self) -> str: - if self.__graph_name is None: - raise ValueError("Graph name not set") + if self.__graph_exists: + self.adb_graph = self.db.graph(graph_name) + # self.__create_default_collections() + # self.__set_factory_methods() - return self.__graph_name + super().__init__(*args, **kwargs) - @property - def graph_exists(self) -> bool: - return self.__graph_exists + ########### + # Getters # + ########### @property def db(self) -> StandardDatabase: if self.__db is None: - raise ValueError("Database not set") + raise DatabaseNotSet("Database not set") return self.__db @property def graph_name(self) -> str: if self.__graph_name is None: - raise ValueError("Graph name not set") + raise GraphNameNotSet("Graph name not set") return self.__graph_name @@ -78,106 +95,64 @@ def graph_name(self) -> str: def graph_exists(self) -> bool: return self.__graph_exists - def clear_coo_cache(self): - self.src_indices = None - self.dst_indices = None - self.vertex_ids_to_index = None + ########### + # Setters # + ########### - def set_db(self, db: StandardDatabase | None = None): + def __set_db(self, db: StandardDatabase | None = None): if db is not None: if not isinstance(db, StandardDatabase): - raise TypeError( - "**db** must be an instance of arango.database.StandardDatabase" - ) + m = "arango.database.StandardDatabase" + raise TypeError(m) self.__db = db return - self.__host = os.getenv("DATABASE_HOST") - self.__username = os.getenv("DATABASE_USERNAME") - self.__password = os.getenv("DATABASE_PASSWORD") - self.__db_name = os.getenv("DATABASE_NAME") + self._host = os.getenv("DATABASE_HOST") + self._username = os.getenv("DATABASE_USERNAME") + self._password = os.getenv("DATABASE_PASSWORD") + self._db_name = os.getenv("DATABASE_NAME") # TODO: Raise a custom exception if any of the environment # variables are missing. For now, we'll just set db to None. - if not all([self.__host, self.__username, self.__password, self.__db_name]): + if not all([self._host, self._username, self._password, self._db_name]): self.__db = None + logger.warning("Database environment variables not set") return - self.__db = ArangoClient(hosts=self.__host, request_timeout=None).db( - self.__db_name, self.__username, self.__password, verify=True - ) + try: + self.__db = ArangoClient(hosts=self._host, request_timeout=None).db( + self._db_name, self._username, self._password, verify=True + ) + except ServerConnectionError as e: + self.__db = None + logger.warning(f"Could not connect to the database: {e}") - def set_graph_name(self, graph_name: str | None = None): + def __set_graph_name(self, graph_name: str | None = None): if self.__db is None: - raise ValueError("Cannot set graph name without setting the database first") + raise DatabaseNotSet( + "Cannot set graph name without setting the database first" + ) - self.__graph_name = os.getenv("DATABASE_GRAPH_NAME") - if graph_name is not None: - if not isinstance(graph_name, str): - raise TypeError("**graph_name** must be a string") + if graph_name is None: + self.__graph_exists = False + logger.warning(f"**graph_name** not set for {self.__class__.__name__}") + return - self.__graph_name = graph_name + if not isinstance(graph_name, str): + raise TypeError("**graph_name** must be a string") - if self.__graph_name is None: - self.__graph_exists = False - print("DATABASE_GRAPH_NAME environment variable not set") + self.__graph_name = graph_name + self.__graph_exists = self.db.has_graph(graph_name) - elif not self.db.has_graph(self.__graph_name): - self.__graph_exists = False - print(f"Graph '{self.__graph_name}' does not exist in the database") - - else: - self.__graph_exists = True - print(f"Found graph '{self.__graph_name}' in the database") - - def pull(self, load_node_and_adj_dict=True, load_coo=True): - if not self.graph_exists: - raise ValueError("Graph does not exist in the database") - - adb_graph = self.db.graph(self.graph_name) - - v_cols = adb_graph.vertex_collections() - edge_definitions = adb_graph.edge_definitions() - e_cols = {c["edge_collection"] for c in edge_definitions} - - metagraph = { - "vertexCollections": {col: {} for col in v_cols}, - "edgeCollections": {col: {} for col in e_cols}, - } - - from phenolrs.graph_loader import GraphLoader - - kwargs = {} - if self.graph_loader_parallelism is not None: - kwargs["parallelism"] = self.graph_loader_parallelism - if self.graph_loader_batch_size is not None: - kwargs["batch_size"] = self.graph_loader_batch_size - - result = GraphLoader.load( - self.db.name, - metagraph, - [self.__host], - username=self.__username, - password=self.__password, - load_node_dict=load_node_and_adj_dict, - load_adj_dict=load_node_and_adj_dict, - load_adj_dict_as_undirected=False, - load_coo=load_coo, - **kwargs, - ) - - if load_node_and_adj_dict: - # hacky, i don't like this - # need to revisit... - # consider using nx.convert.from_dict_of_dicts instead - self._node = result[0] - self._adj = result[1] - - if load_coo: - self.src_indices = result[2] - self.dst_indices = result[3] - self.vertex_ids_to_index = result[4] - - def push(self): - raise NotImplementedError("What would this look like?") + logger.info(f"Graph '{graph_name}' exists: {self.__graph_exists}") + + #################### + # ArangoDB Methods # + #################### + + def aql(self, query: str, bind_vars: dict | None = None, **kwargs) -> Cursor: + return nxadb.classes.function.aql(self.db, query, bind_vars, **kwargs) + + def pull(self, load_node_dict=True, load_adj_dict=True, load_coo=True): + raise NotImplementedError("nxadb.DiGraph.pull() is not implemented yet.") diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py new file mode 100644 index 00000000..b50107ca --- /dev/null +++ b/nx_arangodb/classes/function.py @@ -0,0 +1,371 @@ +from __future__ import annotations + +from typing import Any, Tuple + +import arango +import networkx as nx +import numpy as np +from arango import exceptions, graph + +import nx_arangodb as nxadb + +from ..exceptions import ( + AQLMultipleResultsFound, + GraphDoesNotExist, + InvalidTraversalDirection, +) + + +def get_arangodb_graph( + G: nxadb.Graph | nxadb.DiGraph, + load_node_dict: bool, + load_adj_dict: bool, + load_adj_dict_as_directed: bool, + load_coo: bool, +) -> Tuple[ + dict[str, dict[str, Any]], + dict[str, dict[str, dict[str, Any]]], + np.ndarray, + np.ndarray, + dict[str, int], +]: + """Pulls the graph from the database, assuming the graph exists. + + Returns the folowing representations: + - Node dictionary (nx.Graph) + - Adjacency dictionary (nx.Graph) + - Source Indices (COO) + - Destination Indices (COO) + - Node-ID-to-index mapping (COO) + """ + if not G.graph_exists: + raise GraphDoesNotExist( + "Graph does not exist in the database. Can't load graph." + ) + + adb_graph = G.db.graph(G.graph_name) + v_cols = adb_graph.vertex_collections() + edge_definitions = adb_graph.edge_definitions() + e_cols = {c["edge_collection"] for c in edge_definitions} + + metagraph = { + "vertexCollections": {col: {} for col in v_cols}, + "edgeCollections": {col: {} for col in e_cols}, + } + + from phenolrs.graph_loader import GraphLoader + + kwargs = {} + if G.graph_loader_parallelism is not None: + kwargs["parallelism"] = G.graph_loader_parallelism + if G.graph_loader_batch_size is not None: + kwargs["batch_size"] = G.graph_loader_batch_size + + return GraphLoader.load( + G.db.name, + metagraph, + [G._host], + username=G._username, + password=G._password, + load_node_dict=load_node_dict, + load_adj_dict=load_adj_dict, + load_adj_dict_as_directed=load_adj_dict_as_directed, + load_coo=load_coo, + **kwargs, + ) + + +def key_is_string(func) -> Any: + """Decorator to check if the key is a string.""" + + def wrapper(self, key, *args, **kwargs) -> Any: + if not isinstance(key, str): + raise TypeError(f"'{key}' is not a string.") + + return func(self, key, *args, **kwargs) + + return wrapper + + +def keys_are_strings(func) -> Any: + """Decorator to check if the keys are strings.""" + + def wrapper(self, dict, *args, **kwargs) -> Any: + if not all(isinstance(key, str) for key in dict): + raise TypeError(f"All keys must be strings.") + + return func(self, dict, *args, **kwargs) + + return wrapper + + +RESERVED_KEYS = {"_id", "_key", "_rev"} + + +def key_is_not_reserved(func) -> Any: + """Decorator to check if the key is not reserved.""" + + def wrapper(self, key, *args, **kwargs) -> Any: + if key in RESERVED_KEYS: + raise KeyError(f"'{key}' is a reserved key.") + + return func(self, key, *args, **kwargs) + + return wrapper + + +def keys_are_not_reserved(func) -> Any: + """Decorator to check if the keys are not reserved.""" + + def wrapper(self, dict, *args, **kwargs) -> Any: + if any(key in RESERVED_KEYS for key in dict): + raise KeyError(f"All keys must not be reserved.") + + return func(self, dict, *args, **kwargs) + + return wrapper + + +def create_collection( + db: arango.StandardDatabase, collection_name: str, edge: bool = False +) -> arango.StandardCollection: + """Creates a collection if it does not exist and returns it.""" + if not db.has_collection(collection_name): + db.create_collection(collection_name, edge=edge) + + return db.collection(collection_name) + + +def aql( + db: arango.StandardDatabase, query: str, bind_vars: dict[str, Any], **kwargs +) -> arango.Cursor: + """Executes an AQL query and returns the cursor.""" + return db.aql.execute(query, bind_vars=bind_vars, stream=True, **kwargs) + + +def aql_as_list( + db: arango.StandardDatabase, query: str, bind_vars: dict[str, Any], **kwargs +) -> list[Any]: + """Executes an AQL query and returns the results as a list.""" + return list(aql(db, query, bind_vars, **kwargs)) + + +def aql_single( + db: arango.StandardDatabase, query: str, bind_vars: dict[str, Any] +) -> Any: + """Executes an AQL query and returns the first result.""" + result = aql_as_list(db, query, bind_vars) + if len(result) == 0: + return None + + if len(result) > 1: + raise AQLMultipleResultsFound(f"Multiple results found: {result}") + + return result[0] + + +def aql_doc_has_key(db: arango.StandardDatabase, id: str, key: str) -> bool: + """Checks if a document has a key.""" + query = f"RETURN HAS(DOCUMENT(@id), @key)" + bind_vars = {"id": id, "key": key} + return aql_single(db, query, bind_vars) + + +def aql_doc_get_key(db: arango.StandardDatabase, id: str, key: str) -> Any: + """Gets a key from a document.""" + query = f"RETURN DOCUMENT(@id).@key" + bind_vars = {"id": id, "key": key} + return aql_single(db, query, bind_vars) + + +def aql_doc_get_keys(db: arango.StandardDatabase, id: str) -> list[str]: + """Gets the keys of a document.""" + query = f"RETURN ATTRIBUTES(DOCUMENT(@id))" + bind_vars = {"id": id} + return aql_single(db, query, bind_vars) + + +def aql_doc_get_length(db: arango.StandardDatabase, id: str) -> int: + """Gets the length of a document.""" + query = f"RETURN LENGTH(DOCUMENT(@id))" + bind_vars = {"id": id} + return aql_single(db, query, bind_vars) + + +def aql_edge_exists( + db: arango.StandardDatabase, + src_node_id: str, + dst_node_id: str, + graph_name: str, + direction: str, +): + return aql_edge( + db, + src_node_id, + dst_node_id, + graph_name, + direction, + return_clause="true", + ) + + +def aql_edge_get( + db: arango.StandardDatabase, + src_node_id: str, + dst_node_id: str, + graph_name: str, + direction: str, +): + # TODO: need the use of DISTINCT + return_clause = "DISTINCT e" if direction == "ANY" else "e" + return aql_edge( + db, + src_node_id, + dst_node_id, + graph_name, + direction, + return_clause=return_clause, + ) + + +def aql_edge_id( + db: arango.StandardDatabase, + src_node_id: str, + dst_node_id: str, + graph_name: str, + direction: str, +): + # TODO: need the use of DISTINCT + return_clause = "DISTINCT e._id" if direction == "ANY" else "e._id" + return aql_edge( + db, + src_node_id, + dst_node_id, + graph_name, + direction, + return_clause=return_clause, + ) + + +def aql_edge( + db: arango.StandardDatabase, + src_node_id: str, + dst_node_id: str, + graph_name: str, + direction: str, + return_clause: str, +): + if direction == "INBOUND": + filter_clause = f"e._from == @dst_node_id" + elif direction == "OUTBOUND": + filter_clause = f"e._to == @dst_node_id" + elif direction == "ANY": + filter_clause = f"(e._from == @dst_node_id AND e._to == @src_node_id) OR (e._to == @dst_node_id AND e._from == @src_node_id)" + else: + raise InvalidTraversalDirection(f"Invalid direction: {direction}") + + query = f""" + FOR v, e IN 1..1 {direction} @src_node_id GRAPH @graph_name + FILTER {filter_clause} + RETURN {return_clause} + """ + + bind_vars = { + "src_node_id": src_node_id, + "dst_node_id": dst_node_id, + "graph_name": graph_name, + } + + return aql_single(db, query, bind_vars) + + +def aql_fetch_data( + db: arango.StandardDatabase, + collections: list[str], + data: str, + default: Any, + is_edge: bool = True, +) -> dict[str, Any] | list[tuple[str, str, Any]]: + if is_edge: + items = [] + for collection in collections: + query = f""" + LET result = ( + FOR doc IN `{collection}` + RETURN [doc._from, doc._to, doc.@data or @default] + ) + + RETURN result + """ + + bind_vars = {"data": data, "default": default} + + items.extend(aql_single(db, query, bind_vars)) + + return items + + else: + return_clause = f"{{[doc._id]: doc.@data or @default}}" + + items = {} + for collection in collections: + query = f""" + LET result = ( + FOR doc IN `{collection}` + RETURN {return_clause} + ) + + RETURN MERGE(result) + """ + + bind_vars = {"data": data, "default": default} + + items.update(aql_single(db, query, bind_vars)) + + return items.items() + + +def doc_update( + db: arango.StandardDatabase, id: str, data: dict[str, Any], **kwargs +) -> None: + """Updates a document in the collection.""" + db.update_document({**data, "_id": id}, keep_none=False, silent=True, **kwargs) + + +def doc_delete(db: arango.StandardDatabase, id: str, **kwargs) -> None: + """Deletes a document from the collection.""" + db.delete_document(id, silent=True, **kwargs) + + +def doc_insert( + db: arango.StandardDatabase, + collection: str, + id: str, + data: dict[str, Any] = {}, + **kwargs, +) -> dict[str, Any] | bool: + """Inserts a document into a collection.""" + return db.insert_document(collection, {**data, "_id": id}, overwrite=True, **kwargs) + + +def doc_get_or_insert( + db: arango.StandardDatabase, collection: str, id: str, **kwargs +) -> dict[str, Any]: + """Loads a document if existing, otherwise inserts it & returns it.""" + if db.has_document(id): + return db.document(id) + + return doc_insert(db, collection, id, **kwargs) + + +def get_node_id(key: str, default_node_type: str) -> str: + """Gets the node ID.""" + return key if "/" in key else f"{default_node_type}/{key}" + + +def get_node_type_and_id(key: str, default_node_type: str) -> tuple[str, str]: + """Gets the node type and ID.""" + if "/" in key: + return key.split("/")[0], key + + return default_node_type, f"{default_node_type}/{key}" diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index 71ea891d..a5f26022 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -1,11 +1,26 @@ import os -from typing import ClassVar +from functools import cached_property +from typing import Callable, ClassVar import networkx as nx from arango import ArangoClient +from arango.cursor import Cursor from arango.database import StandardDatabase +from arango.exceptions import ServerConnectionError import nx_arangodb as nxadb +from nx_arangodb.exceptions import * +from nx_arangodb.logger import logger + +from .dict import ( + adjlist_inner_dict_factory, + adjlist_outer_dict_factory, + edge_attr_dict_factory, + graph_dict_factory, + node_attr_dict_factory, + node_dict_factory, +) +from .reportviews import CustomEdgeView, CustomNodeView networkx_api = nxadb.utils.decorators.networkx_class(nx.Graph) @@ -22,40 +37,109 @@ def to_networkx_class(cls) -> type[nx.Graph]: def __init__( self, + graph_name: str | None = None, + default_node_type: str = "nxadb_nodes", + edge_type_func: Callable[[str, str], str] = lambda u, v: f"{u}_to_{v}", *args, **kwargs, ): - super().__init__(*args, **kwargs) + if kwargs.get("incoming_graph_data") is not None and graph_name is not None: + m = "Cannot pass both **incoming_graph_data** and **graph_name** yet" + raise NotImplementedError(m) self.__db = None self.__graph_name = None self.__graph_exists = False + self.__set_db() + if self.__db is not None: + self.__set_graph_name(graph_name) + + self.auto_sync = True + self.graph_loader_parallelism = None self.graph_loader_batch_size = None - self.use_node_and_adj_dict_cache = False + # NOTE: Need to revisit these... + # self.maintain_node_dict_cache = False + # self.maintain_adj_dict_cache = False + self.use_nx_cache = False self.use_coo_cache = False self.src_indices = None self.dst_indices = None self.vertex_ids_to_index = None - self.set_db() - if self.__db is not None: - self.set_graph_name() + self.default_node_type = default_node_type + self.edge_type_func = edge_type_func + self.default_edge_type = edge_type_func(default_node_type, default_node_type) + + if self.__graph_exists: + self.adb_graph = self.db.graph(graph_name) + self.__create_default_collections() + self.__set_factory_methods() + + super().__init__(*args, **kwargs) + + ####################### + # Init helper methods # + ####################### + + def __set_factory_methods(self) -> None: + """Set the factory methods for the graph, _node, and _adj dictionaries. + + The ArangoDB CRUD operations are handled by the modified dictionaries. + + Handles the creation of the following dictionaries: + - graph_attr_dict_factory (graph-level attributes) + - node_dict_factory (nodes in the graph) + - node_attr_dict_factory (attributes of the nodes in the graph) + - adjlist_outer_dict_factory (outer dictionary for the adjacency list) + - adjlist_inner_dict_factory (inner dictionary for the adjacency list) + - edge_attr_dict_factory (attributes of the edges in the graph) + """ + self.graph_attr_dict_factory = graph_dict_factory(self.db, self.graph_name) + + self.node_dict_factory = node_dict_factory( + self.db, self.adb_graph, self.default_node_type + ) + + self.node_attr_dict_factory = node_attr_dict_factory(self.db, self.adb_graph) + + self.adjlist_outer_dict_factory = adjlist_outer_dict_factory( + self.db, self.adb_graph, self.default_node_type, self.edge_type_func + ) + self.adjlist_inner_dict_factory = adjlist_inner_dict_factory( + self.db, self.adb_graph, self.default_node_type, self.edge_type_func + ) + self.edge_attr_dict_factory = edge_attr_dict_factory(self.db, self.adb_graph) + + def __create_default_collections(self) -> None: + if self.default_node_type not in self.adb_graph.vertex_collections(): + self.adb_graph.create_vertex_collection(self.default_node_type) + + if not self.adb_graph.has_edge_definition(self.default_edge_type): + self.adb_graph.create_edge_definition( + edge_collection=self.default_edge_type, + from_vertex_collections=[self.default_node_type], + to_vertex_collections=[self.default_node_type], + ) + + ########### + # Getters # + ########### @property def db(self) -> StandardDatabase: if self.__db is None: - raise ValueError("Database not set") + raise DatabaseNotSet("Database not set") return self.__db @property def graph_name(self) -> str: if self.__graph_name is None: - raise ValueError("Graph name not set") + raise GraphNameNotSet("Graph name not set") return self.__graph_name @@ -63,106 +147,171 @@ def graph_name(self) -> str: def graph_exists(self) -> bool: return self.__graph_exists - def clear_coo_cache(self): - self.src_indices = None - self.dst_indices = None - self.vertex_ids_to_index = None + ########### + # Setters # + ########### - def set_db(self, db: StandardDatabase | None = None): + def __set_db(self, db: StandardDatabase | None = None): if db is not None: if not isinstance(db, StandardDatabase): - raise TypeError( - "**db** must be an instance of arango.database.StandardDatabase" - ) + m = "arango.database.StandardDatabase" + raise TypeError(m) self.__db = db return - self.__host = os.getenv("DATABASE_HOST") - self.__username = os.getenv("DATABASE_USERNAME") - self.__password = os.getenv("DATABASE_PASSWORD") - self.__db_name = os.getenv("DATABASE_NAME") + self._host = os.getenv("DATABASE_HOST") + self._username = os.getenv("DATABASE_USERNAME") + self._password = os.getenv("DATABASE_PASSWORD") + self._db_name = os.getenv("DATABASE_NAME") # TODO: Raise a custom exception if any of the environment # variables are missing. For now, we'll just set db to None. - if not all([self.__host, self.__username, self.__password, self.__db_name]): + if not all([self._host, self._username, self._password, self._db_name]): self.__db = None + logger.warning("Database environment variables not set") return - self.__db = ArangoClient(hosts=self.__host, request_timeout=None).db( - self.__db_name, self.__username, self.__password, verify=True - ) + try: + self.__db = ArangoClient(hosts=self._host, request_timeout=None).db( + self._db_name, self._username, self._password, verify=True + ) + except ServerConnectionError as e: + self.__db = None + logger.warning(f"Could not connect to the database: {e}") - def set_graph_name(self, graph_name: str | None = None): + def __set_graph_name(self, graph_name: str | None = None): if self.__db is None: - raise ValueError("Cannot set graph name without setting the database first") + raise DatabaseNotSet( + "Cannot set graph name without setting the database first" + ) - self.__graph_name = os.getenv("DATABASE_GRAPH_NAME") - if graph_name is not None: - if not isinstance(graph_name, str): - raise TypeError("**graph_name** must be a string") + if graph_name is None: + self.__graph_exists = False + logger.warning(f"**graph_name** not set for {self.__class__.__name__}") + return - self.__graph_name = graph_name + if not isinstance(graph_name, str): + raise TypeError("**graph_name** must be a string") + + self.__graph_name = graph_name + self.__graph_exists = self.db.has_graph(graph_name) + + logger.info(f"Graph '{graph_name}' exists: {self.__graph_exists}") + + #################### + # ArangoDB Methods # + #################### + + # TODO: proper subgraphing! + def aql(self, query: str, bind_vars: dict | None = None, **kwargs) -> Cursor: + return nxadb.classes.function.aql(self.db, query, bind_vars, **kwargs) + + def pull(self, load_node_dict=True, load_adj_dict=True, load_coo=True): + """Load the graph from the ArangoDB database, and update existing graph object. + + :param load_node_dict: Load the node dictionary. + Enabling this option will clear the existing node dictionary, + and replace it with the node data from the database. Comes with + a remote reference to the database. <--- TODO: Should we paramaterize this? + :type load_node_dict: bool + :param load_adj_dict: Load the adjacency dictionary. + Enabling this option will clear the existing adjacency dictionary, + and replace it with the edge data from the database. Comes with + a remote reference to the database. <--- TODO: Should we paramaterize this? + :type load_adj_dict: bool + :param load_coo: Load the COO representation. If False, the src & dst indices will be empty, + along with the node-ID-to-index mapping. Used for nx-cuGraph compatibility. + :type load_coo: bool + """ + node_dict, adj_dict, src_indices, dst_indices, vertex_ids_to_indices = ( + nxadb.classes.function.get_arangodb_graph( + self, + load_node_dict=load_node_dict, + load_adj_dict=load_adj_dict, + load_adj_dict_as_directed=False, + load_coo=load_coo, + ) + ) - if self.__graph_name is None: - self.__graph_exists = False - print("DATABASE_GRAPH_NAME environment variable not set") + if load_node_dict: + self._node.clear() - elif not self.db.has_graph(self.__graph_name): - self.__graph_exists = False - print(f"Graph '{self.__graph_name}' does not exist in the database") + for node_id, node_data in node_dict.items(): + node_attr_dict = self.node_attr_dict_factory() + node_attr_dict.node_id = node_id + node_attr_dict.data = node_data + self._node.data[node_id] = node_attr_dict - else: - self.__graph_exists = True - print(f"Found graph '{self.__graph_name}' in the database") - - def pull(self, load_node_and_adj_dict=True, load_coo=True): - if not self.graph_exists: - raise ValueError("Graph does not exist in the database") - - adb_graph = self.db.graph(self.graph_name) - - v_cols = adb_graph.vertex_collections() - edge_definitions = adb_graph.edge_definitions() - e_cols = {c["edge_collection"] for c in edge_definitions} - - metagraph = { - "vertexCollections": {col: {} for col in v_cols}, - "edgeCollections": {col: {} for col in e_cols}, - } - - from phenolrs.graph_loader import GraphLoader - - kwargs = {} - if self.graph_loader_parallelism is not None: - kwargs["parallelism"] = self.graph_loader_parallelism - if self.graph_loader_batch_size is not None: - kwargs["batch_size"] = self.graph_loader_batch_size - - result = GraphLoader.load( - self.db.name, - metagraph, - [self.__host], - username=self.__username, - password=self.__password, - load_node_dict=load_node_and_adj_dict, - load_adj_dict=load_node_and_adj_dict, - load_adj_dict_as_undirected=True, - load_coo=load_coo, - **kwargs, - ) + if load_adj_dict: + self._adj.clear() + + for src_node_id, dst_dict in adj_dict.items(): + adjlist_inner_dict = self.adjlist_inner_dict_factory() + adjlist_inner_dict.src_node_id = src_node_id + + self._adj.data[src_node_id] = adjlist_inner_dict - if load_node_and_adj_dict: - # hacky, i don't like this - # need to revisit... - # consider using nx.convert.from_dict_of_dicts instead - self._node = result[0] - self._adj = result[1] + for dst_id, edge_data in dst_dict.items(): + edge_attr_dict = self.edge_attr_dict_factory() + edge_attr_dict.edge_id = edge_data["_id"] + edge_attr_dict.data = edge_data + + adjlist_inner_dict.data[dst_id] = edge_attr_dict if load_coo: - self.src_indices = result[2] - self.dst_indices = result[3] - self.vertex_ids_to_index = result[4] + self.src_indices = src_indices + self.dst_indices = dst_indices + self.vertex_ids_to_index = vertex_ids_to_indices def push(self): raise NotImplementedError("What would this look like?") + + ##################### + # nx.Graph Overides # + ##################### + + @cached_property + def nodes(self): + if self.graph_exists: + logger.warning("nxadb.CustomNodeView is currently EXPERIMENTAL") + return CustomNodeView(self) + + return nx.classes.reportviews.NodeView(self) + + @cached_property + def edges(self): + if self.graph_exists: + logger.warning("nxadb.CustomEdgeView is currently EXPERIMENTAL") + return CustomEdgeView(self) + + return nx.classes.reportviews.EdgeView(self) + + def add_node(self, node_for_adding, **attr): + if node_for_adding not in self._node: + if node_for_adding is None: + raise ValueError("None cannot be a node") + self._adj[node_for_adding] = self.adjlist_inner_dict_factory() + + ###################### + # NOTE: monkey patch # + ###################### + + # Old: + # attr_dict = self._node[node_for_adding] = self.node_attr_dict_factory() + # attr_dict.update(attr) + + # New: + self._node[node_for_adding] = self.node_attr_dict_factory() + self._node[node_for_adding].update(attr) + + # Reason: + # Invoking `update` on the `attr_dict` without `attr_dict.node_id` being set + # i.e trying to update a node's attributes before we know _which_ node it is + + ########################### + + else: + self._node[node_for_adding].update(attr) + + nx._clear_cache(self) diff --git a/nx_arangodb/classes/multidigraph.py b/nx_arangodb/classes/multidigraph.py index e2f86584..cabc1b93 100644 --- a/nx_arangodb/classes/multidigraph.py +++ b/nx_arangodb/classes/multidigraph.py @@ -1,33 +1,25 @@ +from typing import ClassVar + import networkx as nx import nx_arangodb as nxadb -from nx_arangodb.classes.digraph import DiGraph -from nx_arangodb.classes.multigraph import MultiGraph +from nx_arangodb.logger import logger networkx_api = nxadb.utils.decorators.networkx_class(nx.MultiDiGraph) __all__ = ["MultiDiGraph"] -class MultiDiGraph(nx.MultiDiGraph, MultiGraph, DiGraph): +class MultiDiGraph(nx.MultiDiGraph): + __networkx_backend__: ClassVar[str] = "arangodb" # nx >=3.2 + __networkx_plugin__: ClassVar[str] = "arangodb" # nx <3.2 + @classmethod def to_networkx_class(cls) -> type[nx.MultiDiGraph]: return nx.MultiDiGraph - def __init__( - self, - *args, - **kwargs, - ): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - - self.__db = None - self.__graph_name = None - self.__graph_exists = False - - self.coo_load_parallelism = None - self.coo_load_batch_size = None - - self.set_db() - if self.__db is not None: - self.set_graph_name() + self.graph_exists = False + m = "nxadb.MultiDiGraph has not been implemented yet. This is a pass-through subclass of nx.MultiDiGraph for now." + logger.warning(m) diff --git a/nx_arangodb/classes/multigraph.py b/nx_arangodb/classes/multigraph.py index 0a252284..80dc47fa 100644 --- a/nx_arangodb/classes/multigraph.py +++ b/nx_arangodb/classes/multigraph.py @@ -1,32 +1,25 @@ +from typing import ClassVar + import networkx as nx import nx_arangodb as nxadb -from nx_arangodb.classes.graph import Graph +from nx_arangodb.logger import logger networkx_api = nxadb.utils.decorators.networkx_class(nx.MultiGraph) __all__ = ["MultiGraph"] -class MultiGraph(nx.MultiGraph, Graph): +class MultiGraph(nx.MultiGraph): + __networkx_backend__: ClassVar[str] = "arangodb" # nx >=3.2 + __networkx_plugin__: ClassVar[str] = "arangodb" # nx <3.2 + @classmethod def to_networkx_class(cls) -> type[nx.MultiGraph]: return nx.MultiGraph - def __init__( - self, - *args, - **kwargs, - ): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - - self.__db = None - self.__graph_name = None - self.__graph_exists = False - - self.coo_load_parallelism = None - self.coo_load_batch_size = None - - self.set_db() - if self.__db is not None: - self.set_graph_name() + self.graph_exists = False + m = "nxadb.MultiGraph has not been implemented yet. This is a pass-through subclass of nx.MultiGraph for now." + logger.warning(m) diff --git a/nx_arangodb/classes/reportviews.py b/nx_arangodb/classes/reportviews.py new file mode 100644 index 00000000..0dd4655f --- /dev/null +++ b/nx_arangodb/classes/reportviews.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import networkx as nx + +import nx_arangodb as nxadb + + +class CustomNodeView(nx.classes.reportviews.NodeView): + def __call__(self, data=False, default=None): + if data is False: + return self + return CustomNodeDataView(self._nodes, data, default) + + def data(self, data=True, default=None): + if data is False: + return self + return CustomNodeDataView(self._nodes, data, default) + + +class CustomNodeDataView(nx.classes.reportviews.NodeDataView): + def __iter__(self): + data = self._data + if data is False: + return iter(self._nodes) + if data is True: + return iter(self._nodes.items()) + + ###################### + # NOTE: Monkey Patch # + ###################### + + # Old: + # return ( + # (n, dd[data] if data in dd else self._default) + # for n, dd in self._nodes.items() + # ) + + # New: + return iter(self._nodes.items(data=data, default=self._default)) + + # Reason: We can utilize AQL to filter the data we + # want to return, instead of filtering it in Python + + ########################### + + +class CustomEdgeDataView(nx.classes.reportviews.EdgeDataView): + + ###################### + # NOTE: Monkey Patch # + ###################### + + # Reason: We can utilize AQL to filter the data we + # want to return, instead of filtering it in Python + # This is hacky for now, but it's meant to show that + # the data can be filtered server-side. + # We solve this by relying on self._adjdict, which + # is the AdjListOuterDict object that has a custom + # items() method that can filter data with AQL. + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if self._data and not isinstance(self._data, bool): + self._report = lambda n, nbr, dd: self._adjdict.items( + data=self._data, default=self._default + ) + + def __iter__(self): + if self._data and not isinstance(self._data, bool): + # don't need to filter data in Python + return self._report("", "", "") + + return ( + self._report(n, nbr, dd) + for n, nbrs in self._nodes_nbrs() + for nbr, dd in nbrs.items() + ) + + +class CustomEdgeView(nx.classes.reportviews.EdgeView): + dataview = CustomEdgeDataView + + def __len__(self): + + ###################### + # NOTE: Monkey Patch # + ###################### + + # Old: + # num_nbrs = (len(nbrs) + (n in nbrs) for n, nbrs in self._nodes_nbrs()) + # return sum(num_nbrs) // 2 + + # New: + G: nxadb.Graph = self._graph + return sum( + [ + G.db.collection(ed["edge_collection"]).count() + for ed in G.adb_graph.edge_definitions() + ] + ) + + # Reason: We can utilize AQL to count the number of edges + # instead of making individual requests to the database + # i.e avoid having to do `n in nbrs` for each node + + ###################### diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index 12f46c11..25ce13cd 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -1,13 +1,25 @@ -# Copied from nx-cugraph - from __future__ import annotations import itertools +import time from typing import TYPE_CHECKING import networkx as nx import nx_arangodb as nxadb +from nx_arangodb.logger import logger + +try: + import cupy as cp + import numpy as np + import nx_cugraph as nxcg + + GPU_ENABLED = True + logger.info("NXCG is enabled.") +except ModuleNotFoundError as e: + GPU_ENABLED = False + logger.info(f"NXCG is disabled. {e}.") + if TYPE_CHECKING: # pragma: no cover from nx_arangodb.typing import AttrKey, Dtype, EdgeValue, NodeValue @@ -102,6 +114,8 @@ def from_networkx( -------- to_networkx : The opposite; convert nx_arangodb graph to networkx graph """ + logger.debug(f"from_networkx for {graph.__class__.__name__}") + if not isinstance(graph, nx.Graph): if isinstance(graph, nx.classes.reportviews.NodeView): # Convert to a Graph with only nodes (no edges) @@ -123,7 +137,6 @@ def from_networkx( else: klass = nxadb.Graph - print(f"ANTHONY: Called from_networkx for {graph.__class__.__name__}") return klass(incoming_graph_data=graph) @@ -153,29 +166,67 @@ def to_networkx(G: nxadb.Graph, *, sort_edges: bool = False) -> nx.Graph: -------- from_networkx : The opposite; convert networkx graph to nx_cugraph graph """ + logger.debug(f"to_networkx for {G.__class__.__name__}") + if not isinstance(G, nxadb.Graph): raise TypeError(f"Expected nx_arangodb.Graph; got {type(G)}") - print(f"ANTHONY: Called to_networkx for {G.__class__.__name__}") return G.to_networkx_class()(incoming_graph_data=G) -def from_networkx_arangodb(G: nxadb.Graph) -> nxadb.Graph: +def from_networkx_arangodb( + G: nxadb.Graph | nxadb.DiGraph, pull_graph: bool +) -> nxadb.Graph | nxadb.DiGraph: + logger.debug(f"from_networkx_arangodb for {G.__class__.__name__}") + + if not isinstance(G, (nxadb.Graph, nxadb.DiGraph)): + raise TypeError(f"Expected nx_arangodb.(Graph || DiGraph); got {type(G)}") + if not G.graph_exists: - print("ANTHONY: Graph does not exist, nothing to pull") + logger.debug("graph does not exist, nothing to pull") return G - if G.use_node_and_adj_dict_cache and len(G.nodes) > 0 and len(G.adj) > 0: - print("ANTHONY: Using cached node and adj dict") + if not pull_graph: + if isinstance(G, nxadb.DiGraph): + m = "nx_arangodb.DiGraph has no CRUD Support yet. Cannot rely on remote connection." + raise NotImplementedError(m) + + logger.debug("graph exists, but not pulling. relying on remote connection...") return G + # if G.use_nx_cache and G._node and G._adj: + # m = "**use_nx_cache** is enabled. using cached data. no pull required." + # logger.debug(m) + # return G + + logger.debug("pulling as NetworkX Graph...") start_time = time.time() - G.pull(load_coo=False) + node_dict, adj_dict, _, _, _ = nxadb.classes.function.get_arangodb_graph( + G, + load_node_dict=True, + load_adj_dict=True, + load_adj_dict_as_directed=G.is_directed(), + load_coo=False, + ) end_time = time.time() + logger.debug(f"load took {end_time - start_time} seconds") + + # Copied from nx.convert.to_networkx_graph + try: + logger.debug("creating nx graph from loaded ArangoDB data...") + result = nx.convert.from_dict_of_dicts( + adj_dict, + create_using=G.__class__, + multigraph_input=G.is_multigraph(), + ) - print("ANTHONY: Node & Adj Load took:", end_time - start_time) + for n, dd in node_dict.items(): + result._node[n].update(dd) - return G + return result + + except Exception as err: + raise nx.NetworkXError("Input is not a correct NetworkX graph.") from err def _to_nxadb_graph( @@ -183,10 +234,13 @@ def _to_nxadb_graph( edge_attr: AttrKey | None = None, edge_default: EdgeValue | None = 1, edge_dtype: Dtype | None = None, + pull_graph: bool = True, ) -> nxadb.Graph | nxadb.DiGraph: """Ensure that input type is a nx_arangodb graph, and convert if necessary.""" - if isinstance(G, nxadb.Graph): - return from_networkx_arangodb(G) + logger.debug(f"_to_nxadb_graph for {G.__class__.__name__}") + + if isinstance(G, (nxadb.Graph, nxadb.DiGraph)): + return from_networkx_arangodb(G, pull_graph) if isinstance(G, nx.Graph): return from_networkx( @@ -196,13 +250,7 @@ def _to_nxadb_graph( raise TypeError -try: - import os - import time - - import cupy as cp - import numpy as np - import nx_cugraph as nxcg +if GPU_ENABLED: def _to_nxcg_graph( G, @@ -212,9 +260,13 @@ def _to_nxcg_graph( as_directed: bool = False, ) -> nxcg.Graph | nxcg.DiGraph: """Ensure that input type is a nx_cugraph graph, and convert if necessary.""" + logger.debug(f"_to_nxcg_graph for {G.__class__.__name__}") + if isinstance(G, nxcg.Graph): + logger.debug("already an nx_cugraph graph") return G - if isinstance(G, nxadb.Graph): + + if isinstance(G, (nxadb.Graph, nxadb.DiGraph)): # Assumption: G.adb_graph_name points to an existing graph in ArangoDB # Therefore, the user wants us to pull the graph from ArangoDB, # and convert it to an nx_cugraph graph. @@ -223,14 +275,20 @@ def _to_nxcg_graph( # the NetworkX graph to an nx_cugraph graph. # TODO: Implement a direct conversion from ArangoDB to nx_cugraph if G.graph_exists: - print("ANTHONY: Graph exists, running _nxadb_to_nxcg()") - return _nxadb_to_nxcg(G, as_directed=as_directed) + logger.debug("converting nx_arangodb graph to nx_cugraph graph") + return nxcg_from_networkx_arangodb(G, as_directed=as_directed) + + if isinstance(G, (nxadb.MultiGraph, nxadb.MultiDiGraph)): + raise NotImplementedError( + "nxadb.MultiGraph not yet supported for _to_nxcg_graph()" + ) # If G is a networkx graph, or is a nxadb graph that doesn't point to an "existing" # ArangoDB graph, then we just treat it as a normal networkx graph & # convert it to nx_cugraph. # TODO: Need to revisit the "existing" ArangoDB graph condition... if isinstance(G, nx.Graph): + logger.debug("converting networkx graph to nx_cugraph graph") return nxcg.convert.from_networkx( G, {edge_attr: edge_default} if edge_attr is not None else None, @@ -241,9 +299,12 @@ def _to_nxcg_graph( # TODO: handle cugraph.Graph raise TypeError - def _nxadb_to_nxcg( - G: nxadb.Graph, as_directed: bool = False + def nxcg_from_networkx_arangodb( + G: nxadb.Graph | nxadb.DiGraph, as_directed: bool = False ) -> nxcg.Graph | nxcg.DiGraph: + """Convert an nx_arangodb graph to nx_cugraph graph.""" + logger.debug(f"nxcg_from_networkx_arangodb for {G.__class__.__name__}") + if G.is_multigraph(): raise NotImplementedError("Multigraphs not yet supported") @@ -253,14 +314,27 @@ def _nxadb_to_nxcg( and G.dst_indices is not None and G.vertex_ids_to_index is not None ): - print("ANTHONY: Using cached COO") + m = "**use_coo_cache** is enabled. using cached COO data. no pull required." + logger.debug(m) else: + logger.debug("pulling as NetworkX-CuGraph Graph...") start_time = time.time() - G.pull(load_node_and_adj_dict=False) + _, _, src_indices, dst_indices, vertex_ids_to_index = ( + nxadb.classes.function.get_arangodb_graph( + G, + load_node_dict=False, + load_adj_dict=False, + load_adj_dict_as_directed=G.is_directed(), + load_coo=True, + ) + ) end_time = time.time() + logger.debug(f"load took {end_time - start_time} seconds") - print("ANTHONY: COO Load took:", end_time - start_time) + G.src_indices = src_indices + G.dst_indices = dst_indices + G.vertex_ids_to_index = vertex_ids_to_index N = len(G.vertex_ids_to_index) @@ -269,8 +343,8 @@ def _nxadb_to_nxcg( else: klass = nxcg.Graph + logger.debug("creating nx_cugraph graph from COO data...") start_time = time.time() - rv = klass.from_coo( N, cp.array(G.src_indices), @@ -278,13 +352,11 @@ def _nxadb_to_nxcg( key_to_id=G.vertex_ids_to_index, ) end_time = time.time() - - print("ANTHONY: from_coo took:", end_time - start_time) + logger.debug(f"nxcg from_coo took {end_time - start_time}") return rv -except ModuleNotFoundError as e: - print(f"ANTHONY: {e}") +else: def _to_nxcg_graph( G, @@ -292,6 +364,6 @@ def _to_nxcg_graph( edge_default: EdgeValue | None = 1, edge_dtype: Dtype | None = None, as_directed: bool = False, - ) -> nxadb.Graph: + ) -> nxcg.Graph | nxcg.DiGraph: m = "nx-cugraph is not installed; cannot convert to nx-cugraph graph" raise NotImplementedError(m) diff --git a/nx_arangodb/exceptions.py b/nx_arangodb/exceptions.py new file mode 100644 index 00000000..0f014e40 --- /dev/null +++ b/nx_arangodb/exceptions.py @@ -0,0 +1,37 @@ +class NetworkXArangoDBException(Exception): + pass + + +class GraphDoesNotExist(NetworkXArangoDBException): + pass + + +class DatabaseNotSet(NetworkXArangoDBException): + pass + + +class GraphNameNotSet(NetworkXArangoDBException): + pass + + +class InvalidTraversalDirection(NetworkXArangoDBException): + pass + + +EDGE_ALREADY_EXISTS_ERROR_CODE = 1210 + + +class EdgeAlreadyExists(NetworkXArangoDBException): + pass + + +class AQLMultipleResultsFound(NetworkXArangoDBException): + pass + + +class ArangoDBAlgorithmError(NetworkXArangoDBException): + pass + + +class ShortestPathError(ArangoDBAlgorithmError): + pass diff --git a/nx_arangodb/interface.py b/nx_arangodb/interface.py index 738cfe81..3acf4ecc 100644 --- a/nx_arangodb/interface.py +++ b/nx_arangodb/interface.py @@ -1,5 +1,3 @@ -# Copied from nx-cugraph - from __future__ import annotations import os diff --git a/nx_arangodb/logger.py b/nx_arangodb/logger.py new file mode 100644 index 00000000..a69ea0b2 --- /dev/null +++ b/nx_arangodb/logger.py @@ -0,0 +1,19 @@ +import logging + +logger = logging.getLogger(__package__) + +if logger.hasHandlers(): + logger.handlers.clear() + +handler = logging.StreamHandler() + +formatter = logging.Formatter( + f"[%(asctime)s] [%(levelname)s]: %(message)s", + "%H:%M:%S %z", +) + +handler.setFormatter(formatter) + +logger.addHandler(handler) + +logger.setLevel(logging.INFO) diff --git a/run_nx_tests.sh b/run_nx_tests.sh index db982aad..0c2fd844 100755 --- a/run_nx_tests.sh +++ b/run_nx_tests.sh @@ -1,10 +1,17 @@ # Copied from nx-cugraph +set -e + +# TODO: address the following tests +# --pyargs networkx.algorithms.community.louvain \ NETWORKX_GRAPH_CONVERT=arangodb \ NETWORKX_TEST_BACKEND=arangodb \ NETWORKX_FALLBACK_TO_NX=True \ pytest \ - --pyargs networkx.classes networkx.algorithms.centrality \ + --pyargs networkx.classes \ + --pyargs networkx.algorithms.centrality \ + --pyargs networkx.algorithms.link_analysis \ + --pyargs networkx.algorithms.shortest_paths \ --cov-config=$(dirname $0)/pyproject.toml \ --cov=nx_arangodb \ --cov-report= \ diff --git a/starter.sh b/starter.sh new file mode 100755 index 00000000..c0bd939f --- /dev/null +++ b/starter.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +# Starts a local ArangoDB server or cluster (community or enterprise). +# Useful for testing the python-arango driver against a local ArangoDB setup. + +# Usage: +# ./starter.sh [single|cluster] [community|enterprise] +# Example: +# ./starter.sh cluster enterprise + +extra_ports="-p 8539:8539 -p 8549:8549" +image_name="enterprise" +conf_file="cluster.conf" + +docker run -d \ + --name arango \ + -p 8528:8528 \ + -p 8529:8529 \ + $extra_ports \ + -v "$(pwd)/tests/static/":/tests/static \ + -v /tmp:/tmp \ + "arangodb/$image_name:latest" \ + /bin/sh -c "arangodb --configuration=/tests/static/$conf_file" diff --git a/temp.py b/temp.py new file mode 100644 index 00000000..30f17c18 --- /dev/null +++ b/temp.py @@ -0,0 +1 @@ +g = nxadb.Graph(graph_name=graph_name) diff --git a/tests/conftest.py b/tests/conftest.py index da5298ed..1ed72b36 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,22 @@ +import logging +import os from typing import Any + +import networkx as nx +import pytest +from adbnx_adapter import ADBNX_Adapter from arango import ArangoClient +from nx_arangodb.logger import logger + +logger.setLevel(logging.INFO) + def pytest_addoption(parser: Any) -> None: parser.addoption("--url", action="store", default="http://localhost:8529") parser.addoption("--dbName", action="store", default="_system") parser.addoption("--username", action="store", default="root") - parser.addoption("--password", action="store", default="test") + parser.addoption("--password", action="store", default="passwd") def pytest_configure(config: Any) -> None: @@ -28,3 +38,25 @@ def pytest_configure(config: Any) -> None: db = ArangoClient(hosts=con["url"]).db( con["dbName"], con["username"], con["password"], verify=True ) + + os.environ["DATABASE_HOST"] = con["url"] + os.environ["DATABASE_USERNAME"] = con["username"] + os.environ["DATABASE_PASSWORD"] = con["password"] + os.environ["DATABASE_NAME"] = con["dbName"] + + +@pytest.fixture(scope="function") +def load_graph(): + db.delete_graph("KarateGraph", drop_collections=True, ignore_missing=True) + adapter = ADBNX_Adapter(db) + adapter.networkx_to_arangodb( + "KarateGraph", + nx.karate_club_graph(), + edge_definitions=[ + { + "edge_collection": "knows", + "from_vertex_collections": ["person"], + "to_vertex_collections": ["person"], + } + ], + ) diff --git a/tests/static/cluster.conf b/tests/static/cluster.conf new file mode 100644 index 00000000..d33e07a3 --- /dev/null +++ b/tests/static/cluster.conf @@ -0,0 +1,15 @@ +[starter] +mode = cluster +local = true +address = 0.0.0.0 +port = 8528 + +[auth] +jwt-secret = /tests/static/keyfile + +[args] +all.database.password = passwd +all.database.extended-names = true +all.log.api-enabled = true +all.javascript.allow-admin-execute = true +all.server.options-api = admin diff --git a/tests/static/keyfile b/tests/static/keyfile new file mode 100644 index 00000000..d97c5ead --- /dev/null +++ b/tests/static/keyfile @@ -0,0 +1 @@ +secret diff --git a/tests/static/service.zip b/tests/static/service.zip new file mode 100644 index 00000000..00bf513e Binary files /dev/null and b/tests/static/service.zip differ diff --git a/tests/static/setup.sh b/tests/static/setup.sh new file mode 100644 index 00000000..0d2189ba --- /dev/null +++ b/tests/static/setup.sh @@ -0,0 +1,7 @@ +#!/bin/sh + +mkdir -p /tests/static +wget -O /tests/static/service.zip "http://localhost:8000/$PROJECT/tests/static/service.zip" +wget -O /tests/static/keyfile "http://localhost:8000/$PROJECT/tests/static/keyfile" +wget -O /tests/static/arangodb.conf "http://localhost:8000/$PROJECT/tests/static/$ARANGODB_CONF" +arangodb --configuration=/tests/static/arangodb.conf diff --git a/tests/test.py b/tests/test.py index f21b176b..8f8ab06f 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1,47 +1,393 @@ -import pytest import networkx as nx +import pytest + import nx_arangodb as nxadb from .conftest import db -def test_db(): +def test_db(load_graph): assert db.version() -def test_bc(): +def test_bc(load_graph): G_1 = nx.karate_club_graph() - - G_2 = nxadb.Graph(G_1) + G_2 = nxadb.Graph(incoming_graph_data=G_1) r_1 = nx.betweenness_centrality(G_1) r_2 = nx.betweenness_centrality(G_2) r_3 = nx.betweenness_centrality(G_1, backend="arangodb") r_4 = nx.betweenness_centrality(G_2, backend="arangodb") - assert r_1 and r_2 and r_3 and r_4 + assert len(r_1) == len(r_2) == len(r_3) == len(r_4) > 0 + + try: + import phenolrs + except ModuleNotFoundError: + return + + G_3 = nxadb.Graph(graph_name="KarateGraph") + r_5 = nx.betweenness_centrality(G_3) -def test_pagerank(): + G_4 = nxadb.Graph(graph_name="KarateGraph") + r_6 = nxadb.betweenness_centrality(G_4, pull_graph_on_cpu=False) + + G_5 = nxadb.DiGraph(graph_name="KarateGraph") + r_7 = nx.betweenness_centrality(G_5) + + # assert r_5 == r_6 # this is acting strange. I need to revisit + assert r_6 == r_7 + assert len(r_5) == len(r_6) == len(r_7) > 0 + + +def test_pagerank(load_graph): G_1 = nx.karate_club_graph() - G_2 = nxadb.Graph(G_1) + G_2 = nxadb.Graph(incoming_graph_data=G_1) r_1 = nx.pagerank(G_1) r_2 = nx.pagerank(G_2) r_3 = nx.pagerank(G_1, backend="arangodb") r_4 = nx.pagerank(G_2, backend="arangodb") - assert r_1 and r_2 and r_3 and r_4 + assert len(r_1) == len(r_2) == len(r_3) == len(r_4) > 0 + + try: + import phenolrs + except ModuleNotFoundError: + return + + G_3 = nxadb.Graph(graph_name="KarateGraph") + r_5 = nx.pagerank(G_3) + G_4 = nxadb.Graph(graph_name="KarateGraph") + r_6 = nxadb.pagerank(G_4, pull_graph_on_cpu=False) -def test_louvain(): + G_5 = nxadb.DiGraph(graph_name="KarateGraph") + r_7 = nx.pagerank(G_5) + + assert len(r_5) == len(r_6) == len(r_7) == len(G_4) + + +def test_louvain(load_graph): G_1 = nx.karate_club_graph() - G_2 = nxadb.Graph(G_1) + G_2 = nxadb.Graph(incoming_graph_data=G_1) r_1 = nx.community.louvain_communities(G_1) r_2 = nx.community.louvain_communities(G_2) r_3 = nx.community.louvain_communities(G_1, backend="arangodb") r_4 = nx.community.louvain_communities(G_2, backend="arangodb") - assert r_1 and r_2 and r_3 and r_4 + assert len(r_1) > 0 + assert len(r_2) > 0 + assert len(r_3) > 0 + assert len(r_4) > 0 + + try: + import phenolrs + except ModuleNotFoundError: + return + + G_3 = nxadb.Graph(graph_name="KarateGraph") + r_5 = nx.community.louvain_communities(G_3) + + G_4 = nxadb.Graph(graph_name="KarateGraph") + r_6 = nxadb.community.louvain_communities(G_4, pull_graph_on_cpu=False) + + G_5 = nxadb.DiGraph(graph_name="KarateGraph") + r_7 = nx.community.louvain_communities(G_5) + + assert len(r_5) > 0 + assert len(r_6) > 0 + assert len(r_7) > 0 + + +def test_shortest_path(load_graph): + G_1 = nxadb.Graph(graph_name="KarateGraph") + G_2 = nxadb.DiGraph(graph_name="KarateGraph") + + r_1 = nx.shortest_path(G_1, source="person/1", target="person/34") + r_2 = nx.shortest_path(G_1, source="person/1", target="person/34", weight="weight") + r_3 = nx.shortest_path(G_2, source="person/1", target="person/34") + r_4 = nx.shortest_path(G_2, source="person/1", target="person/34", weight="weight") + + assert r_1 == r_3 + assert r_2 == r_4 + assert r_1 != r_2 + assert r_3 != r_4 + + +def test_graph_nodes_crud(load_graph): + G_1 = nxadb.Graph(graph_name="KarateGraph", foo="bar") + G_2 = nx.Graph(nx.karate_club_graph()) + + assert G_1.graph_name == "KarateGraph" + assert G_1.graph["foo"] == "bar" + + assert len(G_1.nodes) == len(G_2.nodes) + + for k, v in G_1.nodes(data=True): + assert db.document(k) == v + + for k, v in G_1.nodes(data="club"): + assert db.document(k)["club"] == v + + for k, v in G_1.nodes(data="bad_key", default="boom!"): + doc = db.document(k) + assert "bad_key" not in doc + assert v == "boom!" + + G_1.clear() # clear cache + + person_1 = G_1.nodes["person/1"] + assert person_1["_key"] == "1" + assert person_1["_id"] == "person/1" + assert person_1["club"] == "Mr. Hi" + + assert G_1.nodes["person/2"]["club"] + assert set(G_1._node.data.keys()) == {"person/1", "person/2"} + + G_1.nodes["person/3"]["club"] = "foo" + assert db.document("person/3")["club"] == "foo" + G_1.nodes["person/3"]["club"] = "bar" + assert db.document("person/3")["club"] == "bar" + + for k in G_1: + assert G_1.nodes[k] == db.document(k) + + for v in G_1.nodes.values(): + assert v + + G_1.clear() + + assert not G_1._node.data + + for k, v in G_1.nodes.items(): + assert k == v["_id"] + + with pytest.raises(KeyError): + G_1.nodes["person/unknown"] + + assert G_1.nodes["person/1"]["club"] == "Mr. Hi" + G_1.add_node("person/1", club="updated value") + assert G_1.nodes["person/1"]["club"] == "updated value" + len(G_1.nodes) == len(G_2.nodes) + + G_1.add_node("person/35", foo={"bar": "baz"}) + len(G_1.nodes) == len(G_2.nodes) + 1 + G_1.clear() + assert G_1.nodes["person/35"]["foo"] == {"bar": "baz"} + # TODO: Support this use case: + # G_1.nodes["person/35"]["foo"]["bar"] = "baz2" + + G_1.add_nodes_from(["1", "2", "3"], foo="bar") + G_1.clear() + assert G_1.nodes["1"]["foo"] == "bar" + assert G_1.nodes["2"]["foo"] == "bar" + assert G_1.nodes["3"]["foo"] == "bar" + + assert db.collection(G_1.default_node_type).count() == 3 + assert db.collection(G_1.default_node_type).has("1") + assert db.collection(G_1.default_node_type).has("2") + assert db.collection(G_1.default_node_type).has("3") + + G_1.remove_node("1") + assert not db.collection(G_1.default_node_type).has("1") + with pytest.raises(KeyError): + G_1.nodes["1"] + + with pytest.raises(KeyError): + G_1.adj["1"] + + G_1.remove_nodes_from(["2", "3"]) + assert not db.collection(G_1.default_node_type).has("2") + assert not db.collection(G_1.default_node_type).has("3") + + with pytest.raises(KeyError): + G_1.nodes["2"] + + with pytest.raises(KeyError): + G_1.adj["2"] + + assert len(G_1.adj["person/1"]) > 0 + assert G_1.adj["person/1"]["person/2"] + edge_id = G_1.adj["person/1"]["person/2"]["_id"] + G_1.remove_node("person/1") + assert not db.has_document("person/1") + assert not db.has_document(edge_id) + + +def test_graph_edges_crud(load_graph): + G_1 = nxadb.Graph(graph_name="KarateGraph") + G_2 = nx.karate_club_graph() + + assert len(G_1.adj) == len(G_2.adj) + assert len(G_1.edges) == len(G_2.edges) + + for src, dst, w in G_1.edges.data("weight"): + assert G_1.adj[src][dst]["weight"] == w + + for src, dst, w in G_1.edges.data("bad_key", default="boom!"): + assert "bad_key" not in G_1.adj[src][dst] + assert w == "boom!" + + for k, edge in G_1.adj["person/1"].items(): + assert db.has_document(k) + assert db.has_document(edge["_id"]) + + G_1.add_edge("person/1", "person/1", foo="bar", _edge_type="knows") + edge_id = G_1.adj["person/1"]["person/1"]["_id"] + doc = db.document(edge_id) + assert doc["foo"] == "bar" + assert G_1.adj["person/1"]["person/1"]["foo"] == "bar" + + del G_1.adj["person/1"]["person/1"]["foo"] + doc = db.document(edge_id) + assert "foo" not in doc + + G_1.adj["person/1"]["person/1"].update({"bar": "foo"}) + doc = db.document(edge_id) + assert doc["bar"] == "foo" + + assert len(G_1.adj["person/1"]["person/1"]) == len(doc) + adj_count = len(G_1.adj["person/1"]) + G_1.remove_edge("person/1", "person/1") + assert len(G_1.adj["person/1"]) == adj_count - 1 + assert not db.has_document(edge_id) + assert "person/1" in G_1 + + assert not db.has_document(f"{G_1.default_node_type}/new_node_1") + col_count = db.collection(G_1.default_edge_type).count() + + G_1.add_edge("new_node_1", "new_node_2", foo="bar") + G_1.add_edge("new_node_1", "new_node_2", foo="bar", bar="foo") + + bind_vars = { + "src": f"{G_1.default_node_type}/new_node_1", + "dst": f"{G_1.default_node_type}/new_node_2", + } + + result = list( + db.aql.execute( + f"FOR e IN {G_1.default_edge_type} FILTER e._from == @src AND e._to == @dst RETURN e", + bind_vars=bind_vars, + ) + ) + + assert len(result) == 1 + + result = list( + db.aql.execute( + f"FOR e IN {G_1.default_edge_type} FILTER e._from == @dst AND e._to == @src RETURN e", + bind_vars=bind_vars, + ) + ) + + assert len(result) == 0 + + assert db.collection(G_1.default_edge_type).count() == col_count + 1 + assert G_1.adj["new_node_1"]["new_node_2"] + assert G_1.adj["new_node_1"]["new_node_2"]["foo"] == "bar" + assert G_1.adj["new_node_2"]["new_node_1"] + assert ( + G_1.adj["new_node_2"]["new_node_1"]["_id"] + == G_1.adj["new_node_1"]["new_node_2"]["_id"] + ) + edge_id = G_1.adj["new_node_1"]["new_node_2"]["_id"] + doc = db.document(edge_id) + assert db.has_document(doc["_from"]) + assert db.has_document(doc["_to"]) + assert G_1.nodes["new_node_1"] + assert G_1.nodes["new_node_2"] + + G_1.remove_edge("new_node_1", "new_node_2") + G_1.clear() + assert "new_node_1" in G_1 + assert "new_node_2" in G_1 + assert "new_node_2" not in G_1.adj["new_node_1"] + + G_1.add_edges_from( + [("new_node_1", "new_node_2"), ("new_node_1", "new_node_3")], foo="bar" + ) + G_1.clear() + assert "new_node_1" in G_1 + assert "new_node_2" in G_1 + assert "new_node_3" in G_1 + assert G_1.adj["new_node_1"]["new_node_2"]["foo"] == "bar" + assert G_1.adj["new_node_1"]["new_node_3"]["foo"] == "bar" + + G_1.remove_edges_from([("new_node_1", "new_node_2"), ("new_node_1", "new_node_3")]) + assert "new_node_1" in G_1 + assert "new_node_2" in G_1 + assert "new_node_3" in G_1 + assert "new_node_2" not in G_1.adj["new_node_1"] + assert "new_node_3" not in G_1.adj["new_node_1"] + + assert G_1["person/1"]["person/2"] == G_1["person/2"]["person/1"] + new_weight = 1000 + G_1["person/1"]["person/2"]["weight"] = new_weight + assert G_1["person/1"]["person/2"]["weight"] == new_weight + assert G_1["person/2"]["person/1"]["weight"] == new_weight + G_1.clear() + assert G_1["person/1"]["person/2"]["weight"] == new_weight + assert G_1["person/2"]["person/1"]["weight"] == new_weight + + +def test_readme(load_graph): + G = nxadb.Graph(graph_name="KarateGraph") + G_nx = nx.karate_club_graph() + + assert len(G.nodes) == len(G_nx.nodes) + assert len(G.adj) == len(G_nx.adj) + assert len(G.edges) == len(G_nx.edges) + + G.nodes(data="club", default="unknown") + G.edges(data="weight", default=1000) + + G.nodes["person/1"] + G.adj["person/1"] + G.edges[("person/1", "person/3")] + + G.nodes["person/1"]["name"] = "John Doe" + G.nodes["person/1"].update({"age": 40}) + del G.nodes["person/1"]["name"] + + G.adj["person/1"]["person/3"]["weight"] = 2 + G.adj["person/1"]["person/3"].update({"weight": 3}) + del G.adj["person/1"]["person/3"]["weight"] + + G.edges[("person/1", "person/3")]["weight"] = 0.5 + assert G.adj["person/1"]["person/3"]["weight"] == 0.5 + + G.add_node("person/35", name="Jane Doe") + G.add_nodes_from( + [("person/36", {"name": "Jack Doe"}), ("person/37", {"name": "Jill Doe"})] + ) + G.add_edge("person/1", "person/35", weight=1.5, _edge_type="knows") + G.add_edges_from( + [ + ("person/1", "person/36", {"weight": 2}), + ("person/1", "person/37", {"weight": 3}), + ], + _edge_type="knows", + ) + + G.remove_edge("person/1", "person/35") + G.remove_edges_from([("person/1", "person/36"), ("person/1", "person/37")]) + G.remove_node("person/35") + G.remove_nodes_from(["person/36", "person/37"]) + + G.clear() + + assert len(G.nodes) == len(G_nx.nodes) + assert len(G.adj) == len(G_nx.adj) + assert len(G.edges) == len(G_nx.edges) + + +def test_digraph_nodes_crud(): + pytest.skip("Not implemented yet") + + +def test_digraph_edges_crud(): + pytest.skip("Not implemented yet")