diff --git a/nx_arangodb/classes/dict.py b/nx_arangodb/classes/dict.py index 76071d91..43ee7489 100644 --- a/nx_arangodb/classes/dict.py +++ b/nx_arangodb/classes/dict.py @@ -33,6 +33,7 @@ doc_get_or_insert, doc_insert, doc_update, + get_arangodb_graph, get_node_id, get_node_type_and_id, key_is_not_reserved, @@ -321,56 +322,6 @@ def __delitem__(self, key: str) -> None: root_data = self.root.data if self.root else self.data root_data["_rev"] = doc_update(self.db, self.node_id, update_dict) - # @logger_debug - # def __iter__(self) -> Iterator[str]: - # """for key in G._node['node/1']""" - # yield from aql_doc_get_keys(self.db, self.node_id, self.parent_keys) - - # @logger_debug - # def __len__(self) -> int: - # """len(G._node['node/1'])""" - # return aql_doc_get_length(self.db, self.node_id, self.parent_keys) - - # @logger_debug - # def keys(self) -> Any: - # """G._node['node/1'].keys()""" - # yield from self.__iter__() - - # @logger_debug - # # TODO: Revisit typing of return value - # def values(self) -> Any: - # """G._node['node/1'].values()""" - # self.data = self.db.document(self.node_id) - # yield from self.data.values() - - # @logger_debug - # # TODO: Revisit typing of return value - # def items(self) -> Any: - # """G._node['node/1'].items()""" - - # # TODO: Revisit this lazy hack - # if self.parent_keys: - # yield from self.data.items() - # else: - # self.data = self.db.document(self.node_id) - # yield from self.data.items() - - # ? - # def pull(): - # pass - - # ? - # def push(): - # pass - - # @logger_debug - # def clear(self) -> None: - # """G._node['node/1'].clear()""" - # self.data.clear() - - # # if clear_remote: - # # doc_insert(self.db, self.node_id, silent=True, overwrite=True) - @keys_are_strings @keys_are_not_reserved # @values_are_json_serializable # TODO? @@ -435,6 +386,9 @@ def __contains__(self, key: str) -> bool: if node_id in self.data: return True + if self.FETCHED_ALL_DATA: + return False + return bool(self.graph.has_vertex(node_id)) @key_is_string @@ -446,6 +400,9 @@ def __getitem__(self, key: str) -> NodeAttrDict: if vertex := self.data.get(node_id): return vertex + if self.FETCHED_ALL_DATA: + raise KeyError(key) + if vertex := self.graph.vertex(node_id): node_attr_dict: NodeAttrDict = self.node_attr_dict_factory() node_attr_dict.node_id = node_id @@ -472,7 +429,7 @@ def __setitem__(self, key: str, value: NodeAttrDict) -> None: node_attr_dict = self.node_attr_dict_factory() node_attr_dict.node_id = node_id - node_attr_dict.data = result + node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, result) self.data[node_id] = node_attr_dict @@ -570,16 +527,23 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any: @logger_debug def __fetch_all(self): - self.data.clear() - for collection in self.graph.vertex_collections(): - for doc in self.graph.vertex_collection(collection).all(): - node_id = doc["_id"] + self.clear() - node_attr_dict = self.node_attr_dict_factory() - node_attr_dict.node_id = node_id - node_attr_dict.data = doc + node_dict, _, _, _, _ = get_arangodb_graph( + self.graph, + load_node_dict=True, + load_adj_dict=False, + load_adj_dict_as_directed=False, # not used + load_adj_dict_as_multigraph=False, # not used + load_coo=False, + ) - self.data[node_id] = node_attr_dict + for node_id, node_data in node_dict.items(): + node_attr_dict = self.node_attr_dict_factory() + node_attr_dict.node_id = node_id + node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, node_data) + + self.data[node_id] = node_attr_dict self.FETCHED_ALL_DATA = True @@ -710,43 +674,6 @@ def __delitem__(self, key: str) -> None: root_data = self.root.data if self.root else self.data root_data["_rev"] = doc_update(self.db, self.edge_id, update_dict) - # @logger_debug - # def __iter__(self) -> Iterator[str]: - # """for key in G._adj['node/1']['node/2']""" - # assert self.edge_id - # yield from aql_doc_get_keys(self.db, self.edge_id) - - # @logger_debug - # def __len__(self) -> int: - # """len(G._adj['node/1']['node/'2])""" - # assert self.edge_id - # return aql_doc_get_length(self.db, self.edge_id) - - # # TODO: Revisit typing of return value - # @logger_debug - # def keys(self) -> Any: - # """G._adj['node/1']['node/'2].keys()""" - # return self.__iter__() - - # # TODO: Revisit typing of return value - # @logger_debug - # def values(self) -> Any: - # """G._adj['node/1']['node/'2].values()""" - # self.data = self.db.document(self.edge_id) - # yield from self.data.values() - - # # TODO: Revisit typing of return value - # @logger_debug - # def items(self) -> Any: - # """G._adj['node/1']['node/'2].items()""" - # self.data = self.db.document(self.edge_id) - # yield from self.data.items() - - # @logger_debug - # def clear(self) -> None: - # """G._adj['node/1']['node/'2].clear()""" - # self.data.clear() - @keys_are_strings @keys_are_not_reserved @logger_debug @@ -836,6 +763,9 @@ def __contains__(self, key: str) -> bool: if dst_node_id in self.data: return True + if self.FETCHED_ALL_DATA: + return False + result = aql_edge_exists( self.db, self.src_node_id, @@ -859,6 +789,9 @@ def __getitem__(self, key: str) -> EdgeAttrDict: self.data[dst_node_id] = edge return edge # type: ignore # false positive + if self.FETCHED_ALL_DATA: + raise KeyError(key) + assert self.src_node_id edge = aql_edge_get( self.db, @@ -1022,8 +955,7 @@ def items(self) -> Any: @logger_debug def __fetch_all(self) -> None: - if self.FETCHED_ALL_DATA: - return + assert self.src_node_id self.clear() @@ -1037,8 +969,7 @@ def __fetch_all(self) -> None: for edge in aql(self.db, query, bind_vars): edge_attr_dict = self.edge_attr_dict_factory() edge_attr_dict.edge_id = edge["_id"] - edge_attr_dict.data = edge - + edge_attr_dict.data = build_edge_attr_dict_data(edge_attr_dict, edge) self.data[edge["_to"]] = edge_attr_dict self.FETCHED_ALL_DATA = True @@ -1100,6 +1031,9 @@ def __contains__(self, key: str) -> bool: if node_id in self.data: return True + if self.FETCHED_ALL_DATA: + return False + return bool(self.graph.has_vertex(node_id)) @key_is_string @@ -1114,7 +1048,6 @@ def __getitem__(self, key: str) -> AdjListInnerDict: if self.graph.has_vertex(node_id): adjlist_inner_dict: AdjListInnerDict = self.adjlist_inner_dict_factory() adjlist_inner_dict.src_node_id = node_id - self.data[node_id] = adjlist_inner_dict return adjlist_inner_dict @@ -1237,29 +1170,32 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any: result = aql_fetch_data_edge(self.db, e_cols, data, default) yield from result - # TODO: Revisit this logic @logger_debug def __fetch_all(self) -> None: - if self.FETCHED_ALL_DATA: - return - self.clear() - # items = defaultdict(dict) - for ed in self.graph.edge_definitions(): - collection = ed["edge_collection"] - for edge in self.graph.edge_collection(collection): - src_node_id = edge["_from"] - dst_node_id = edge["_to"] + _, adj_dict, _, _, _ = get_arangodb_graph( + self.graph, + load_node_dict=False, + load_adj_dict=True, + load_adj_dict_as_directed=False, # TODO: Abstract based on Graph type + load_adj_dict_as_multigraph=False, # TODO: Abstract based on Graph type + load_coo=False, + ) + + for src_node_id, inner_dict in adj_dict.items(): + for dst_node_id, edge in inner_dict.items(): - # items[src_node_id][dst_node_id] = edge - # items[dst_node_id][src_node_id] = edge + if src_node_id in self.data: + if dst_node_id in self.data[src_node_id].data: + continue if src_node_id in self.data: src_inner_dict = self.data[src_node_id] else: src_inner_dict = self.adjlist_inner_dict_factory() src_inner_dict.src_node_id = src_node_id + src_inner_dict.FETCHED_ALL_DATA = True self.data[src_node_id] = src_inner_dict if dst_node_id in self.data: @@ -1267,11 +1203,12 @@ def __fetch_all(self) -> None: else: dst_inner_dict = self.adjlist_inner_dict_factory() dst_inner_dict.src_node_id = dst_node_id + src_inner_dict.FETCHED_ALL_DATA = True self.data[dst_node_id] = dst_inner_dict edge_attr_dict = src_inner_dict.edge_attr_dict_factory() edge_attr_dict.edge_id = edge["_id"] - edge_attr_dict.data = edge + edge_attr_dict.data = build_edge_attr_dict_data(edge_attr_dict, edge) self.data[src_node_id].data[dst_node_id] = edge_attr_dict self.data[dst_node_id].data[src_node_id] = edge_attr_dict diff --git a/nx_arangodb/classes/digraph.py b/nx_arangodb/classes/digraph.py index 310f33b5..e9cedc73 100644 --- a/nx_arangodb/classes/digraph.py +++ b/nx_arangodb/classes/digraph.py @@ -154,6 +154,3 @@ def __set_graph_name(self, graph_name: str | None = None) -> None: def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Cursor: return nxadb.classes.function.aql(self.db, query, bind_vars, **kwargs) - - def pull(self, load_node_dict=True, load_adj_dict=True, load_coo=True): - raise NotImplementedError("nxadb.DiGraph.pull() is not implemented yet.") diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py index f48e338b..67b578fb 100644 --- a/nx_arangodb/classes/function.py +++ b/nx_arangodb/classes/function.py @@ -8,11 +8,13 @@ from collections import UserDict from typing import Any, Callable, Tuple +import networkx as nx import numpy as np import numpy.typing as npt from arango.collection import StandardCollection from arango.cursor import Cursor from arango.database import StandardDatabase +from arango.graph import Graph import nx_arangodb as nxadb from nx_arangodb.logger import logger @@ -25,10 +27,11 @@ def get_arangodb_graph( - G: nxadb.Graph | nxadb.DiGraph, + adb_graph: Graph, load_node_dict: bool, load_adj_dict: bool, load_adj_dict_as_directed: bool, + load_adj_dict_as_multigraph: bool, load_coo: bool, ) -> Tuple[ dict[str, dict[str, Any]], @@ -46,12 +49,6 @@ def get_arangodb_graph( - Destination Indices (COO) - Node-ID-to-index mapping (COO) """ - if not G.graph_exists_in_db: - raise GraphDoesNotExist( - "Graph does not exist in the database. Can't load graph." - ) - - adb_graph = G.db.graph(G.graph_name) v_cols = adb_graph.vertex_collections() edge_definitions = adb_graph.edge_definitions() e_cols = {c["edge_collection"] for c in edge_definitions} @@ -63,22 +60,30 @@ def get_arangodb_graph( from phenolrs.networkx_loader import NetworkXLoader + config = nx.config.backends.arangodb + kwargs = {} - if G.graph_loader_parallelism is not None: - kwargs["parallelism"] = G.graph_loader_parallelism - if G.graph_loader_batch_size is not None: - kwargs["batch_size"] = G.graph_loader_batch_size + if parallelism := config.get("load_parallelism"): + kwargs["parallelism"] = parallelism + if batch_size := config.get("load_batch_size"): + kwargs["batch_size"] = batch_size + + assert config.db_name + assert config.host + assert config.username + assert config.password # TODO: Remove ignore when phenolrs is published return NetworkXLoader.load_into_networkx( # type: ignore - G.db.name, - metagraph, - [G._host], - username=G._username, - password=G._password, + config.db_name, + metagraph=metagraph, + hosts=[config.host], + username=config.username, + password=config.password, load_node_dict=load_node_dict, load_adj_dict=load_adj_dict, load_adj_dict_as_directed=load_adj_dict_as_directed, + load_adj_dict_as_multigraph=load_adj_dict_as_multigraph, load_coo=load_coo, **kwargs, ) @@ -103,7 +108,7 @@ def logger_debug(func: Callable[..., Any]) -> Any: """Decorator to log debug messages.""" def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - logger.debug(f"{func.__name__} - {args} - {kwargs}") + logger.debug(f"{type(self)}.{func.__name__} - {args} - {kwargs}") return func(self, *args, **kwargs) return wrapper diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index 95e9e2d6..40f65637 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -10,9 +10,10 @@ from arango.cursor import Cursor from arango.database import StandardDatabase from arango.exceptions import ServerConnectionError +from networkx.utils import Config import nx_arangodb as nxadb -from nx_arangodb.exceptions import DatabaseNotSet, GraphNameNotSet +from nx_arangodb.exceptions import DatabaseNotSet, GraphDoesNotExist, GraphNameNotSet from nx_arangodb.logger import logger from .dict import ( @@ -85,6 +86,7 @@ def __init__( self.adb_graph = self.db.graph(self.__graph_name) self.__create_default_collections() self.__set_factory_methods() + self.__set_arangodb_backend_config() elif self.__graph_name and incoming_graph_data is not None: # TODO: Parameterize the edge definitions @@ -114,6 +116,7 @@ def __init__( ) self.__set_factory_methods() + self.__set_arangodb_backend_config() self.__graph_exists_in_db = True super().__init__(*args, **kwargs) @@ -122,6 +125,22 @@ def __init__( # Init helper methods # ####################### + def __set_arangodb_backend_config(self) -> None: + if not all([self._host, self._username, self._password, self._db_name]): + m = "Must set all environment variables to use the ArangoDB Backend with an existing graph" # noqa: E501 + raise OSError(m) + + config = Config( + host=self._host, + username=self._username, + password=self._password, + db_name=self._db_name, + load_parallelism=self.graph_loader_parallelism, + load_batch_size=self.graph_loader_batch_size, + ) + + nx.config.backends.arangodb = config + def __set_factory_methods(self) -> None: """Set the factory methods for the graph, _node, and _adj dictionaries. @@ -189,6 +208,11 @@ def graph_exists_in_db(self) -> bool: ########### def __set_db(self, db: StandardDatabase | None = None) -> None: + self._host = os.getenv("DATABASE_HOST") + self._username = os.getenv("DATABASE_USERNAME") + self._password = os.getenv("DATABASE_PASSWORD") + self._db_name = os.getenv("DATABASE_NAME") + if db is not None: if not isinstance(db, StandardDatabase): m = "arango.database.StandardDatabase" @@ -198,11 +222,6 @@ def __set_db(self, db: StandardDatabase | None = None) -> None: self.__db = db return - self._host = os.getenv("DATABASE_HOST") - self._username = os.getenv("DATABASE_USERNAME") - self._password = os.getenv("DATABASE_PASSWORD") - self._db_name = os.getenv("DATABASE_NAME") - # TODO: Raise a custom exception if any of the environment # variables are missing. For now, we'll just set db to None. if not all([self._host, self._username, self._password, self._db_name]): @@ -210,13 +229,9 @@ def __set_db(self, db: StandardDatabase | None = None) -> None: logger.warning("Database environment variables not set") return - try: - self.__db = ArangoClient(hosts=self._host, request_timeout=None).db( - self._db_name, self._username, self._password, verify=True - ) - except ServerConnectionError as e: - self.__db = None - logger.warning(f"Could not connect to the database: {e}") + self.__db = ArangoClient(hosts=self._host, request_timeout=None).db( + self._db_name, self._username, self._password, verify=True + ) def __set_graph_name(self, graph_name: str | None = None) -> None: if self.__db is None: @@ -243,7 +258,7 @@ def __set_graph_name(self, graph_name: str | None = None) -> None: def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Cursor: return nxadb.classes.function.aql(self.db, query, bind_vars, **kwargs) - # NOTE: Ignore this for now + # NOTE: OUT OF SERVICE # def chat(self, prompt: str) -> str: # if self.__qa_chain is None: # if not self.__graph_exists_in_db: @@ -268,67 +283,6 @@ def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Curs # print(result["result"]) - def pull(self, load_node_dict=True, load_adj_dict=True, load_coo=True): - """Load the graph from the ArangoDB database, and update existing graph object. - - :param load_node_dict: Load the node dictionary. - Enabling this option will clear the existing node dictionary, - and replace it with the node data from the database. Comes with - a remote reference to the database. - :type load_node_dict: bool - :param load_adj_dict: Load the adjacency dictionary. - Enabling this option will clear the existing adjacency dictionary, - and replace it with the edge data from the database. Comes with - a remote reference to the database. - :type load_adj_dict: bool - :param load_coo: Load the COO representation. If False, the src & dst - indices will be empty, along with the node-ID-to-index mapping. - Used for nx-cuGraph compatibility. - :type load_coo: bool - """ - node_dict, adj_dict, src_indices, dst_indices, vertex_ids_to_indices = ( - nxadb.classes.function.get_arangodb_graph( - self, - load_node_dict=load_node_dict, - load_adj_dict=load_adj_dict, - load_adj_dict_as_directed=False, - load_coo=load_coo, - ) - ) - - if load_node_dict: - self._node.clear() - - for node_id, node_data in node_dict.items(): - node_attr_dict = self.node_attr_dict_factory() - node_attr_dict.node_id = node_id - node_attr_dict.data = node_data - self._node.data[node_id] = node_attr_dict - - if load_adj_dict: - self._adj.clear() - - for src_node_id, dst_dict in adj_dict.items(): - adjlist_inner_dict = self.adjlist_inner_dict_factory() - adjlist_inner_dict.src_node_id = src_node_id - - self._adj.data[src_node_id] = adjlist_inner_dict - - for dst_id, edge_data in dst_dict.items(): - edge_attr_dict = self.edge_attr_dict_factory() - edge_attr_dict.edge_id = edge_data["_id"] - edge_attr_dict.data = edge_data - - adjlist_inner_dict.data[dst_id] = edge_attr_dict - - if load_coo: - self.src_indices = src_indices - self.dst_indices = dst_indices - self.vertex_ids_to_index = vertex_ids_to_indices - - def push(self): - raise NotImplementedError("What would this look like?") - ##################### # nx.Graph Overides # ##################### diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index bf0d31f3..fe03b84b 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -133,39 +133,41 @@ def from_networkx_arangodb( # return G logger.debug("pulling as NetworkX Graph...") - print(f"Fetching {G.graph_name} as Node & Adj dictionaries...") + print(f"Fetching {G.graph_name} as dictionaries...") start_time = time.time() - node_dict, adj_dict, _, _, _ = nxadb.classes.function.get_arangodb_graph( - G, - load_node_dict=True, + _, adj_dict, _, _, _ = nxadb.classes.function.get_arangodb_graph( + adb_graph=G.adb_graph, + load_node_dict=False, # TODO: Should we load node dict? load_adj_dict=True, load_adj_dict_as_directed=G.is_directed(), + load_adj_dict_as_multigraph=G.is_multigraph(), load_coo=False, ) end_time = time.time() logger.debug(f"load took {end_time - start_time} seconds") - print(f"ADB -> Node & Adj load took {end_time - start_time} seconds") + print(f"ADB -> Dictionaries load took {end_time - start_time} seconds") - # Copied from nx.convert.to_networkx_graph - try: - logger.debug("creating nx graph from loaded ArangoDB data...") - print("Creating nx graph from loaded ArangoDB data...") - start_time = time.time() - result: nx.Graph = nx.convert.from_dict_of_dicts( - adj_dict, - create_using=G.to_networkx_class(), - multigraph_input=G.is_multigraph(), - ) + return G.to_networkx_class()(incoming_graph_data=adj_dict) - for n, dd in node_dict.items(): - result._node[n].update(dd) - end_time = time.time() - print(f"NX Graph creation took {end_time - start_time}") + # try: + # logger.debug("creating nx graph from loaded ArangoDB data...") + # print("Creating nx graph from loaded ArangoDB data...") + # start_time = time.time() + # result: nx.Graph = nx.convert.from_dict_of_dicts( + # adj_dict, + # create_using=G.to_networkx_class(), + # multigraph_input=G.is_multigraph(), + # ) + + # for n, dd in node_dict.items(): + # result._node[n].update(dd) + # end_time = time.time() + # print(f"NX Graph creation took {end_time - start_time}") - return result + # return result - except Exception as err: - raise nx.NetworkXError("Input is not a correct NetworkX graph.") from err + # except Exception as err: + # raise nx.NetworkXError("Input is not a correct NetworkX graph.") from err def _to_nx_graph( @@ -238,10 +240,11 @@ def nxcg_from_networkx_arangodb( start_time = time.time() _, _, src_indices, dst_indices, vertex_ids_to_index = ( nxadb.classes.function.get_arangodb_graph( - G, + adb_graph=G.adb_graph, load_node_dict=False, load_adj_dict=False, - load_adj_dict_as_directed=G.is_directed(), + load_adj_dict_as_directed=G.is_directed(), # not used + load_adj_dict_as_multigraph=G.is_multigraph(), # not used load_coo=True, ) ) diff --git a/phenolrs-0.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl b/phenolrs-0.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl index db5add3a..2b3fd2ae 100644 Binary files a/phenolrs-0.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl and b/phenolrs-0.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl differ