diff --git a/nx_arangodb/classes/coreviews.py b/nx_arangodb/classes/coreviews.py index 08abcd48..794a648e 100644 --- a/nx_arangodb/classes/coreviews.py +++ b/nx_arangodb/classes/coreviews.py @@ -5,3 +5,12 @@ class CustomAdjacencyView(nx.classes.coreviews.AdjacencyView): def update(self, data): return self._atlas.update(data) + + def __getitem__(self, name): + return CustomAtlasView(self._atlas[name]) + + +class CustomAtlasView(nx.classes.coreviews.AtlasView): + + def update(self, data): + return self._atlas.update(data) diff --git a/nx_arangodb/classes/dict/adj.py b/nx_arangodb/classes/dict/adj.py index c2543eaa..5af99abe 100644 --- a/nx_arangodb/classes/dict/adj.py +++ b/nx_arangodb/classes/dict/adj.py @@ -4,7 +4,7 @@ from collections import UserDict from collections.abc import Iterator from itertools import islice -from typing import Any, Callable +from typing import Any, Callable, Dict, List from arango.database import StandardDatabase from arango.exceptions import DocumentDeleteError @@ -40,6 +40,7 @@ get_node_id, get_node_type_and_id, get_update_dict, + is_arangodb_id, json_serializable, key_is_adb_id_or_int, key_is_not_reserved, @@ -47,6 +48,7 @@ keys_are_not_reserved, keys_are_strings, logger_debug, + read_collection_name_from_local_id, separate_edges_by_collections, upsert_collection_edges, ) @@ -1196,7 +1198,50 @@ def clear(self) -> None: @logger_debug def update(self, edges: Any) -> None: """g._adj['node/1'].update({'node/2': {'foo': 'bar'}})""" - raise NotImplementedError("AdjListInnerDict.update()") + from_col_name = read_collection_name_from_local_id( + self.src_node_id, self.default_node_type + ) + + to_upsert: Dict[str, List[Dict[str, Any]]] = {from_col_name: []} + + for edge_id, edge_data in edges.items(): + edge_doc = edge_data + edge_doc["_from"] = self.src_node_id + edge_doc["_to"] = edge_id + + edge_doc_id = edge_data.get("_id") + assert is_arangodb_id(edge_doc_id) + edge_col_name = read_collection_name_from_local_id( + edge_doc_id, self.default_node_type + ) + + if to_upsert.get(edge_col_name) is None: + to_upsert[edge_col_name] = [edge_doc] + else: + to_upsert[edge_col_name].append(edge_doc) + + # perform write to ArangoDB + result = upsert_collection_edges(self.db, to_upsert) + + all_good = check_list_for_errors(result) + if all_good: + # Means no single operation failed, in this case we update the local cache + self.__set_adj_elements(edges) + else: + # In this case some or all documents failed. Right now we will not + # update the local cache, but raise an error instead. + # Reason: We cannot set silent to True, because we need as it does + # not report errors then. We need to update the driver to also pass + # the errors back to the user, then we can adjust the behavior here. + # This will also save network traffic and local computation time. + errors = [] + for collections_results in result: + for collection_result in collections_results: + errors.append(collection_result) + logger.warning( + "Failed to insert at least one node. Will not update local cache." + ) + raise ArangoDBBatchError(errors) @logger_debug def values(self) -> Any: @@ -1242,14 +1287,25 @@ def _fetch_all(self) -> None: self.FETCHED_ALL_DATA = True self.FETCHED_ALL_IDS = True + def __set_adj_elements(self, edges): + for dst_node_id, edge in edges.items(): + edge_attr_dict: EdgeAttrDict = self._create_edge_attr_dict(edge) + + self.__fetch_all_helper(edge_attr_dict, dst_node_id, is_update=True) + @logger_debug - def __fetch_all_graph(self, edge_attr_dict: EdgeAttrDict, dst_node_id: str) -> None: + def __fetch_all_graph( + self, edge_attr_dict: EdgeAttrDict, dst_node_id: str, is_update: bool = False + ) -> None: """Helper function for _fetch_all() in Graphs.""" if dst_node_id in self.data: # Don't raise an error if it's a self-loop if self.data[dst_node_id] == edge_attr_dict: return + if is_update: + return + m = "Multiple edges between the same nodes are not supported in Graphs." m += f" Found 2 edges between {self.src_node_id} & {dst_node_id}." m += " Consider using a MultiGraph." @@ -1259,7 +1315,7 @@ def __fetch_all_graph(self, edge_attr_dict: EdgeAttrDict, dst_node_id: str) -> N @logger_debug def __fetch_all_multigraph( - self, edge_attr_dict: EdgeAttrDict, dst_node_id: str + self, edge_attr_dict: EdgeAttrDict, dst_node_id: str, is_update: bool = False ) -> None: """Helper function for _fetch_all() in MultiGraphs.""" edge_key_dict = self.data.get(dst_node_id) diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py index 52808c50..58b13e74 100644 --- a/nx_arangodb/classes/function.py +++ b/nx_arangodb/classes/function.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import Any, Callable, Tuple +from typing import Any, Callable, Optional, Tuple import networkx as nx from arango import ArangoError, DocumentInsertError @@ -710,6 +710,27 @@ def get_arangodb_collection_key_tuple(key): return key.split("/", 1) +def extract_arangodb_collection_name(arangodb_id: str) -> str: + if not is_arangodb_id(arangodb_id): + raise ValueError(f"Invalid ArangoDB key: {arangodb_id}") + return arangodb_id.split("/")[0] + + +def read_collection_name_from_local_id( + local_id: Optional[str], default_collection: str +) -> str: + if local_id is None: + print("local_id is None, cannot read collection name.") + return "" + + if is_arangodb_id(local_id): + return extract_arangodb_collection_name(local_id) + + assert default_collection is not None + assert default_collection != "" + return default_collection + + def separate_nodes_by_collections(nodes: Any, default_collection: str) -> Any: """ Separate the dictionary into collections based on whether keys contain '/'. diff --git a/tests/test.py b/tests/test.py index 2da38d58..25306b8f 100644 --- a/tests/test.py +++ b/tests/test.py @@ -12,7 +12,7 @@ ) import nx_arangodb as nxadb -from nx_arangodb.classes.dict.adj import AdjListOuterDict, EdgeAttrDict +from nx_arangodb.classes.dict.adj import AdjListOuterDict, EdgeAttrDict, EdgeKeyDict from nx_arangodb.classes.dict.node import NodeAttrDict, NodeDict from .conftest import create_line_graph, db @@ -626,6 +626,99 @@ def test_edge_dict_update_multiple_collections(load_two_relation_graph: Any) -> assert f"{v_1_name}/{3}" in local_edge_cache[f"{v_2_name}/{1}"] +@pytest.mark.parametrize( + "graph_cls", + [ + (nxadb.Graph), + (nxadb.DiGraph), + ], +) +def test_edge_adj_inner_dict_update_existing_single_collection( + load_karate_graph: Any, graph_cls: type[nxadb.Graph] +) -> None: + G_1 = graph_cls(name="KarateGraph", foo="bar", use_experimental_views=True) + + local_adj = G_1.adj + local_inner_edges_dict: GraphAdjDict = {} + from_doc_id_to_use: str = "person/9" + + target_dict = local_adj[from_doc_id_to_use] + for to_doc_id, edge_doc in target_dict.items(): + # will contain three items/documents + edge_doc_id = edge_doc["_id"] + local_inner_edges_dict[to_doc_id] = { + "_id": edge_doc_id, + "extraValue": edge_doc["_key"], + } + + G_1.adj[from_doc_id_to_use].update(local_inner_edges_dict) + + edge_col = db.collection("knows") + edge_col_docs = edge_col.all() + + # Check if the extraValue attribute was added to requested docs in ADB + for doc in edge_col_docs: + if doc["_from"] == from_doc_id_to_use: + assert "extraValue" in doc + assert doc["extraValue"] == doc["_key"] + + # Check if the extraValue attribute was added to each document in the local cache + for to_doc_id in local_inner_edges_dict.keys(): + assert "extraValue" in G_1._adj[from_doc_id_to_use][to_doc_id] + assert G_1.adj[from_doc_id_to_use][to_doc_id][ + "extraValue" + ] == extract_arangodb_key(local_inner_edges_dict[to_doc_id]["_id"]) + return + + +@pytest.mark.parametrize( + "graph_cls", + [ + (nxadb.MultiGraph), + (nxadb.MultiDiGraph), + ], +) +def test_edge_adj_inner_dict_update_existing_single_collection_multi_graphs( + load_karate_graph: Any, graph_cls: type[nxadb.Graph] +) -> None: + G_1 = graph_cls(name="KarateGraph", foo="bar", use_experimental_views=True) + + local_adj = G_1.adj + local_inner_edges_dict: GraphAdjDict = {} + from_doc_id_to_use: str = "person/9" + + target_dict = local_adj[from_doc_id_to_use] + for outer_to_doc_id, edge_key_dict in target_dict.items(): + assert isinstance(edge_key_dict, EdgeKeyDict) + + for to_doc_id, edge_doc in edge_key_dict.items(): + edge_doc_id = edge_doc["_id"] + local_inner_edges_dict[to_doc_id] = { + "_id": edge_doc_id, + "extraValue": edge_doc["_key"], + } + + G_1.adj[from_doc_id_to_use].update(local_inner_edges_dict) + + edge_col = db.collection("knows") + edge_col_docs = edge_col.all() + + # Check if the extraValue attribute was added to requested docs in ADB + for doc in edge_col_docs: + if doc["_from"] == from_doc_id_to_use: + assert "extraValue" in doc + assert doc["extraValue"] == doc["_key"] + + # Check if the extraValue attribute was added to each document in the local cache + for to_doc_id in local_inner_edges_dict.keys(): + assert to_doc_id in G_1.adj[from_doc_id_to_use][to_doc_id] + + assert "extraValue" in G_1.adj[from_doc_id_to_use][to_doc_id][to_doc_id] + assert G_1.adj[from_doc_id_to_use][to_doc_id][to_doc_id][ + "extraValue" + ] == extract_arangodb_key(local_inner_edges_dict[to_doc_id]["_id"]) + + @pytest.mark.parametrize( "graph_cls", [