diff --git a/nx_arangodb/classes/coreviews.py b/nx_arangodb/classes/coreviews.py new file mode 100644 index 00000000..08abcd48 --- /dev/null +++ b/nx_arangodb/classes/coreviews.py @@ -0,0 +1,7 @@ +import networkx as nx + + +class CustomAdjacencyView(nx.classes.coreviews.AdjacencyView): + + def update(self, data): + return self._atlas.update(data) diff --git a/nx_arangodb/classes/dict/adj.py b/nx_arangodb/classes/dict/adj.py index 380584eb..2a586384 100644 --- a/nx_arangodb/classes/dict/adj.py +++ b/nx_arangodb/classes/dict/adj.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from collections import UserDict from collections.abc import Iterator from itertools import islice @@ -8,12 +9,19 @@ from arango.database import StandardDatabase from arango.exceptions import DocumentDeleteError from arango.graph import Graph +from phenolrs.networkx.typings import ( + DiGraphAdjDict, + GraphAdjDict, + MultiDiGraphAdjDict, + MultiGraphAdjDict, +) from nx_arangodb.exceptions import EdgeTypeAmbiguity, MultipleEdgesFound from nx_arangodb.logger import logger from ..enum import DIRECTED_GRAPH_TYPES, MULTIGRAPH_TYPES, GraphType, TraversalDirection from ..function import ( + ArangoDBBatchError, aql, aql_doc_get_key, aql_doc_has_key, @@ -23,6 +31,7 @@ aql_edge_get, aql_edge_id, aql_fetch_data_edge, + check_list_for_errors, doc_insert, doc_update, get_arangodb_graph, @@ -36,6 +45,8 @@ keys_are_not_reserved, keys_are_strings, logger_debug, + separate_edges_by_collections, + upsert_collection_edges, ) ############# @@ -169,7 +180,7 @@ def __init__( self.graph = graph self.edge_id: str | None = None - # NodeAttrDict may be a child of another NodeAttrDict + # EdgeAttrDict may be a child of another EdgeAttrDict # e.g G._adj['node/1']['node/2']['object']['foo'] = 'bar' # In this case, **parent_keys** would be ['object'] # and **root** would be G._adj['node/1']['node/2'] @@ -1482,8 +1493,31 @@ def clear(self) -> None: @keys_are_strings @logger_debug def update(self, edges: Any) -> None: - """g._adj.update({'node/1': {'node/2': {'foo': 'bar'}})""" - raise NotImplementedError("AdjListOuterDict.update()") + """g._adj.update({'node/1': {'node/2': {'_id': 'foo/bar', 'foo': "bar"}})""" + separated_by_edge_collection = separate_edges_by_collections( + edges, graph_type=self.graph_type, default_node_type=self.default_node_type + ) + result = upsert_collection_edges(self.db, separated_by_edge_collection) + + all_good = check_list_for_errors(result) + if all_good: + # Means no single operation failed, in this case we update the local cache + self.__set_adj_elements(edges) + else: + # In this case some or all documents failed. Right now we will not + # update the local cache, but raise an error instead. + # Reason: We cannot set silent to True, because we need as it does + # not report errors then. We need to update the driver to also pass + # the errors back to the user, then we can adjust the behavior here. + # This will also save network traffic and local computation time. + errors = [] + for collections_results in result: + for collection_result in collections_results: + errors.append(collection_result) + warnings.warn( + "Failed to insert at least one node. Will not update local cache." + ) + raise ArangoDBBatchError(errors) @logger_debug def values(self) -> Any: @@ -1507,22 +1541,44 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any: yield from result @logger_debug - def _fetch_all(self) -> None: - self.clear() + def __set_adj_elements( + self, + edges_dict: ( + GraphAdjDict | DiGraphAdjDict | MultiGraphAdjDict | MultiDiGraphAdjDict + ), + ) -> None: + def set_edge_graph( + src_node_id: str, dst_node_id: str, edge: dict[str, Any] + ) -> EdgeAttrDict: + adjlist_inner_dict = self.data[src_node_id] + + edge_attr_dict: EdgeAttrDict + edge_attr_dict = adjlist_inner_dict._create_edge_attr_dict(edge) + + adjlist_inner_dict.data[dst_node_id] = edge_attr_dict + + return edge_attr_dict + + def set_edge_multigraph( + src_node_id: str, dst_node_id: str, edges: dict[int, dict[str, Any]] + ) -> EdgeKeyDict: + adjlist_inner_dict = self.data[src_node_id] + + edge_key_dict = adjlist_inner_dict.edge_key_dict_factory() + edge_key_dict.src_node_id = src_node_id + edge_key_dict.dst_node_id = dst_node_id + edge_key_dict.FETCHED_ALL_DATA = True + edge_key_dict.FETCHED_ALL_IDS = True + + for edge in edges.values(): + edge_attr_dict = adjlist_inner_dict._create_edge_attr_dict(edge) + edge_key_dict.data[edge["_id"]] = edge_attr_dict - def set_adj_inner_dict( - adj_outer_dict: AdjListOuterDict, node_id: str - ) -> AdjListInnerDict: - if node_id in adj_outer_dict.data: - return adj_outer_dict.data[node_id] + adjlist_inner_dict.data[dst_node_id] = edge_key_dict - adj_inner_dict = self.adjlist_inner_dict_factory() - adj_inner_dict.src_node_id = node_id - adj_inner_dict.FETCHED_ALL_DATA = True - adj_inner_dict.FETCHED_ALL_IDS = True - adj_outer_dict.data[node_id] = adj_inner_dict + return edge_key_dict - return adj_inner_dict + set_edge_func = set_edge_multigraph if self.is_multigraph else set_edge_graph def propagate_edge_undirected( src_node_id: str, @@ -1536,7 +1592,7 @@ def propagate_edge_directed( dst_node_id: str, edge_key_or_attr_dict: EdgeKeyDict | EdgeAttrDict, ) -> None: - set_adj_inner_dict(self.mirror, dst_node_id) + self.__set_adj_inner_dict(self.mirror, dst_node_id) self.mirror.data[dst_node_id].data[src_node_id] = edge_key_or_attr_dict def propagate_edge_directed_symmetric( @@ -1546,7 +1602,7 @@ def propagate_edge_directed_symmetric( ) -> None: propagate_edge_directed(src_node_id, dst_node_id, edge_key_or_attr_dict) propagate_edge_undirected(src_node_id, dst_node_id, edge_key_or_attr_dict) - set_adj_inner_dict(self.mirror, src_node_id) + self.__set_adj_inner_dict(self.mirror, src_node_id) self.mirror.data[src_node_id].data[dst_node_id] = edge_key_or_attr_dict propagate_edge_func = ( @@ -1559,38 +1615,39 @@ def propagate_edge_directed_symmetric( ) ) - def set_edge_graph( - src_node_id: str, dst_node_id: str, edge: dict[str, Any] - ) -> EdgeAttrDict: - adjlist_inner_dict = self.data[src_node_id] - - edge_attr_dict: EdgeAttrDict - edge_attr_dict = adjlist_inner_dict._create_edge_attr_dict(edge) + for src_node_id, inner_dict in edges_dict.items(): + for dst_node_id, edge_or_edges in inner_dict.items(): - adjlist_inner_dict.data[dst_node_id] = edge_attr_dict + if not self.is_directed: + if src_node_id in self.data: + if dst_node_id in self.data[src_node_id].data: + continue # can skip due not directed - return edge_attr_dict + self.__set_adj_inner_dict(self, src_node_id) + self.__set_adj_inner_dict(self, dst_node_id) + edge_attr_or_key_dict = set_edge_func( # type: ignore[operator] + src_node_id, dst_node_id, edge_or_edges + ) - def set_edge_multigraph( - src_node_id: str, dst_node_id: str, edges: dict[int, dict[str, Any]] - ) -> EdgeKeyDict: - adjlist_inner_dict = self.data[src_node_id] + propagate_edge_func(src_node_id, dst_node_id, edge_attr_or_key_dict) - edge_key_dict = adjlist_inner_dict.edge_key_dict_factory() - edge_key_dict.src_node_id = src_node_id - edge_key_dict.dst_node_id = dst_node_id - edge_key_dict.FETCHED_ALL_DATA = True - edge_key_dict.FETCHED_ALL_IDS = True + def __set_adj_inner_dict( + self, adj_outer_dict: AdjListOuterDict, node_id: str + ) -> AdjListInnerDict: + if node_id in adj_outer_dict.data: + return adj_outer_dict.data[node_id] - for edge in edges.values(): - edge_attr_dict = adjlist_inner_dict._create_edge_attr_dict(edge) - edge_key_dict.data[edge["_id"]] = edge_attr_dict - - adjlist_inner_dict.data[dst_node_id] = edge_key_dict + adj_inner_dict = self.adjlist_inner_dict_factory() + adj_inner_dict.src_node_id = node_id + adj_inner_dict.FETCHED_ALL_DATA = True + adj_inner_dict.FETCHED_ALL_IDS = True + adj_outer_dict.data[node_id] = adj_inner_dict - return edge_key_dict + return adj_inner_dict - set_edge_func = set_edge_multigraph if self.is_multigraph else set_edge_graph + @logger_debug + def _fetch_all(self) -> None: + self.clear() ( _, @@ -1613,21 +1670,7 @@ def set_edge_multigraph( if self.is_directed: adj_dict = adj_dict["succ"] - for src_node_id, inner_dict in adj_dict.items(): - for dst_node_id, edge_or_edges in inner_dict.items(): - - if not self.is_directed: - if src_node_id in self.data: - if dst_node_id in self.data[src_node_id].data: - continue # can skip due not directed - - set_adj_inner_dict(self, src_node_id) - set_adj_inner_dict(self, dst_node_id) - edge_attr_or_key_dict = set_edge_func( # type: ignore[operator] - src_node_id, dst_node_id, edge_or_edges - ) - - propagate_edge_func(src_node_id, dst_node_id, edge_attr_or_key_dict) + self.__set_adj_elements(adj_dict) self.FETCHED_ALL_DATA = True self.FETCHED_ALL_IDS = True diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py index f0985b05..d62c92c8 100644 --- a/nx_arangodb/classes/function.py +++ b/nx_arangodb/classes/function.py @@ -30,6 +30,7 @@ from nx_arangodb.logger import logger from ..exceptions import AQLMultipleResultsFound, InvalidTraversalDirection +from .enum import GraphType def do_load_all_edge_attributes(attributes: set[str]) -> bool: @@ -242,7 +243,7 @@ def wrapper(self: Any, data: Any, *args: Any, **kwargs: Any) -> Any: return wrapper -RESERVED_KEYS = {"_id", "_key", "_rev"} +RESERVED_KEYS = {"_id", "_key", "_rev", "_from", "_to"} def key_is_not_reserved(func: Callable[..., Any]) -> Any: @@ -744,3 +745,107 @@ def upsert_collection_documents(db: StandardDatabase, separated: Any) -> Any: ) return results + + +def separate_edges_by_collections_graph(edges: Any, default_node_type: str) -> Any: + """ + Separate the dictionary into collections for Graph and DiGraph types. + :param edges: The input dictionary with keys that must contain the real doc id. + :param default_node_type: The name of the default collection for keys without '/'. + :return: A dictionary where the keys are collection names and the + values are dictionaries of key-value pairs belonging to those collections. + """ + separated: Any = {} + + for from_doc_id, target_dict in edges.items(): + for to_doc_id, edge_doc in target_dict.items(): + assert edge_doc is not None and "_id" in edge_doc + edge_collection_name = get_node_type_and_id( + edge_doc["_id"], default_node_type + )[0] + + if edge_collection_name not in separated: + separated[edge_collection_name] = [] + + edge_doc["_from"] = from_doc_id + edge_doc["_to"] = to_doc_id + + separated[edge_collection_name].append(edge_doc) + + return separated + + +def separate_edges_by_collections_multigraph(edges: Any, default_node_type: str) -> Any: + """ + Separate the dictionary into collections for MultiGraph and MultiDiGraph types. + :param edges: The input dictionary with keys that must contain the real doc id. + :param default_node_type: The name of the default collection for keys without '/'. + :return: A dictionary where the keys are collection names and the + values are dictionaries of key-value pairs belonging to those collections. + """ + separated: Any = {} + + for from_doc_id, target_dict in edges.items(): + for to_doc_id, edge_doc in target_dict.items(): + # edge_doc is expected to be a list of edges in Multi(Di)Graph + for m_edge_id, m_edge_doc in edge_doc.items(): + assert m_edge_doc is not None and "_id" in m_edge_doc + edge_collection_name = get_node_type_and_id( + m_edge_doc["_id"], default_node_type + )[0] + + if edge_collection_name not in separated: + separated[edge_collection_name] = [] + + m_edge_doc["_from"] = from_doc_id + m_edge_doc["_to"] = to_doc_id + + separated[edge_collection_name].append(m_edge_doc) + + return separated + + +def separate_edges_by_collections( + edges: Any, graph_type: str, default_node_type: str +) -> Any: + """ + Wrapper function to separate the dictionary into collections based on graph type. + :param edges: The input dictionary with keys that must contain the real doc id. + :param graph_type: The type of graph to create. + :param default_node_type: The name of the default collection for keys without '/'. + :return: A dictionary where the keys are collection names and the + values are dictionaries of key-value pairs belonging to those collections. + """ + if graph_type in [GraphType.Graph.name, GraphType.DiGraph.name]: + return separate_edges_by_collections_graph(edges, default_node_type) + elif graph_type in [GraphType.MultiGraph.name, GraphType.MultiDiGraph.name]: + return separate_edges_by_collections_multigraph(edges, default_node_type) + else: + raise ValueError(f"Unsupported graph type: {graph_type}") + + +def upsert_collection_edges(db: StandardDatabase, separated: Any) -> Any: + """ + Process each collection in the separated dictionary. + :param db: The ArangoDB database object. + :param separated: A dictionary where the keys are collection names and the + values are dictionaries + of key-value pairs belonging to those collections. + :return: A list of results from the insert_many operation. + If inserting a document fails, the exception is not raised but + returned as an object in the result list. + """ + + results = [] + + for collection_name, documents_list in separated.items(): + collection = db.collection(collection_name) + results.append( + collection.insert_many( + documents_list, + silent=False, + overwrite_mode="update", + ) + ) + + return results diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index 6a0f13db..954e4dde 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -19,6 +19,7 @@ ) from nx_arangodb.logger import logger +from .coreviews import CustomAdjacencyView from .dict import ( adjlist_inner_dict_factory, adjlist_outer_dict_factory, @@ -386,6 +387,14 @@ def nodes(self): return super().nodes + @cached_property + def adj(self): + if self.graph_exists_in_db: + logger.warning("nxadb.CustomAdjacencyView is currently EXPERIMENTAL") + return CustomAdjacencyView(self._adj) + + return super().adj + @cached_property def edges(self): if self.graph_exists_in_db: diff --git a/tests/test.py b/tests/test.py index f69c742d..79527220 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1,8 +1,14 @@ -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Union import networkx as nx import pytest from arango import DocumentDeleteError +from phenolrs.networkx.typings import ( + DiGraphAdjDict, + GraphAdjDict, + MultiDiGraphAdjDict, + MultiGraphAdjDict, +) import nx_arangodb as nxadb from nx_arangodb.classes.dict.adj import EdgeAttrDict @@ -13,6 +19,10 @@ G_NX = nx.karate_club_graph() +def extract_arangodb_key(adb_id: str) -> str: + return adb_id.split("/")[1] + + def create_line_graph(load_attributes: set[str]) -> nxadb.Graph: G = nx.Graph() G.add_edge(1, 2, my_custom_weight=1) @@ -271,9 +281,6 @@ def test_node_dict_update_existing_single_collection( # of them using the update method using a single collection G_1 = nxadb.Graph(graph_name="KarateGraph", foo="bar") - def extract_arangodb_key(adb_id: str) -> str: - return adb_id.split("/")[1] - nodes_ids_list = G_1.nodes local_nodes_dict = {} @@ -352,6 +359,166 @@ def test_node_dict_update_multiple_collections( assert f"{v_2_name}/{i}" in G_1.nodes +@pytest.mark.parametrize( + "graph_cls", + [ + (nxadb.Graph), + (nxadb.DiGraph), + ], +) +def test_edge_adj_dict_update_existing_single_collection_graph_and_digraph( + load_karate_graph: Any, graph_cls: type[nxadb.Graph] +) -> None: + G_1 = graph_cls(graph_name="KarateGraph", foo="bar") + + local_adj = G_1.adj + local_edges_dict: Union[GraphAdjDict | DiGraphAdjDict] = {} + if graph_cls == nxadb.Graph: + local_edges_dict = GraphAdjDict() + elif graph_cls == nxadb.DiGraph: + local_edges_dict = DiGraphAdjDict() + + for from_doc_id, target_dict in local_adj.items(): + for to_doc_id, edge_doc in target_dict.items(): + edge_doc_id = edge_doc["_id"] + if from_doc_id not in local_edges_dict: + local_edges_dict[from_doc_id] = {} + + local_edges_dict[from_doc_id][to_doc_id] = { + "_id": edge_doc_id, + "extraValue": edge_doc["_key"], + } + + G_1.adj.update(local_edges_dict) + + edge_col = db.collection("knows") + edge_col_docs = edge_col.all() + + # Check if the extraValue attribute was added to each document in the database + for doc in edge_col_docs: + assert "extraValue" in doc + assert doc["extraValue"] == doc["_key"] + + # Check if the extraValue attribute was added to each document in the local cache + for from_doc_id, target_dict in local_edges_dict.items(): + for to_doc_id, edge_doc in target_dict.items(): + assert "extraValue" in G_1._adj[from_doc_id][to_doc_id] + assert G_1.adj[from_doc_id][to_doc_id][ + "extraValue" + ] == extract_arangodb_key(edge_doc["_id"]) + + +@pytest.mark.parametrize( + "graph_cls", + [ + (nxadb.MultiGraph), + (nxadb.MultiDiGraph), + ], +) +def test_edge_adj_dict_update_existing_single_collection_MultiGraph_and_MultiDiGraph( + load_karate_graph: Any, graph_cls: type[nxadb.Graph] +) -> None: + G_1 = graph_cls(graph_name="KarateGraph", foo="bar") + + local_adj = G_1.adj + local_edges_dict: Union[MultiGraphAdjDict | MultiDiGraphAdjDict] = {} + if graph_cls == nxadb.MultiGraph: + local_edges_dict = MultiGraphAdjDict() + elif graph_cls == nxadb.MultiDiGraph: + local_edges_dict = MultiDiGraphAdjDict() + + for from_doc_id, target_dict in local_adj.items(): + for to_doc_id, edge_dict in target_dict.items(): + for edge_id, edge_doc in edge_dict.items(): + if from_doc_id not in local_edges_dict: + local_edges_dict[from_doc_id] = {} + + if to_doc_id not in local_edges_dict[from_doc_id]: + local_edges_dict[from_doc_id][to_doc_id] = {} + + local_edges_dict[from_doc_id][to_doc_id][edge_id] = { + "_id": edge_doc["_id"], + "extraValue": edge_doc["_key"], + } + + G_1.adj.update(local_edges_dict) + + edge_col = db.collection("knows") + edge_col_docs = edge_col.all() + + # Check if the extraValue attribute was added to each document in the database + for doc in edge_col_docs: + assert "extraValue" in doc + assert doc["extraValue"] == doc["_key"] + + # Check if the extraValue attribute was added to each document in the local cache + for from_doc_id, target_dict in local_edges_dict.items(): + for to_doc_id, edge_dict in target_dict.items(): + for edge_id, edge_doc in edge_dict.items(): + assert "extraValue" in G_1._adj[from_doc_id][to_doc_id][edge_id] + assert G_1.adj[from_doc_id][to_doc_id][edge_id][ + "extraValue" + ] == extract_arangodb_key(edge_doc["_id"]) + + +def test_edge_dict_update_multiple_collections(load_two_relation_graph: Any) -> None: + graph_name = "IntegrationTestTwoRelationGraph" + v_1_name = graph_name + "_v1" + v_2_name = graph_name + "_v2" + e_1_name = graph_name + "_e1" + e_2_name = graph_name + "_e2" + + assert db.collection(v_1_name).count() == 0 + assert db.collection(v_2_name).count() == 0 + assert db.collection(e_1_name).count() == 0 + assert db.collection(e_2_name).count() == 0 + + G_1 = nxadb.Graph(graph_name=graph_name, default_node_type=v_1_name) + assert len(G_1.nodes) == 0 + assert len(G_1.edges) == 0 + + # inserts into first collection (by default) + new_edges_dict: GraphAdjDict = { + graph_name + + "_v1/1": { + graph_name + "_v1/2": {"_id": e_1_name + "/1"}, + graph_name + "_v1/3": {"_id": e_1_name + "/2"}, + }, + graph_name + + "_v2/1": { + graph_name + "_v1/2": {"_id": e_2_name + "/1"}, + graph_name + "_v1/3": {"_id": e_2_name + "/2"}, + }, + } + + G_1.adj.update(new_edges_dict) + + # _adj list is not responsible for maintaining the vertex collections + assert db.collection(v_1_name).count() == 0 + assert db.collection(v_2_name).count() == 0 + + assert db.collection(e_1_name).count() == 2 + assert db.collection(e_2_name).count() == 2 + + # Check that the edge ids are present in the database + assert db.has_document({"_id": e_1_name + "/1"}) + assert db.has_document({"_id": e_1_name + "/2"}) + assert db.has_document({"_id": e_2_name + "/1"}) + assert db.has_document({"_id": e_2_name + "/2"}) + + # Check local state + assert len(G_1.nodes) == 0 + assert len(G_1.edges) == 4 + + local_edge_cache = G_1._adj + assert f"{v_1_name}/{1}" in local_edge_cache + assert f"{v_2_name}/{1}" in local_edge_cache + assert f"{v_1_name}/{2}" in local_edge_cache[f"{v_1_name}/{1}"] + assert f"{v_1_name}/{3}" in local_edge_cache[f"{v_1_name}/{1}"] + assert f"{v_1_name}/{2}" in local_edge_cache[f"{v_2_name}/{1}"] + assert f"{v_1_name}/{3}" in local_edge_cache[f"{v_2_name}/{1}"] + + @pytest.mark.parametrize( "graph_cls", [