Skip to content

fix: cache nxcg graph instead of coo representation #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions nx_arangodb/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,10 @@ def __init__(
# self.maintain_node_dict_cache = False
# self.maintain_adj_dict_cache = False
self.use_nx_cache = True
self.use_coo_cache = True
# self.__qa_chain = None
self.use_nxcg_cache = True
self.nxcg_graph = None

self.src_indices: npt.NDArray[np.int64] | None = None
self.dst_indices: npt.NDArray[np.int64] | None = None
self.edge_indices: npt.NDArray[np.int64] | None = None
self.vertex_ids_to_index: dict[str, int] | None = None
self.edge_values: dict[str, list[int | float]] | None = None
# self.__qa_chain = None

# Does not apply to undirected graphs
self.symmetrize_edges = symmetrize_edges
Expand Down Expand Up @@ -379,6 +375,9 @@ def clear_edges(self):
logger.info("Note that clearing edges ony erases the edges in the local cache")
super().clear_edges()

def clear_nxcg_cache(self):
self.nxcg_graph = None

@cached_property
def nodes(self):
if self.graph_exists_in_db:
Expand Down
104 changes: 48 additions & 56 deletions nx_arangodb/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph:
symmetrize_edges_if_directed=G.symmetrize_edges if G.is_directed() else False,
)

print(f"ADB -> Dictionaries load took {time.time() - start_time}s")
print(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")

G_NX: nx.Graph | nx.DiGraph = G.to_networkx_class()()
G_NX._node = node_dict
Expand All @@ -153,74 +153,62 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph:
if GPU_ENABLED:

def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
if (
G.use_coo_cache
and G.src_indices is not None
and G.dst_indices is not None
and G.edge_indices is not None
and G.vertex_ids_to_index is not None
and G.edge_values is not None
):
m = "**use_coo_cache** is enabled. using cached COO data. no pull required."
if G.use_nxcg_cache and G.nxcg_graph is not None:
m = "**use_nxcg_cache** is enabled. using cached NXCG Graph. no pull required." # noqa
logger.debug(m)

else:
start_time = time.time()

(
_,
_,
src_indices,
dst_indices,
edge_indices,
vertex_ids_to_index,
edge_values,
) = nxadb.classes.function.get_arangodb_graph(
adb_graph=G.adb_graph,
load_node_dict=False,
load_adj_dict=False,
load_coo=True,
edge_collections_attributes=G.get_edge_attributes,
load_all_vertex_attributes=False, # not used
load_all_edge_attributes=do_load_all_edge_attributes(
G.get_edge_attributes
),
is_directed=G.is_directed(),
is_multigraph=G.is_multigraph(),
symmetrize_edges_if_directed=(
G.symmetrize_edges if G.is_directed() else False
),
)

print(f"ADB -> COO load took {time.time() - start_time}s")

G.src_indices = src_indices
G.dst_indices = dst_indices
G.edge_indices = edge_indices
G.vertex_ids_to_index = vertex_ids_to_index
G.edge_values = edge_values

N = len(G.vertex_ids_to_index) # type: ignore
src_indices_cp = cp.array(G.src_indices)
dst_indices_cp = cp.array(G.dst_indices)
edge_indices_cp = cp.array(G.edge_indices)
return G.nxcg_graph

start_time = time.time()

(
_,
_,
src_indices,
dst_indices,
edge_indices,
vertex_ids_to_index,
edge_values,
) = nxadb.classes.function.get_arangodb_graph(
adb_graph=G.adb_graph,
load_node_dict=False,
load_adj_dict=False,
load_coo=True,
edge_collections_attributes=G.get_edge_attributes,
load_all_vertex_attributes=False, # not used
load_all_edge_attributes=do_load_all_edge_attributes(G.get_edge_attributes),
is_directed=G.is_directed(),
is_multigraph=G.is_multigraph(),
symmetrize_edges_if_directed=(
G.symmetrize_edges if G.is_directed() else False
),
)

print(f"ADB Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")

start_time = time.time()

N = len(vertex_ids_to_index)
src_indices_cp = cp.array(src_indices)
dst_indices_cp = cp.array(dst_indices)
edge_indices_cp = cp.array(edge_indices)

if G.is_multigraph():
if G.is_directed() or as_directed:
klass = nxcg.MultiDiGraph
else:
klass = nxcg.MultiGraph

return klass.from_coo(
G.nxcg_graph = klass.from_coo(
N=N,
src_indices=src_indices_cp,
dst_indices=dst_indices_cp,
edge_indices=edge_indices_cp,
edge_values=G.edge_values,
edge_values=edge_values,
# edge_masks,
# node_values,
# node_masks,
key_to_id=G.vertex_ids_to_index,
key_to_id=vertex_ids_to_index,
# edge_keys=edge_keys,
)

Expand All @@ -230,13 +218,17 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
else:
klass = nxcg.Graph

return klass.from_coo(
G.nxcg_graph = klass.from_coo(
N=N,
src_indices=src_indices_cp,
dst_indices=dst_indices_cp,
edge_values=G.edge_values,
edge_values=edge_values,
# edge_masks,
# node_values,
# node_masks,
key_to_id=G.vertex_ids_to_index,
key_to_id=vertex_ids_to_index,
)

print(f"NXCG Graph construction took {time.time() - start_time}s")

return G.nxcg_graph