diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index b7589294..6a0f13db 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -76,14 +76,10 @@ def __init__( # self.maintain_node_dict_cache = False # self.maintain_adj_dict_cache = False self.use_nx_cache = True - self.use_coo_cache = True - # self.__qa_chain = None + self.use_nxcg_cache = True + self.nxcg_graph = None - self.src_indices: npt.NDArray[np.int64] | None = None - self.dst_indices: npt.NDArray[np.int64] | None = None - self.edge_indices: npt.NDArray[np.int64] | None = None - self.vertex_ids_to_index: dict[str, int] | None = None - self.edge_values: dict[str, list[int | float]] | None = None + # self.__qa_chain = None # Does not apply to undirected graphs self.symmetrize_edges = symmetrize_edges @@ -379,6 +375,9 @@ def clear_edges(self): logger.info("Note that clearing edges ony erases the edges in the local cache") super().clear_edges() + def clear_nxcg_cache(self): + self.nxcg_graph = None + @cached_property def nodes(self): if self.graph_exists_in_db: diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index a82ede16..943daa76 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -135,7 +135,7 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph: symmetrize_edges_if_directed=G.symmetrize_edges if G.is_directed() else False, ) - print(f"ADB -> Dictionaries load took {time.time() - start_time}s") + print(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s") G_NX: nx.Graph | nx.DiGraph = G.to_networkx_class()() G_NX._node = node_dict @@ -153,57 +153,45 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph: if GPU_ENABLED: def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph: - if ( - G.use_coo_cache - and G.src_indices is not None - and G.dst_indices is not None - and G.edge_indices is not None - and G.vertex_ids_to_index is not None - and G.edge_values is not None - ): - m = "**use_coo_cache** is enabled. using cached COO data. no pull required." + if G.use_nxcg_cache and G.nxcg_graph is not None: + m = "**use_nxcg_cache** is enabled. using cached NXCG Graph. no pull required." # noqa logger.debug(m) - else: - start_time = time.time() - - ( - _, - _, - src_indices, - dst_indices, - edge_indices, - vertex_ids_to_index, - edge_values, - ) = nxadb.classes.function.get_arangodb_graph( - adb_graph=G.adb_graph, - load_node_dict=False, - load_adj_dict=False, - load_coo=True, - edge_collections_attributes=G.get_edge_attributes, - load_all_vertex_attributes=False, # not used - load_all_edge_attributes=do_load_all_edge_attributes( - G.get_edge_attributes - ), - is_directed=G.is_directed(), - is_multigraph=G.is_multigraph(), - symmetrize_edges_if_directed=( - G.symmetrize_edges if G.is_directed() else False - ), - ) - - print(f"ADB -> COO load took {time.time() - start_time}s") - - G.src_indices = src_indices - G.dst_indices = dst_indices - G.edge_indices = edge_indices - G.vertex_ids_to_index = vertex_ids_to_index - G.edge_values = edge_values - - N = len(G.vertex_ids_to_index) # type: ignore - src_indices_cp = cp.array(G.src_indices) - dst_indices_cp = cp.array(G.dst_indices) - edge_indices_cp = cp.array(G.edge_indices) + return G.nxcg_graph + + start_time = time.time() + + ( + _, + _, + src_indices, + dst_indices, + edge_indices, + vertex_ids_to_index, + edge_values, + ) = nxadb.classes.function.get_arangodb_graph( + adb_graph=G.adb_graph, + load_node_dict=False, + load_adj_dict=False, + load_coo=True, + edge_collections_attributes=G.get_edge_attributes, + load_all_vertex_attributes=False, # not used + load_all_edge_attributes=do_load_all_edge_attributes(G.get_edge_attributes), + is_directed=G.is_directed(), + is_multigraph=G.is_multigraph(), + symmetrize_edges_if_directed=( + G.symmetrize_edges if G.is_directed() else False + ), + ) + + print(f"ADB Graph '{G.adb_graph.name}' load took {time.time() - start_time}s") + + start_time = time.time() + + N = len(vertex_ids_to_index) + src_indices_cp = cp.array(src_indices) + dst_indices_cp = cp.array(dst_indices) + edge_indices_cp = cp.array(edge_indices) if G.is_multigraph(): if G.is_directed() or as_directed: @@ -211,16 +199,16 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph: else: klass = nxcg.MultiGraph - return klass.from_coo( + G.nxcg_graph = klass.from_coo( N=N, src_indices=src_indices_cp, dst_indices=dst_indices_cp, edge_indices=edge_indices_cp, - edge_values=G.edge_values, + edge_values=edge_values, # edge_masks, # node_values, # node_masks, - key_to_id=G.vertex_ids_to_index, + key_to_id=vertex_ids_to_index, # edge_keys=edge_keys, ) @@ -230,13 +218,17 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph: else: klass = nxcg.Graph - return klass.from_coo( + G.nxcg_graph = klass.from_coo( N=N, src_indices=src_indices_cp, dst_indices=dst_indices_cp, - edge_values=G.edge_values, + edge_values=edge_values, # edge_masks, # node_values, # node_masks, - key_to_id=G.vertex_ids_to_index, + key_to_id=vertex_ids_to_index, ) + + print(f"NXCG Graph construction took {time.time() - start_time}s") + + return G.nxcg_graph