diff --git a/nx_arangodb/classes/dict.py b/nx_arangodb/classes/dict.py index 32949822..b73cd29a 100644 --- a/nx_arangodb/classes/dict.py +++ b/nx_arangodb/classes/dict.py @@ -21,10 +21,7 @@ from .enum import DIRECTED_GRAPH_TYPES, MULTIGRAPH_TYPES, GraphType, TraversalDirection from .function import ( aql, - aql_as_list, aql_doc_get_key, - aql_doc_get_keys, - aql_doc_get_length, aql_doc_has_key, aql_edge_count_src, aql_edge_count_src_dst, @@ -33,7 +30,6 @@ aql_edge_id, aql_fetch_data, aql_fetch_data_edge, - aql_single, create_collection, doc_delete, doc_get_or_insert, @@ -45,7 +41,6 @@ get_update_dict, json_serializable, key_is_adb_id_or_int, - key_is_int, key_is_not_reserved, key_is_string, keys_are_not_reserved, @@ -752,6 +747,7 @@ def _fetch_all(self): load_node_dict=True, load_adj_dict=False, load_coo=False, + edge_collections_attributes=set(), load_all_vertex_attributes=True, load_all_edge_attributes=False, # not used is_directed=False, # not used @@ -2254,6 +2250,7 @@ def set_edge_multigraph( load_node_dict=False, load_adj_dict=True, load_coo=False, + edge_collections_attributes=set(), load_all_vertex_attributes=False, # not used load_all_edge_attributes=True, is_directed=self.is_directed, diff --git a/nx_arangodb/classes/digraph.py b/nx_arangodb/classes/digraph.py index 8989ac1d..801bc21c 100644 --- a/nx_arangodb/classes/digraph.py +++ b/nx_arangodb/classes/digraph.py @@ -28,6 +28,7 @@ def __init__( default_node_type: str | None = None, edge_type_key: str = "_edge_type", edge_type_func: Callable[[str, str], str] | None = None, + edge_collections_attributes: set[str] | None = None, db: StandardDatabase | None = None, read_parallelism: int = 10, read_batch_size: int = 100000, @@ -41,6 +42,7 @@ def __init__( default_node_type, edge_type_key, edge_type_func, + edge_collections_attributes, db, read_parallelism, read_batch_size, diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py index 781c6461..ef87dc6d 100644 --- a/nx_arangodb/classes/function.py +++ b/nx_arangodb/classes/function.py @@ -21,6 +21,7 @@ DiGraphAdjDict, DstIndices, EdgeIndices, + EdgeValuesDict, GraphAdjDict, MultiDiGraphAdjDict, MultiGraphAdjDict, @@ -38,11 +39,19 @@ ) +def do_load_all_edge_attributes(attributes: set[str]) -> bool: + if len(attributes) == 0: + return True + + return False + + def get_arangodb_graph( adb_graph: Graph, load_node_dict: bool, load_adj_dict: bool, load_coo: bool, + edge_collections_attributes: set[str], load_all_vertex_attributes: bool, load_all_edge_attributes: bool, is_directed: bool, @@ -55,6 +64,7 @@ def get_arangodb_graph( DstIndices, EdgeIndices, ArangoIDtoIndex, + EdgeValuesDict, ]: """Pulls the graph from the database, assuming the graph exists. @@ -71,7 +81,7 @@ def get_arangodb_graph( metagraph: dict[str, dict[str, Any]] = { "vertexCollections": {col: set() for col in v_cols}, - "edgeCollections": {col: set() for col in e_cols}, + "edgeCollections": {col: edge_collections_attributes for col in e_cols}, } if not any((load_node_dict, load_adj_dict, load_coo)): @@ -89,6 +99,21 @@ def get_arangodb_graph( assert config.username assert config.password + res_do_load_all_edge_attributes = do_load_all_edge_attributes( + edge_collections_attributes + ) + + if res_do_load_all_edge_attributes is not load_all_edge_attributes: + if len(edge_collections_attributes) > 0: + raise ValueError( + "You have specified to load at least one specific edge attribute" + " and at the same time set the parameter `load_all_vertex_attributes`" + " to true. This combination is not allowed." + ) + else: + # We need this case as the user wants by purpose to not load any edge data + res_do_load_all_edge_attributes = load_all_edge_attributes + ( node_dict, adj_dict, @@ -106,7 +131,7 @@ def get_arangodb_graph( load_adj_dict=load_adj_dict, load_coo=load_coo, load_all_vertex_attributes=load_all_vertex_attributes, - load_all_edge_attributes=load_all_edge_attributes, + load_all_edge_attributes=res_do_load_all_edge_attributes, is_directed=is_directed, is_multigraph=is_multigraph, symmetrize_edges_if_directed=symmetrize_edges_if_directed, @@ -121,6 +146,7 @@ def get_arangodb_graph( dst_indices, edge_indices, vertex_ids_to_index, + edge_values, ) diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index 843bb2b2..b7589294 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -48,6 +48,7 @@ def __init__( default_node_type: str | None = None, edge_type_key: str = "_edge_type", edge_type_func: Callable[[str, str], str] | None = None, + edge_collections_attributes: set[str] | None = None, db: StandardDatabase | None = None, read_parallelism: int = 10, read_batch_size: int = 100000, @@ -69,6 +70,8 @@ def __init__( self.read_batch_size = read_batch_size self.write_batch_size = write_batch_size + self._set_edge_collections_attributes_to_fetch(edge_collections_attributes) + # NOTE: Need to revisit these... # self.maintain_node_dict_cache = False # self.maintain_adj_dict_cache = False @@ -80,6 +83,7 @@ def __init__( self.dst_indices: npt.NDArray[np.int64] | None = None self.edge_indices: npt.NDArray[np.int64] | None = None self.vertex_ids_to_index: dict[str, int] | None = None + self.edge_values: dict[str, list[int | float]] | None = None # Does not apply to undirected graphs self.symmetrize_edges = symmetrize_edges @@ -236,6 +240,17 @@ def _set_factory_methods(self) -> None: *adj_args, self.symmetrize_edges ) + def _set_edge_collections_attributes_to_fetch( + self, attributes: set[str] | None + ) -> None: + if attributes is None: + self._edge_collections_attributes = set() + return + if len(attributes) > 0: + self._edge_collections_attributes = attributes + if "_id" not in attributes: + self._edge_collections_attributes.add("_id") + ########### # Getters # ########### @@ -258,6 +273,10 @@ def graph_name(self) -> str: def graph_exists_in_db(self) -> bool: return self._graph_exists_in_db + @property + def get_edge_attributes(self) -> set[str]: + return self._edge_collections_attributes + ########### # Setters # ########### diff --git a/nx_arangodb/classes/multidigraph.py b/nx_arangodb/classes/multidigraph.py index 2ec8d3cc..189d987c 100644 --- a/nx_arangodb/classes/multidigraph.py +++ b/nx_arangodb/classes/multidigraph.py @@ -27,6 +27,7 @@ def __init__( default_node_type: str | None = None, edge_type_key: str = "_edge_type", edge_type_func: Callable[[str, str], str] | None = None, + edge_collections_attributes: set[str] | None = None, db: StandardDatabase | None = None, read_parallelism: int = 10, read_batch_size: int = 100000, @@ -40,6 +41,7 @@ def __init__( default_node_type, edge_type_key, edge_type_func, + edge_collections_attributes, db, read_parallelism, read_batch_size, diff --git a/nx_arangodb/classes/multigraph.py b/nx_arangodb/classes/multigraph.py index b3a3a43f..c108456e 100644 --- a/nx_arangodb/classes/multigraph.py +++ b/nx_arangodb/classes/multigraph.py @@ -28,6 +28,7 @@ def __init__( default_node_type: str | None = None, edge_type_key: str = "_edge_type", edge_type_func: Callable[[str, str], str] | None = None, + edge_collections_attributes: set[str] | None = None, db: StandardDatabase | None = None, read_parallelism: int = 10, read_batch_size: int = 100000, @@ -40,6 +41,7 @@ def __init__( default_node_type, edge_type_key, edge_type_func, + edge_collections_attributes, db, read_parallelism, read_batch_size, diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index 72a5a09a..a82ede16 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -6,6 +6,7 @@ import networkx as nx import nx_arangodb as nxadb +from nx_arangodb.classes.function import do_load_all_edge_attributes from nx_arangodb.logger import logger try: @@ -126,9 +127,9 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph: load_node_dict=True, load_adj_dict=True, load_coo=False, + edge_collections_attributes=G.get_edge_attributes, load_all_vertex_attributes=False, - # TODO: Only return the edge attributes that are needed - load_all_edge_attributes=True, + load_all_edge_attributes=do_load_all_edge_attributes(G.get_edge_attributes), is_directed=G.is_directed(), is_multigraph=G.is_multigraph(), symmetrize_edges_if_directed=G.symmetrize_edges if G.is_directed() else False, @@ -158,6 +159,7 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph: and G.dst_indices is not None and G.edge_indices is not None and G.vertex_ids_to_index is not None + and G.edge_values is not None ): m = "**use_coo_cache** is enabled. using cached COO data. no pull required." logger.debug(m) @@ -165,20 +167,29 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph: else: start_time = time.time() - _, _, src_indices, dst_indices, edge_indices, vertex_ids_to_index = ( - nxadb.classes.function.get_arangodb_graph( - adb_graph=G.adb_graph, - load_node_dict=False, - load_adj_dict=False, - load_coo=True, - load_all_vertex_attributes=False, # not used - load_all_edge_attributes=False, # not used - is_directed=G.is_directed(), - is_multigraph=G.is_multigraph(), - symmetrize_edges_if_directed=( - G.symmetrize_edges if G.is_directed() else False - ), - ) + ( + _, + _, + src_indices, + dst_indices, + edge_indices, + vertex_ids_to_index, + edge_values, + ) = nxadb.classes.function.get_arangodb_graph( + adb_graph=G.adb_graph, + load_node_dict=False, + load_adj_dict=False, + load_coo=True, + edge_collections_attributes=G.get_edge_attributes, + load_all_vertex_attributes=False, # not used + load_all_edge_attributes=do_load_all_edge_attributes( + G.get_edge_attributes + ), + is_directed=G.is_directed(), + is_multigraph=G.is_multigraph(), + symmetrize_edges_if_directed=( + G.symmetrize_edges if G.is_directed() else False + ), ) print(f"ADB -> COO load took {time.time() - start_time}s") @@ -187,6 +198,7 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph: G.dst_indices = dst_indices G.edge_indices = edge_indices G.vertex_ids_to_index = vertex_ids_to_index + G.edge_values = edge_values N = len(G.vertex_ids_to_index) # type: ignore src_indices_cp = cp.array(G.src_indices) @@ -204,7 +216,7 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph: src_indices=src_indices_cp, dst_indices=dst_indices_cp, edge_indices=edge_indices_cp, - # edge_values, + edge_values=G.edge_values, # edge_masks, # node_values, # node_masks, @@ -222,7 +234,7 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph: N=N, src_indices=src_indices_cp, dst_indices=dst_indices_cp, - # edge_values, + edge_values=G.edge_values, # edge_masks, # node_values, # node_masks, diff --git a/tests/test.py b/tests/test.py index 38d96b05..a7f08309 100644 --- a/tests/test.py +++ b/tests/test.py @@ -12,6 +12,23 @@ G_NX = nx.karate_club_graph() +def create_line_graph(load_attributes: set[str]) -> nxadb.Graph: + G = nx.Graph() + G.add_edge(1, 2, my_custom_weight=1) + G.add_edge(2, 3, my_custom_weight=1) + G.add_edge(3, 4, my_custom_weight=1000) + G.add_edge(4, 5, my_custom_weight=1000) + + if load_attributes: + return nxadb.Graph( + incoming_graph_data=G, + graph_name="LineGraph", + edge_collections_attributes=load_attributes, + ) + + return nxadb.Graph(incoming_graph_data=G, graph_name="LineGraph") + + def assert_same_dict_values( d1: dict[str | int, float], d2: dict[str | int, float], digit: int ) -> None: @@ -80,6 +97,69 @@ def test_load_graph_from_nxadb(): db.delete_graph(graph_name, drop_collections=True) +def test_load_graph_from_nxadb_w_specific_edge_attribute(): + graph_name = "KarateGraph" + + db.delete_graph(graph_name, drop_collections=True, ignore_missing=True) + + graph = nxadb.Graph( + graph_name=graph_name, + incoming_graph_data=G_NX, + default_node_type="person", + edge_collections_attributes={"weight"}, + ) + # TODO: re-enable this line as soon as CPU based data caching is implemented + # graph._adj._fetch_all() + + for _from, adj in graph._adj.items(): + for _to, edge in adj.items(): + assert "weight" in edge + assert isinstance(edge["weight"], (int, float)) + + # call without specifying weight, fallback to weight: 1 for each + nx.pagerank(graph) + + # call with specifying weight + nx.pagerank(graph, weight="weight") + + db.delete_graph(graph_name, drop_collections=True) + + +def test_load_graph_from_nxadb_w_not_available_edge_attribute(): + graph_name = "KarateGraph" + + db.delete_graph(graph_name, drop_collections=True, ignore_missing=True) + + graph = nxadb.Graph( + graph_name=graph_name, + incoming_graph_data=G_NX, + default_node_type="person", + # This will lead to weight not being loaded into the edge data + edge_collections_attributes={"_id"}, + ) + + # Should just succeed without any errors (fallback to weight: 1 as above) + nx.pagerank(graph, weight="weight_does_not_exist") + + db.delete_graph(graph_name, drop_collections=True) + + +def test_load_graph_with_non_default_weight_attribute(): + graph_name = "LineGraph" + + db.delete_graph(graph_name, drop_collections=True, ignore_missing=True) + + graph = create_line_graph(load_attributes={"my_custom_weight"}) + res_custom = nx.pagerank(graph, weight="my_custom_weight") + res_default = nx.pagerank(graph) + + # to check that the results are different in case of different weights + # custom specified weights vs. fallback default weight to 1 + assert res_custom != res_default + + db.delete_graph(graph_name, drop_collections=True) + + @pytest.mark.parametrize( "algorithm_func, assert_func", [