diff --git a/_nx_arangodb/__init__.py b/_nx_arangodb/__init__.py index ff11a6d0..4d7a2de5 100644 --- a/_nx_arangodb/__init__.py +++ b/_nx_arangodb/__init__.py @@ -11,6 +11,8 @@ $ python _nx_arangodb/__init__.py """ +import networkx as nx + from _nx_arangodb._version import __version__ # This is normally handled by packaging.version.Version, but instead of adding @@ -28,14 +30,7 @@ # "description": "TODO", "functions": { # BEGIN: functions - "betweenness_centrality", - "is_partition", - "louvain_communities", - "louvain_partitions", - "modularity", - "pagerank", "shortest_path", - "to_scipy_sparse_array", # END: functions }, "additional_docs": { @@ -45,27 +40,9 @@ }, "additional_parameters": { # BEGIN: additional_parameters - "is_partition": { - "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", - }, - "louvain_communities": { - "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", - }, - "louvain_partitions": { - "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", - }, - "modularity": { - "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", - }, - "pagerank": { - "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", - }, "shortest_path": { "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", }, - "to_scipy_sparse_array": { - "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", - }, # END: additional_parameters }, } @@ -96,6 +73,17 @@ def get_info(): for key in info_keys: del d[key] + + d["default_config"] = { + "host": None, + "username": None, + "password": None, + "db_name": None, + "load_parallelism": None, + "load_batch_size": None, + "pull_graph": True, + } + return d diff --git a/nx_arangodb/algorithms/__init__.py b/nx_arangodb/algorithms/__init__.py index 00c2d294..ee55416c 100644 --- a/nx_arangodb/algorithms/__init__.py +++ b/nx_arangodb/algorithms/__init__.py @@ -1,5 +1,2 @@ -from . import centrality, community, link_analysis, shortest_paths -from .centrality import * -from .community import * -from .link_analysis import * +from . import shortest_paths from .shortest_paths import * diff --git a/nx_arangodb/algorithms/centrality/__init__.py b/nx_arangodb/algorithms/centrality/__init__.py deleted file mode 100644 index cf7adb68..00000000 --- a/nx_arangodb/algorithms/centrality/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .betweenness import * diff --git a/nx_arangodb/algorithms/centrality/betweenness.py b/nx_arangodb/algorithms/centrality/betweenness.py deleted file mode 100644 index a949fbd1..00000000 --- a/nx_arangodb/algorithms/centrality/betweenness.py +++ /dev/null @@ -1,51 +0,0 @@ -# type: ignore -# NOTE: NetworkX algorithms are not typed - -import networkx as nx - -from nx_arangodb.convert import _to_nx_graph, _to_nxcg_graph -from nx_arangodb.logger import logger -from nx_arangodb.utils import networkx_algorithm - -try: - import nx_cugraph as nxcg - - GPU_ENABLED = True -except ModuleNotFoundError: - GPU_ENABLED = False - - -__all__ = ["betweenness_centrality"] - - -@networkx_algorithm( - is_incomplete=True, - is_different=True, - version_added="23.10", - _plc="betweenness_centrality", -) -def betweenness_centrality( - G, - k=None, - normalized=True, - weight=None, - endpoints=False, - seed=None, - run_on_gpu=True, - pull_graph_on_cpu=True, -): - logger.debug(f"nxadb.betweenness_centrality for {G.__class__.__name__}") - - if GPU_ENABLED and run_on_gpu: - G = _to_nxcg_graph(G, weight) - - logger.debug("using nxcg.betweenness_centrality") - print("Running nxcg.betweenness_centrality()") - return nxcg.betweenness_centrality(G, k=k, normalized=normalized, weight=weight) - - G = _to_nx_graph(G, pull_graph=pull_graph_on_cpu) - - logger.debug("using nx.betweenness_centrality") - return nx.betweenness_centrality.orig_func( - G, k=k, normalized=normalized, weight=weight, endpoints=endpoints, seed=seed - ) diff --git a/nx_arangodb/algorithms/community/__init__.py b/nx_arangodb/algorithms/community/__init__.py deleted file mode 100644 index 5b43a3e4..00000000 --- a/nx_arangodb/algorithms/community/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .louvain import * diff --git a/nx_arangodb/algorithms/community/louvain.py b/nx_arangodb/algorithms/community/louvain.py deleted file mode 100644 index 2c708863..00000000 --- a/nx_arangodb/algorithms/community/louvain.py +++ /dev/null @@ -1,110 +0,0 @@ -# type: ignore -# NOTE: NetworkX algorithms are not typed - -from collections import deque - -import networkx as nx - -from nx_arangodb.convert import _to_nx_graph, _to_nxcg_graph -from nx_arangodb.logger import logger -from nx_arangodb.utils import _dtype_param, networkx_algorithm - -try: - import nx_cugraph as nxcg - - GPU_ENABLED = True -except ModuleNotFoundError: - GPU_ENABLED = False - - -@networkx_algorithm( - extra_params={ - **_dtype_param, - }, - is_incomplete=True, # seed not supported; self-loops not supported - is_different=True, # RNG different - version_added="23.10", - _plc="louvain", - name="louvain_communities", -) -def louvain_communities( - G, - weight="weight", - resolution=1, - threshold=0.0000001, - max_level=None, - seed=None, - run_on_gpu=True, - pull_graph_on_cpu=True, -): - logger.debug(f"nxadb.louvain_communities for {G.__class__.__name__}") - - if GPU_ENABLED and run_on_gpu: - G = _to_nxcg_graph(G, weight) - - logger.debug("using nxcg.louvain_communities") - print("Running nxcg.louvain_communities()") - return nxcg.algorithms.community.louvain._louvain_communities( - G, - weight=weight, - resolution=resolution, - threshold=threshold, - max_level=max_level, - seed=seed, - ) - - G = _to_nx_graph(G, pull_graph=pull_graph_on_cpu) - - logger.debug("using nx.louvain_communities") - return nx.community.louvain_communities.orig_func( - G, - weight=weight, - resolution=resolution, - threshold=threshold, - max_level=max_level, - seed=seed, - ) - - -@networkx_algorithm( - extra_params={ - **_dtype_param, - }, - is_incomplete=True, # seed not supported; self-loops not supported - is_different=True, # RNG different - version_added="23.10", - _plc="louvain", - name="louvain_partitions", -) -def louvain_partitions( - G, weight="weight", resolution=1, threshold=0.0000001, seed=None -): - return nx.community.louvain_partitions.orig_func( - G, weight=weight, resolution=resolution, threshold=threshold, seed=seed - ) - - -@networkx_algorithm( - extra_params={ - **_dtype_param, - }, - is_incomplete=True, # seed not supported; self-loops not supported - is_different=True, # RNG different - version_added="23.10", -) -def modularity(G, communities, weight="weight", resolution=1): - return nx.community.modularity.orig_func( - G, communities, weight=weight, resolution=resolution - ) - - -@networkx_algorithm( - extra_params={ - **_dtype_param, - }, - is_incomplete=True, # seed not supported; self-loops not supported - is_different=True, # RNG different - version_added="23.10", -) -def is_partition(G, communities): - return nx.community.is_partition.orig_func(G, communities) diff --git a/nx_arangodb/algorithms/link_analysis/__init__.py b/nx_arangodb/algorithms/link_analysis/__init__.py deleted file mode 100644 index 7e957e4f..00000000 --- a/nx_arangodb/algorithms/link_analysis/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .pagerank_alg import * diff --git a/nx_arangodb/algorithms/link_analysis/pagerank_alg.py b/nx_arangodb/algorithms/link_analysis/pagerank_alg.py deleted file mode 100644 index a4da41c9..00000000 --- a/nx_arangodb/algorithms/link_analysis/pagerank_alg.py +++ /dev/null @@ -1,77 +0,0 @@ -# type: ignore -# NOTE: NetworkX algorithms are not typed - -import networkx as nx - -from nx_arangodb.convert import _to_nx_graph, _to_nxcg_graph -from nx_arangodb.logger import logger -from nx_arangodb.utils import _dtype_param, networkx_algorithm - -try: - import nx_cugraph as nxcg - - GPU_ENABLED = True -except ModuleNotFoundError: - GPU_ENABLED = False - - -@networkx_algorithm( - extra_params=_dtype_param, - is_incomplete=True, # dangling not supported - version_added="23.12", - _plc={"pagerank", "personalized_pagerank"}, -) -def pagerank( - G, - alpha=0.85, - personalization=None, - max_iter=100, - tol=1.0e-6, - nstart=None, - weight="weight", - dangling=None, - *, - dtype=None, - run_on_gpu=True, - pull_graph_on_cpu=True, -): - logger.debug(f"nxadb.pagerank for {G.__class__.__name__}") - - if GPU_ENABLED and run_on_gpu: - G = _to_nxcg_graph(G, weight) - - logger.debug("using nxcg.pagerank") - print("Running nxcg.pagerank()") - return nxcg.pagerank( - G, - alpha=alpha, - personalization=personalization, - max_iter=max_iter, - tol=tol, - nstart=nstart, - weight=weight, - dangling=dangling, - dtype=dtype, - ) - - G = _to_nx_graph(G, pull_graph=pull_graph_on_cpu) - - logger.debug("using nx.pagerank") - return nx.algorithms.pagerank.orig_func( - G, - alpha=alpha, - personalization=personalization, - max_iter=max_iter, - tol=tol, - nstart=nstart, - weight=weight, - dangling=dangling, - ) - - -@networkx_algorithm( - extra_params=_dtype_param, - version_added="23.12", -) -def to_scipy_sparse_array(G, nodelist=None, dtype=None, weight="weight", format="csr"): - return nx.to_scipy_sparse_array.orig_func(G, nodelist, dtype, weight, format) diff --git a/nx_arangodb/algorithms/shortest_paths/generic.py b/nx_arangodb/algorithms/shortest_paths/generic.py index dbc578c0..ba7fab2d 100644 --- a/nx_arangodb/algorithms/shortest_paths/generic.py +++ b/nx_arangodb/algorithms/shortest_paths/generic.py @@ -30,10 +30,10 @@ def shortest_path( ) if target is None or source is None: - raise ShortestPathError("Both source and target must be specified for now") + raise NotImplementedError("Both source and target must be specified for now") if method != "dijkstra": - raise ShortestPathError("Only dijkstra method is supported") + raise NotImplementedError("Only dijkstra method is supported") query = """ FOR vertex IN ANY SHORTEST_PATH @source TO @target GRAPH @graph diff --git a/nx_arangodb/classes/dict.py b/nx_arangodb/classes/dict.py index d15bfcc5..ce2f7060 100644 --- a/nx_arangodb/classes/dict.py +++ b/nx_arangodb/classes/dict.py @@ -529,13 +529,19 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any: def __fetch_all(self): self.clear() - node_dict, _, _, _, _ = get_arangodb_graph( + ( + node_dict, + *_, + ) = get_arangodb_graph( self.graph, load_node_dict=True, load_adj_dict=False, + load_coo=False, + load_all_vertex_attributes=True, + load_all_edge_attributes=False, # not used is_directed=False, # not used is_multigraph=False, # not used - load_coo=False, + symmetrize_edges_if_directed=False, # not used ) for node_id, node_data in node_dict.items(): @@ -1174,13 +1180,20 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any: def __fetch_all(self) -> None: self.clear() - _, adj_dict, _, _, _ = get_arangodb_graph( + ( + _, + adj_dict, + *_, + ) = get_arangodb_graph( self.graph, load_node_dict=False, load_adj_dict=True, + load_coo=False, + load_all_vertex_attributes=False, # not used + load_all_edge_attributes=True, is_directed=False, # TODO: Abstract based on Graph type is_multigraph=False, # TODO: Abstract based on Graph type - load_coo=False, + symmetrize_edges_if_directed=False, # TODO: Abstract based on Graph type ) for src_node_id, inner_dict in adj_dict.items(): diff --git a/nx_arangodb/classes/digraph.py b/nx_arangodb/classes/digraph.py index e9cedc73..8f491c93 100644 --- a/nx_arangodb/classes/digraph.py +++ b/nx_arangodb/classes/digraph.py @@ -31,6 +31,7 @@ def __init__( graph_name: str | None = None, # default_node_type: str = "node", # edge_type_func: Callable[[str, str], str] = lambda u, v: f"{u}_to_{v}", + symmetrize_edges: bool = False, *args: Any, **kwargs: Any, ): @@ -62,6 +63,7 @@ def __init__( self.src_indices: npt.NDArray[np.int64] | None = None self.dst_indices: npt.NDArray[np.int64] | None = None + self.edge_indices: npt.NDArray[np.int64] | None = None self.vertex_ids_to_index: dict[str, int] | None = None # self.default_node_type = default_node_type @@ -75,6 +77,8 @@ def __init__( super().__init__(*args, **kwargs) + self.symmetrize_edges = symmetrize_edges + ########### # Getters # ########### diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py index 44a2ef98..d00e35be 100644 --- a/nx_arangodb/classes/function.py +++ b/nx_arangodb/classes/function.py @@ -15,6 +15,13 @@ from arango.cursor import Cursor from arango.database import StandardDatabase from arango.graph import Graph +from phenolrs.networkx import NetworkXLoader +from phenolrs.networkx.typings import ( + DiGraphAdj, + GraphAdj, + MultiDiGraphAdj, + MultiGraphAdj, +) import nx_arangodb as nxadb from nx_arangodb.logger import logger @@ -30,12 +37,16 @@ def get_arangodb_graph( adb_graph: Graph, load_node_dict: bool, load_adj_dict: bool, + load_coo: bool, + load_all_vertex_attributes: bool, + load_all_edge_attributes: bool, is_directed: bool, is_multigraph: bool, - load_coo: bool, + symmetrize_edges_if_directed: bool, ) -> Tuple[ dict[str, dict[str, Any]], - dict[str, dict[str, dict[str, Any]]], + GraphAdj | DiGraphAdj | MultiGraphAdj | MultiDiGraphAdj, + npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64], dict[str, int], @@ -54,8 +65,8 @@ def get_arangodb_graph( e_cols = {c["edge_collection"] for c in edge_definitions} metagraph: dict[str, dict[str, Any]] = { - "vertexCollections": {col: {} for col in v_cols}, - "edgeCollections": {col: {} for col in e_cols}, + "vertexCollections": {col: set() for col in v_cols}, + "edgeCollections": {col: set() for col in e_cols}, } if not any((load_node_dict, load_adj_dict, load_coo)): @@ -80,20 +91,31 @@ def get_arangodb_graph( assert config.username assert config.password - from phenolrs.networkx_loader import NetworkXLoader - - # TODO: Remove ignore when phenolrs is published - return NetworkXLoader.load_into_networkx( # type: ignore - config.db_name, - metagraph=metagraph, - hosts=[config.host], - username=config.username, - password=config.password, - load_adj_dict=load_adj_dict, - is_directed=is_directed, - is_multigraph=is_multigraph, - load_coo=load_coo, - **kwargs, + node_dict, adj_dict, src_indices, dst_indices, edge_indices, vertex_ids_to_index = ( + NetworkXLoader.load_into_networkx( + config.db_name, + metagraph=metagraph, + hosts=[config.host], + username=config.username, + password=config.password, + load_adj_dict=load_adj_dict, + load_coo=load_coo, + load_all_vertex_attributes=load_all_vertex_attributes, + load_all_edge_attributes=load_all_edge_attributes, + is_directed=is_directed, + is_multigraph=is_multigraph, + symmetrize_edges_if_directed=symmetrize_edges_if_directed, + **kwargs, + ) + ) + + return ( + node_dict, + adj_dict, + src_indices, + dst_indices, + edge_indices, + vertex_ids_to_index, ) diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index 40f65637..168e521a 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -69,8 +69,11 @@ def __init__( self.src_indices: npt.NDArray[np.int64] | None = None self.dst_indices: npt.NDArray[np.int64] | None = None + self.edge_indices: npt.NDArray[np.int64] | None = None self.vertex_ids_to_index: dict[str, int] | None = None + self.symmetrize_edges = False + self.default_node_type = default_node_type self.edge_type_func = edge_type_func self.default_edge_type = edge_type_func(default_node_type, default_node_type) @@ -130,16 +133,13 @@ def __set_arangodb_backend_config(self) -> None: m = "Must set all environment variables to use the ArangoDB Backend with an existing graph" # noqa: E501 raise OSError(m) - config = Config( - host=self._host, - username=self._username, - password=self._password, - db_name=self._db_name, - load_parallelism=self.graph_loader_parallelism, - load_batch_size=self.graph_loader_batch_size, - ) - - nx.config.backends.arangodb = config + config = nx.config.backends.arangodb + config.host = self._host + config.username = self._username + config.password = self._password + config.db_name = self._db_name + config.load_parallelism = self.graph_loader_parallelism + config.load_batch_size = self.graph_loader_batch_size def __set_factory_methods(self) -> None: """Set the factory methods for the graph, _node, and _adj dictionaries. @@ -331,3 +331,29 @@ def add_node(self, node_for_adding, **attr): self._node[node_for_adding].update(attr) nx._clear_cache(self) + + def number_of_edges(self, u=None, v=None): + if u is None: + ###################### + # NOTE: monkey patch # + ###################### + + # Old: + # return int(self.size()) + + # New: + edge_collections = { + e_d["edge_collection"] for e_d in self.adb_graph.edge_definitions() + } + num = sum( + self.adb_graph.edge_collection(e).count() for e in edge_collections + ) + num *= 2 if self.is_directed() and self.symmetrize_edges else 1 + + return num + + # Reason: + # It is more efficient to count the number of edges in the edge collections + # compared to relying on the DegreeView. + + super().number_of_edges(u, v) diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index b395cafa..a67f570c 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -1,8 +1,7 @@ from __future__ import annotations -import itertools import time -from typing import TYPE_CHECKING, Any +from typing import Any import networkx as nx @@ -20,49 +19,79 @@ GPU_ENABLED = False logger.info(f"NXCG is disabled. {e}.") - -if TYPE_CHECKING: # pragma: no cover - from nx_arangodb.typing import AttrKey, Dtype, EdgeValue, NodeValue - __all__ = [ - "from_networkx", - "to_networkx", + "_to_nx_graph", + "_to_nxadb_graph", + "_to_nxcg_graph", ] -concat = itertools.chain.from_iterable -# A "required" attribute is one that all edges or nodes must have or KeyError is raised -REQUIRED = ... +def _to_nx_graph( + G: Any, *args: Any, pull_graph: bool = True, **kwargs: Any +) -> nx.Graph: + logger.debug(f"_to_nx_graph for {G.__class__.__name__}") + + if isinstance(G, nxadb.Graph | nxadb.DiGraph): + return nxadb_to_nx(G, pull_graph) + + if isinstance(G, nx.Graph): + return G + + raise TypeError(f"Expected nxadb.Graph or nx.Graph; got {type(G)}") + + +def _to_nxadb_graph( + G: Any, + *args: Any, + as_directed: bool = False, + **kwargs: Any, +) -> nxadb.Graph: + logger.debug(f"_to_nxadb_graph for {G.__class__.__name__}") + + if isinstance(G, nxadb.Graph): + return G -def from_networkx( + if isinstance(G, nx.Graph): + return nx_to_nxadb(G, as_directed=as_directed) + + raise TypeError(f"Expected nxadb.Graph or nx.Graph; got {type(G)}") + + +if GPU_ENABLED: + + def _to_nxcg_graph(G: Any, as_directed: bool = False) -> nxcg.Graph: + logger.debug(f"_to_nxcg_graph for {G.__class__.__name__}") + + if isinstance(G, nxcg.Graph): + return G + + if isinstance(G, nxadb.Graph): + if not G.graph_exists_in_db: + m = "nx_arangodb.Graph does not exist in ArangoDB. Cannot pull graph." + raise ValueError(m) + + logger.debug("converting nx_arangodb graph to nx_cugraph graph") + return nxadb_to_nxcg(G, as_directed=as_directed) + + raise TypeError(f"Expected nx_arangodb.Graph or nxcg.Graph; got {type(G)}") + +else: + + def _to_nxcg_graph(G: Any, as_directed: bool = False) -> nxcg.Graph: + m = "nx-cugraph is not installed; cannot convert to nx-cugraph" + raise NotImplementedError(m) + + +def nx_to_nxadb( graph: nx.Graph, *args: Any, as_directed: bool = False, **kwargs: Any, # name: str | None = None, # graph_name: str | None = None, -) -> nxadb.Graph | nxadb.DiGraph: - """Convert a networkx graph to nx_arangodb graph. - - Parameters - ---------- - G : networkx.Graph - - See Also - -------- - to_networkx : The opposite; convert nx_arangodb graph to networkx graph - """ +) -> nxadb.Graph: logger.debug(f"from_networkx for {graph.__class__.__name__}") - if not isinstance(graph, nx.Graph): - if isinstance(graph, nx.classes.reportviews.NodeView): - # Convert to a Graph with only nodes (no edges) - G = nx.Graph() - G.add_nodes_from(graph.items()) - graph = G - else: - raise TypeError(f"Expected networkx.Graph; got {type(graph)}") - if graph.is_multigraph(): if graph.is_directed() or as_directed: klass = nxadb.MultiDiGraph @@ -78,46 +107,12 @@ def from_networkx( return klass(incoming_graph_data=graph) -def to_networkx(G: nxadb.Graph, *args: Any, **kwargs: Any) -> nx.Graph: - """Convert a nx_arangodb graph to networkx graph. - - All edge and node attributes and ``G.graph`` properties are converted. - - TEMPORARY ASSUMPTION: The nx_arangodb Graph is a subclass of networkx Graph. - Therefore, I'm going to assume that we _should_ be able instantiate an - nx Graph using the **incoming_graph_data** parameter. - - Parameters - ---------- - G : nx_arangodb.Graph - - Returns - ------- - networkx.Graph - - See Also - -------- - from_networkx : The opposite; convert networkx graph to nx_cugraph graph - """ - logger.debug(f"to_networkx for {G.__class__.__name__}") - - if not isinstance(G, nxadb.Graph): - raise TypeError(f"Expected nx_arangodb.Graph; got {type(G)}") - - return G.to_networkx_class()(incoming_graph_data=G) - - -def from_networkx_arangodb( - G: nxadb.Graph | nxadb.DiGraph, pull_graph: bool -) -> nx.Graph | nx.DiGraph: - logger.debug(f"from_networkx_arangodb for {G.__class__.__name__}") - - if not isinstance(G, (nxadb.Graph, nxadb.DiGraph)): - raise TypeError(f"Expected nx_arangodb.(Graph || DiGraph); got {type(G)}") - +def nxadb_to_nx(G: nxadb.Graph, pull_graph: bool) -> nx.Graph: if not G.graph_exists_in_db: logger.debug("graph does not exist, nothing to pull") - return G + # TODO: Consider just returning G here? + # Avoids the need to re-create the graph from scratch + return G.to_networkx_class()(incoming_graph_data=G) if not pull_graph: if isinstance(G, nxadb.DiGraph): @@ -127,101 +122,45 @@ def from_networkx_arangodb( logger.debug("graph exists, but not pulling. relying on remote connection...") return G + # TODO: Re-enable this # if G.use_nx_cache and G._node and G._adj: # m = "**use_nx_cache** is enabled. using cached data. no pull required." # logger.debug(m) # return G - logger.debug("pulling as NetworkX Graph...") - print(f"Fetching {G.graph_name} as dictionaries...") start_time = time.time() - _, adj_dict, _, _, _ = nxadb.classes.function.get_arangodb_graph( + + node_dict, adj_dict, *_ = nxadb.classes.function.get_arangodb_graph( adb_graph=G.adb_graph, - load_node_dict=False, # TODO: Should we load node dict? + load_node_dict=True, load_adj_dict=True, + load_coo=False, + load_all_vertex_attributes=False, + # TODO: Only return the edge attributes that are needed + load_all_edge_attributes=True, is_directed=G.is_directed(), is_multigraph=G.is_multigraph(), - load_coo=False, + symmetrize_edges_if_directed=G.symmetrize_edges if G.is_directed() else False, ) - end_time = time.time() - logger.debug(f"load took {end_time - start_time} seconds") - print(f"ADB -> Dictionaries load took {end_time - start_time} seconds") - - return G.to_networkx_class()(incoming_graph_data=adj_dict) - - # try: - # logger.debug("creating nx graph from loaded ArangoDB data...") - # print("Creating nx graph from loaded ArangoDB data...") - # start_time = time.time() - # result: nx.Graph = nx.convert.from_dict_of_dicts( - # adj_dict, - # create_using=G.to_networkx_class(), - # multigraph_input=G.is_multigraph(), - # ) - # for n, dd in node_dict.items(): - # result._node[n].update(dd) - # end_time = time.time() - # print(f"NX Graph creation took {end_time - start_time}") + print(f"ADB -> Dictionaries load took {time.time() - start_time}s") - # return result + G_NX: nx.Graph | nx.DiGraph = G.to_networkx_class()() + G_NX._node = node_dict - # except Exception as err: - # raise nx.NetworkXError("Input is not a correct NetworkX graph.") from err + if isinstance(G_NX, nx.DiGraph): + G_NX._succ = G._adj = adj_dict["succ"] + G_NX._pred = adj_dict["pred"] + else: + G_NX._adj = adj_dict -def _to_nx_graph( - G: Any, - pull_graph: bool = True, -) -> nx.Graph | nx.DiGraph: - """Ensure that input type is an nx graph, and convert if necessary.""" - logger.debug(f"_to_nx_graph for {G.__class__.__name__}") - - if isinstance(G, (nxadb.Graph, nxadb.DiGraph)): - return from_networkx_arangodb(G, pull_graph) - - if isinstance(G, nx.Graph): - return G - - raise TypeError(f"Expected nx_arangodb.Graph or nx.Graph; got {type(G)}") + return G_NX if GPU_ENABLED: - def _to_nxcg_graph(G: Any, as_directed: bool = False) -> nxcg.Graph | nxcg.DiGraph: - """Ensure that input type is a nx_cugraph graph, and convert if necessary.""" - logger.debug(f"_to_nxcg_graph for {G.__class__.__name__}") - - if isinstance(G, nxcg.Graph): - logger.debug("already an nx_cugraph graph") - return G - - if isinstance(G, (nxadb.Graph, nxadb.DiGraph)): - # Assumption: G.adb_graph_name points to an existing graph in ArangoDB - # Therefore, the user wants us to pull the graph from ArangoDB, - # and convert it to an nx_cugraph graph. - # We currently accomplish this by using the NetworkX adapter for ArangoDB, - # which converts the ArangoDB graph to a NetworkX graph, and then we convert - # the NetworkX graph to an nx_cugraph graph. - # TODO: Implement a direct conversion from ArangoDB to nx_cugraph - if G.graph_exists_in_db: - logger.debug("converting nx_arangodb graph to nx_cugraph graph") - return nxcg_from_networkx_arangodb(G, as_directed=as_directed) - - if isinstance(G, (nxadb.MultiGraph, nxadb.MultiDiGraph)): - raise NotImplementedError( - "nxadb.MultiGraph not yet supported for _to_nxcg_graph()" - ) - - # TODO: handle cugraph.Graph - raise TypeError(f"Expected nx_arangodb.Graph or nx.Graph; got {type(G)}") - - def nxcg_from_networkx_arangodb( - G: nxadb.Graph | nxadb.DiGraph, as_directed: bool = False - ) -> nxcg.Graph | nxcg.DiGraph: - """Convert an nx_arangodb graph to nx_cugraph graph.""" - logger.debug(f"nxcg_from_networkx_arangodb for {G.__class__.__name__}") - + def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph: if G.is_multigraph(): raise NotImplementedError("Multigraphs not yet supported") @@ -229,64 +168,75 @@ def nxcg_from_networkx_arangodb( G.use_coo_cache and G.src_indices is not None and G.dst_indices is not None + and G.edge_indices is not None and G.vertex_ids_to_index is not None ): m = "**use_coo_cache** is enabled. using cached COO data. no pull required." logger.debug(m) else: - logger.debug("pulling as NetworkX-CuGraph Graph...") - print(f"Fetching {G.graph_name} as COO...") start_time = time.time() - _, _, src_indices, dst_indices, vertex_ids_to_index = ( + + _, _, src_indices, dst_indices, edge_indices, vertex_ids_to_index = ( nxadb.classes.function.get_arangodb_graph( adb_graph=G.adb_graph, load_node_dict=False, load_adj_dict=False, - is_directed=G.is_directed(), # not used - is_multigraph=G.is_multigraph(), # not used load_coo=True, + load_all_vertex_attributes=False, # not used + load_all_edge_attributes=False, # not used + is_directed=G.is_directed(), + is_multigraph=G.is_multigraph(), + symmetrize_edges_if_directed=( + G.symmetrize_edges if G.is_directed() else False + ), ) ) - end_time = time.time() - logger.debug(f"load took {end_time - start_time} seconds") - print(f"ADB -> COO load took {end_time - start_time} seconds") + + print(f"ADB -> COO load took {time.time() - start_time}s") G.src_indices = src_indices G.dst_indices = dst_indices + G.edge_indices = edge_indices G.vertex_ids_to_index = vertex_ids_to_index N = len(G.vertex_ids_to_index) - - if G.is_directed() or as_directed: - klass = nxcg.DiGraph - else: - klass = nxcg.Graph - - start_time = time.time() - print("Building CuPy arrays...") src_indices_cp = cp.array(G.src_indices) dst_indices_cp = cp.array(G.dst_indices) - end_time = time.time() - print(f"COO (NumPy) -> COO (CuPy) took {end_time - start_time}") - - logger.debug("creating nx_cugraph graph from COO data...") - print("creating nx_cugraph graph from COO data...") - start_time = time.time() - rv = klass.from_coo( - N=N, - src_indices=src_indices_cp, - dst_indices=dst_indices_cp, - key_to_id=G.vertex_ids_to_index, - ) - end_time = time.time() - print(f"COO -> NXCG took {end_time - start_time}") - logger.debug(f"nxcg from_coo took {end_time - start_time}") - - return rv + edge_indices_cp = cp.array(G.edge_indices) -else: + if G.is_multigraph(): + if G.is_directed() or as_directed: + klass = nxcg.MultiDiGraph + else: + klass = nxcg.MultiGraph + + return klass.from_coo( + N=N, + src_indices=src_indices_cp, + dst_indices=dst_indices_cp, + edge_indices=edge_indices_cp, + # edge_values, + # edge_masks, + # node_values, + # node_masks, + key_to_id=vertex_ids_to_index, + # edge_keys=edge_keys, + ) - def _to_nxcg_graph(G: Any, as_directed: bool = False) -> nxcg.Graph | nxcg.DiGraph: - m = "nx-cugraph is not installed; cannot convert to nx-cugraph graph" - raise NotImplementedError(m) + else: + if G.is_directed() or as_directed: + klass = nxcg.DiGraph + else: + klass = nxcg.Graph + + return klass.from_coo( + N=N, + src_indices=src_indices_cp, + dst_indices=dst_indices_cp, + # edge_values, + # edge_masks, + # node_values, + # node_masks, + key_to_id=vertex_ids_to_index, + ) diff --git a/nx_arangodb/interface.py b/nx_arangodb/interface.py index ec4a7f99..40e08523 100644 --- a/nx_arangodb/interface.py +++ b/nx_arangodb/interface.py @@ -2,324 +2,162 @@ import os import sys -from typing import Any +from functools import partial +from typing import Any, Callable, Protocol, Set import networkx as nx +from networkx.utils.backends import _load_backend, _registered_algorithms import nx_arangodb as nxadb +from nx_arangodb.logger import logger + +# Avoid infinite recursion when testing +_IS_TESTING = os.environ.get("NETWORKX_TEST_BACKEND") in {"arangodb"} + + +class NetworkXFunction(Protocol): + graphs: dict[str, Any] + name: str + list_graphs: Set[str] + orig_func: Callable[..., Any] + _returns_graph: bool class BackendInterface: - # Required conversions @staticmethod - def convert_from_nx( - graph: Any, *args: Any, **kwargs: Any - ) -> nxadb.Graph | nxadb.DiGraph: - return nxadb.from_networkx(graph, *args, **kwargs) + def convert_from_nx(graph: nx.Graph, *args: Any, **kwargs: Any) -> nxadb.Graph: + return nxadb._to_nxadb_graph(graph, *args, **kwargs) @staticmethod - def convert_to_nx( - obj: nx.Graph | nx.DiGraph | nxadb.Graph | nxadb.DiGraph, - *, - name: str | None = None, - ) -> nx.Graph | nx.DiGraph: - if isinstance(obj, nxadb.Graph): - return nxadb.to_networkx(obj) - return obj - - # TODO Anthony: Clarify what needs to be changed here - @staticmethod - def on_start_tests(items): - """Modify pytest items after tests have been collected. + def convert_to_nx(obj: Any, *args: Any, **kwargs: Any) -> nx.Graph: + if not isinstance(obj, nxadb.Graph): + return obj - This is called during ``pytest_collection_modifyitems`` phase of pytest. - We use this to set `xfail` on tests we expect to fail. See: + return nxadb._to_nx_graph(obj, *args, **kwargs) - https://docs.pytest.org/en/stable/reference/reference.html#std-hook-pytest_collection_modifyitems + def __getattr__(self, attr: str, *, from_backend_name: str = "arangodb") -> Any: """ - try: - import pytest - except ModuleNotFoundError: - return - - def key(testpath): - filename, path = testpath.split(":") - *names, testname = path.split(".") - if names: - [classname] = names - return (testname, frozenset({classname, filename})) - return (testname, frozenset({filename})) - - # Reasons for xfailing - no_weights = "weighted implementation not currently supported" - no_multigraph = "multigraphs not currently supported" - louvain_different = "Louvain may be different due to RNG" - no_string_dtype = "string edge values not currently supported" - sssp_path_different = "sssp may choose a different valid path" - - xfail = { - # This is removed while strongly_connected_components() is not - # dispatchable. See algorithms/components/strongly_connected.py for - # details. - # - # key( - # "test_strongly_connected.py:" - # "TestStronglyConnected.test_condensation_mapping_and_members" - # ): "Strongly connected groups in different iteration order", - key( - "test_cycles.py:TestMinimumCycleBasis.test_unweighted_diamond" - ): sssp_path_different, - key( - "test_cycles.py:TestMinimumCycleBasis.test_weighted_diamond" - ): sssp_path_different, - key( - "test_cycles.py:TestMinimumCycleBasis.test_petersen_graph" - ): sssp_path_different, - key( - "test_cycles.py:TestMinimumCycleBasis." - "test_gh6787_and_edge_attribute_names" - ): sssp_path_different, - } + Dispatching mechanism for all networkx algorithms. This avoids having to + write a separate function for each algorithm. + """ + if ( + attr not in _registered_algorithms + or _IS_TESTING + and attr in {"empty_graph"} + ): + raise AttributeError(attr) - from packaging.version import parse - - nxver = parse(nx.__version__) - - if nxver.major == 3 and nxver.minor <= 2: - xfail.update( - { - # NetworkX versions prior to 3.2.1 have tests written to - # expect sp.sparse.linalg.ArpackNoConvergence exceptions - # raised on no convergence in HITS. Newer versions since - # the merge of - # https://github.com/networkx/networkx/pull/7084 expect - # nx.PowerIterationFailedConvergence, which is what - # nx_cugraph.hits raises, so we mark them as xfail for - # previous versions of NX. - key( - "test_hits.py:TestHITS.test_hits_not_convergent" - ): "nx_cugraph.hits raises updated exceptions not caught in " - "these tests", - # NetworkX versions 3.2 and older contain tests that fail - # with pytest>=8. Assume pytest>=8 and mark xfail. - key( - "test_strongly_connected.py:" - "TestStronglyConnected.test_connected_raise" - ): "test is incompatible with pytest>=8", - } - ) + if from_backend_name != "arangodb": + raise ValueError(f"Unsupported source backend: '{from_backend_name}'") - if nxver.major == 3 and nxver.minor <= 1: - # MAINT: networkx 3.0, 3.1 - # NetworkX 3.2 added the ability to "fallback to nx" if backend algorithms - # raise NotImplementedError or `can_run` returns False. The tests below - # exercise behavior we have not implemented yet, so we mark them as xfail - # for previous versions of NX. - xfail.update( - { - key( - "test_agraph.py:TestAGraph.test_no_warnings_raised" - ): "pytest.warn(None) deprecated", - key( - "test_betweenness_centrality.py:" - "TestWeightedBetweennessCentrality.test_K5" - ): no_weights, - key( - "test_betweenness_centrality.py:" - "TestWeightedBetweennessCentrality.test_P3_normalized" - ): no_weights, - key( - "test_betweenness_centrality.py:" - "TestWeightedBetweennessCentrality.test_P3" - ): no_weights, - key( - "test_betweenness_centrality.py:" - "TestWeightedBetweennessCentrality.test_krackhardt_kite_graph" - ): no_weights, - key( - "test_betweenness_centrality.py:" - "TestWeightedBetweennessCentrality." - "test_krackhardt_kite_graph_normalized" - ): no_weights, - key( - "test_betweenness_centrality.py:" - "TestWeightedBetweennessCentrality." - "test_florentine_families_graph" - ): no_weights, - key( - "test_betweenness_centrality.py:" - "TestWeightedBetweennessCentrality.test_les_miserables_graph" - ): no_weights, - key( - "test_betweenness_centrality.py:" - "TestWeightedBetweennessCentrality.test_ladder_graph" - ): no_weights, - key( - "test_betweenness_centrality.py:" - "TestWeightedBetweennessCentrality.test_G" - ): no_weights, - key( - "test_betweenness_centrality.py:" - "TestWeightedBetweennessCentrality.test_G2" - ): no_weights, - key( - "test_betweenness_centrality.py:" - "TestWeightedBetweennessCentrality.test_G3" - ): no_multigraph, - key( - "test_betweenness_centrality.py:" - "TestWeightedBetweennessCentrality.test_G4" - ): no_multigraph, - key( - "test_betweenness_centrality.py:" - "TestWeightedEdgeBetweennessCentrality.test_K5" - ): no_weights, - key( - "test_betweenness_centrality.py:" - "TestWeightedEdgeBetweennessCentrality.test_C4" - ): no_weights, - key( - "test_betweenness_centrality.py:" - "TestWeightedEdgeBetweennessCentrality.test_P4" - ): no_weights, - key( - "test_betweenness_centrality.py:" - "TestWeightedEdgeBetweennessCentrality.test_balanced_tree" - ): no_weights, - key( - "test_betweenness_centrality.py:" - "TestWeightedEdgeBetweennessCentrality.test_weighted_graph" - ): no_weights, - key( - "test_betweenness_centrality.py:" - "TestWeightedEdgeBetweennessCentrality." - "test_normalized_weighted_graph" - ): no_weights, - key( - "test_betweenness_centrality.py:" - "TestWeightedEdgeBetweennessCentrality.test_weighted_multigraph" - ): no_multigraph, - key( - "test_betweenness_centrality.py:" - "TestWeightedEdgeBetweennessCentrality." - "test_normalized_weighted_multigraph" - ): no_multigraph, - } + return partial(_auto_func, attr) + + +def _auto_func(func_name: str, /, *args: Any, **kwargs: Any) -> Any: + """ + Function to automatically dispatch to the correct backend for a given algorithm. + + :param func_name: The name of the algorithm to run. + :type func_name: str + """ + dfunc = _registered_algorithms[func_name] + + # TODO: Use `nx.config.backends.arangodb.backend_priority` instead + backend_priority = [] + if nxadb.convert.GPU_ENABLED: + backend_priority.append("cugraph") + + for backend in backend_priority: + if not dfunc.__wrapped__._should_backend_run(backend, *args, **kwargs): + logger.warning(f"'{func_name}' cannot be run on backend '{backend}'") + continue + + try: + return _run_with_backend( + backend, + dfunc, + args, + kwargs, ) - else: - xfail.update( - { - key( - "test_louvain.py:test_karate_club_partition" - ): louvain_different, - key("test_louvain.py:test_none_weight_param"): louvain_different, - key("test_louvain.py:test_multigraph"): louvain_different, - # See networkx#6630 - key( - "test_louvain.py:test_undirected_selfloops" - ): "self-loops not handled in Louvain", - } + + except NotImplementedError: + logger.warning(f"'{func_name}' not implemented for backend '{backend}'") + pass + + default_backend = "networkx" + logger.debug(f"'{func_name}' running on default backend '{default_backend}'") + return _run_with_backend(default_backend, dfunc, args, kwargs) + + +def _run_with_backend( + backend_name: str, + dfunc: NetworkXFunction, + args: Any, + kwargs: Any, +) -> Any: + """ + :param backend: The name of the backend to run the algorithm on. + :type backend: str + :param dfunc: The function to run. + :type dfunc: NetworkXFunction + """ + func_name = dfunc.name + backend_func = ( + dfunc.orig_func + if backend_name == "networkx" + else getattr(_load_backend(backend_name), func_name) + ) + + graphs_resolved = { + gname: val + for gname, pos in dfunc.graphs.items() + if (val := args[pos] if pos < len(args) else kwargs.get(gname)) is not None + } + + if dfunc.list_graphs: + graphs_converted = { + gname: ( + [_convert_to_backend(g, backend_name) for g in val] + if gname in dfunc.list_graphs + else _convert_to_backend(val, backend_name) ) - if sys.version_info[:2] == (3, 9): - # This test is sensitive to RNG, which depends on Python version - xfail[key("test_louvain.py:test_threshold")] = ( - "Louvain does not support seed parameter" - ) - if nxver.major == 3 and nxver.minor >= 2: - xfail.update( - { - key( - "test_convert_pandas.py:TestConvertPandas." - "test_from_edgelist_multi_attr_incl_target" - ): no_string_dtype, - key( - "test_convert_pandas.py:TestConvertPandas." - "test_from_edgelist_multidigraph_and_edge_attr" - ): no_string_dtype, - key( - "test_convert_pandas.py:TestConvertPandas." - "test_from_edgelist_int_attr_name" - ): no_string_dtype, - } - ) - if nxver.minor == 2: - different_iteration_order = "Different graph data iteration order" - xfail.update( - { - key( - "test_cycles.py:TestMinimumCycleBasis." - "test_gh6787_and_edge_attribute_names" - ): different_iteration_order, - key( - "test_euler.py:TestEulerianCircuit." - "test_eulerian_circuit_cycle" - ): different_iteration_order, - key( - "test_gml.py:TestGraph.test_special_float_label" - ): different_iteration_order, - } - ) - elif nxver.minor >= 3: - xfail.update( - { - key("test_louvain.py:test_max_level"): louvain_different, - } - ) - - too_slow = "Too slow to run" - skip = { - key("test_tree_isomorphism.py:test_positive"): too_slow, - key("test_tree_isomorphism.py:test_negative"): too_slow, - # These repeatedly call `bfs_layers`, which converts the graph every call - key( - "test_vf2pp.py:TestGraphISOVF2pp.test_custom_graph2_different_labels" - ): too_slow, - key( - "test_vf2pp.py:TestGraphISOVF2pp.test_custom_graph3_same_labels" - ): too_slow, - key( - "test_vf2pp.py:TestGraphISOVF2pp.test_custom_graph3_different_labels" - ): too_slow, - key( - "test_vf2pp.py:TestGraphISOVF2pp.test_custom_graph4_same_labels" - ): too_slow, - key( - "test_vf2pp.py:TestGraphISOVF2pp." - "test_disconnected_graph_all_same_labels" - ): too_slow, - key( - "test_vf2pp.py:TestGraphISOVF2pp." - "test_disconnected_graph_all_different_labels" - ): too_slow, - key( - "test_vf2pp.py:TestGraphISOVF2pp." - "test_disconnected_graph_some_same_labels" - ): too_slow, - key( - "test_vf2pp.py:TestMultiGraphISOVF2pp." - "test_custom_multigraph3_same_labels" - ): too_slow, - key( - "test_vf2pp_helpers.py:TestNodeOrdering." - "test_matching_order_all_branches" - ): too_slow, + for gname, val in graphs_resolved.items() } - if os.environ.get("PYTEST_NO_SKIP", False): - skip.clear() - - for item in items: - kset = set(item.keywords) - for (test_name, keywords), reason in xfail.items(): - if item.name == test_name and keywords.issubset(kset): - item.add_marker(pytest.mark.xfail(reason=reason)) - for (test_name, keywords), reason in skip.items(): - if item.name == test_name and keywords.issubset(kset): - item.add_marker(pytest.mark.skip(reason=reason)) - - @classmethod - def can_run(cls, name, args, kwargs): - """Can this backend run the specified algorithms with the given arguments? - - This is a proposed API to add to networkx dispatching machinery and may change. - """ - return hasattr(cls, name) and getattr(cls, name).can_run(*args, **kwargs) + else: + graphs_converted = { + gname: _convert_to_backend(graph, backend_name) + for gname, graph in graphs_resolved.items() + } + + converted_args = list(args) + converted_kwargs = dict(kwargs) + + for gname, val in graphs_converted.items(): + if gname in kwargs: + converted_kwargs[gname] = val + else: + converted_args[dfunc.graphs[gname]] = val + + result = backend_func(*converted_args, **converted_kwargs) + + # TODO: Convert to nxadb.Graph? + # What would this look like? Create a new graph in ArangoDB? + # Or just establish a remote connection? + # if dfunc._returns_graph: + # raise NotImplementedError("Returning Graphs not implemented yet") + + return result + + +def _convert_to_backend(G_from: Any, backend_name: str) -> Any: + if backend_name == "networkx": + pull_graph = nx.config.backends.arangodb.pull_graph + return nxadb._to_nx_graph(G_from, pull_graph=pull_graph) + + if backend_name == "cugraph": + return nxadb._to_nxcg_graph(G_from) + + raise ValueError(f"Unsupported backend: '{backend_name}'") + + +backend_interface = BackendInterface() diff --git a/phenolrs-0.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl b/phenolrs-0.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl index 29f254c0..1e27da03 100644 Binary files a/phenolrs-0.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl and b/phenolrs-0.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl differ diff --git a/pyproject.toml b/pyproject.toml index ab91102f..07f12bcb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,14 +63,14 @@ Homepage = "https://github.com/aMahanna/nx-arangodb" # "plugin" used in nx version < 3.2 [project.entry-points."networkx.plugins"] -arangodb = "nx_arangodb.interface:BackendInterface" - -[project.entry-points."networkx.plugin_info"] -arangodb = "_nx_arangodb:get_info" +arangodb = "nx_arangodb.interface:backend_interface" # "backend" used in nx version >= 3.2 [project.entry-points."networkx.backends"] -arangodb = "nx_arangodb.interface:BackendInterface" +arangodb = "nx_arangodb.interface:backend_interface" + +[project.entry-points."networkx.plugin_info"] +arangodb = "_nx_arangodb:get_info" [project.entry-points."networkx.backend_info"] arangodb = "_nx_arangodb:get_info" diff --git a/tests/test.py b/tests/test.py index 0c9da34c..11d98c47 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1,7 +1,6 @@ -from typing import Any +from typing import Any, Callable import networkx as nx -import pandas as pd import pytest import nx_arangodb as nxadb @@ -12,6 +11,42 @@ G_NX = nx.karate_club_graph() +def assert_same_dict_values( + d1: dict[str | int, float], d2: dict[str | int, float], digit: int +) -> None: + if type(next(iter(d1.keys()))) == int: + d1 = {f"person/{k+1}": v for k, v in d1.items()} # type: ignore + + if type(next(iter(d2.keys()))) == int: + d2 = {f"person/{k+1}": v for k, v in d2.items()} # type: ignore + + assert d1.keys() == d2.keys(), "Dictionaries have different keys" + for key in d1: + m = f"Values for key '{key}' are not equal up to digit {digit}" + assert round(d1[key], digit) == round(d2[key], digit), m + + +def assert_bc(d1: dict[str | int, float], d2: dict[str | int, float]) -> None: + assert_same_dict_values(d1, d2, 14) + + +def assert_pagerank(d1: dict[str | int, float], d2: dict[str | int, float]) -> None: + assert_same_dict_values(d1, d2, 15) + + +def assert_louvain(l1: list[set[Any]], l2: list[set[Any]]) -> None: + # TODO: Implement some kind of comparison + # Reason: Louvain returns different results on different runs + pass + + +def assert_k_components( + d1: dict[int, list[set[Any]]], d2: dict[int, list[set[Any]]] +) -> None: + assert d1.keys() == d2.keys(), "Dictionaries have different keys" + assert d1 == d2 + + def test_db(load_graph: Any) -> None: assert db.version() @@ -36,114 +71,60 @@ def test_load_graph_from_nxadb(): db.delete_graph(graph_name, drop_collections=True) -def test_bc(load_graph): - G_1 = G_NX - G_2 = nxadb.Graph(incoming_graph_data=G_1) - G_3 = nxadb.Graph(graph_name="KarateGraph") - - r_1 = nx.betweenness_centrality(G_1) - r_2 = nx.betweenness_centrality(G_2) - r_3 = nx.betweenness_centrality(G_1, backend="arangodb") - r_4 = nx.betweenness_centrality(G_2, backend="arangodb") - r_5 = nx.betweenness_centrality.orig_func(G_3) - - assert len(r_1) == len(G_1) - assert r_1 == r_2 - assert r_2 == r_3 - assert r_3 == r_4 - assert len(r_1) == len(r_5) - - try: - import phenolrs # noqa - except ModuleNotFoundError: - pytest.skip("phenolrs not installed") - - G_4 = nxadb.Graph(graph_name="KarateGraph") - r_6 = nx.betweenness_centrality(G_4) - - G_5 = nxadb.Graph(graph_name="KarateGraph") - r_7 = nxadb.betweenness_centrality(G_5, pull_graph_on_cpu=False) # type: ignore - - G_6 = nxadb.DiGraph(graph_name="KarateGraph") - r_8 = nx.betweenness_centrality(G_6) - - # assert r_6 == r_7 # this is acting strange. I need to revisit - assert r_7 == r_8 - assert len(r_6) == len(r_7) == len(r_8) == len(G_4) > 0 - - -def test_pagerank(load_graph: Any) -> None: +@pytest.mark.parametrize( + "algorithm_func, assert_func", + [ + (nx.betweenness_centrality, assert_bc), + (nx.pagerank, assert_pagerank), + (nx.community.louvain_communities, assert_louvain), + ], +) +def test_algorithm( + algorithm_func: Callable[..., Any], + assert_func: Callable[..., Any], + load_graph: Any, +) -> None: G_1 = G_NX G_2 = nxadb.Graph(incoming_graph_data=G_1) G_3 = nxadb.Graph(graph_name="KarateGraph") + G_4 = nxadb.DiGraph(graph_name="KarateGraph", symmetrize_edges=True) + G_5 = nxadb.DiGraph(graph_name="KarateGraph", symmetrize_edges=False) - r_1 = nx.pagerank(G_1) - r_2 = nx.pagerank(G_2) - r_3 = nx.pagerank(G_1, backend="arangodb") - r_4 = nx.pagerank(G_2, backend="arangodb") - r_5 = nx.pagerank.orig_func(G_3) - - assert len(r_1) == len(G_1) - assert r_1 == r_2 - assert r_2 == r_3 - assert r_3 == r_4 - assert len(r_1) == len(r_5) - - try: - import phenolrs # noqa - except ModuleNotFoundError: - pytest.skip("phenolrs not installed") - - G_4 = nxadb.Graph(graph_name="KarateGraph") - r_6 = nx.pagerank(G_4) + r_1 = algorithm_func(G_1) + r_2 = algorithm_func(G_2) + r_3 = algorithm_func(G_1, backend="arangodb") + r_4 = algorithm_func(G_2, backend="arangodb") - G_5 = nxadb.Graph(graph_name="KarateGraph") - r_7 = nxadb.pagerank(G_5, pull_graph_on_cpu=False) # type: ignore + r_5 = algorithm_func.orig_func(G_3) # type: ignore + nx.config.backends.arangodb.pull_graph = False + r_6 = algorithm_func(G_3) + nx.config.backends.arangodb.pull_graph = True - G_6 = nxadb.DiGraph(graph_name="KarateGraph") - r_8 = nx.pagerank(G_6) - - assert len(r_6) == len(r_7) == len(r_8) == len(G_4) > 0 - - -def test_louvain(load_graph: Any) -> None: - G_1 = G_NX - G_2 = nxadb.Graph(incoming_graph_data=G_1) - G_3 = nxadb.Graph(graph_name="KarateGraph") - - r_1 = nx.community.louvain_communities(G_1) - r_2 = nx.community.louvain_communities(G_2) - r_3 = nx.community.louvain_communities(G_1, backend="arangodb") - r_4 = nx.community.louvain_communities(G_2, backend="arangodb") - r_5 = nx.community.louvain_communities.orig_func(G_3) - - assert len(r_1) > 0 - assert len(r_2) > 0 - assert len(r_3) > 0 - assert len(r_4) > 0 - assert len(r_5) > 0 + assert all([r_1, r_2, r_3, r_4, r_5, r_6]) + assert_func(r_1, r_2) + assert_func(r_2, r_3) + assert_func(r_3, r_4) + assert_func(r_5, r_6) try: import phenolrs # noqa except ModuleNotFoundError: pytest.skip("phenolrs not installed") - G_4 = nxadb.Graph(graph_name="KarateGraph") - r_6 = nx.community.louvain_communities(G_4) - - G_5 = nxadb.Graph(graph_name="KarateGraph") - r_7 = nxadb.community.louvain_communities(G_5, pull_graph_on_cpu=False) # type: ignore # noqa + r_7 = algorithm_func(G_3) + r_8 = algorithm_func(G_4) + r_9 = algorithm_func(G_5) + r_10 = algorithm_func(nx.DiGraph(incoming_graph_data=G_NX)) - G_6 = nxadb.DiGraph(graph_name="KarateGraph") - r_8 = nx.community.louvain_communities(G_6) + assert all([r_7, r_8, r_9, r_10]) + assert_func(r_7, r_1) + assert_func(r_7, r_8) + assert len(r_8) == len(r_9) + assert r_8 != r_9 + assert_func(r_8, r_10) - assert len(r_5) > 0 - assert len(r_6) > 0 - assert len(r_7) > 0 - assert len(r_8) > 0 - -def test_shortest_path(load_graph: Any) -> None: +def test_shortest_path_remote_algorithm(load_graph: Any) -> None: G_1 = nxadb.Graph(graph_name="KarateGraph") G_2 = nxadb.DiGraph(graph_name="KarateGraph") @@ -496,8 +477,7 @@ def test_readme(load_graph: Any) -> None: True, ), ("Pandas EdgeList", nx.to_pandas_edgelist(G_NX), False, True), - # TODO: Address **nx.relabel.relabel_nodes** issue - # ("Pandas Adjacency", nx.to_pandas_adjacency(G_NX), False, True), + ("Pandas Adjacency", nx.to_pandas_adjacency(G_NX), False, True), ], ) def test_incoming_graph_data_not_nx_graph( @@ -509,9 +489,19 @@ def test_incoming_graph_data_not_nx_graph( G = nxadb.Graph(incoming_graph_data=incoming_graph_data, graph_name=name) - assert len(G.nodes) == len(G_NX.nodes) == db.collection(G.default_node_type).count() assert len(G.adj) == len(G_NX.adj) == db.collection(G.default_node_type).count() - assert len(G.edges) == len(G_NX.edges) == db.collection(G.default_edge_type).count() + assert ( + len(G.nodes) + == len(G_NX.nodes) + == db.collection(G.default_node_type).count() + == G.number_of_nodes() + ) + assert ( + len(G.edges) + == len(G_NX.edges) + == db.collection(G.default_edge_type).count() + == G.number_of_edges() + ) assert has_club == ("club" in G.nodes["0"]) assert has_weight == ("weight" in G.adj["0"]["1"])