diff --git a/.circleci/config.yml b/.circleci/config.yml index 28e367aa..7d162293 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -3,7 +3,7 @@ version: 2.1 executors: python-executor: docker: - - image: circleci/python:3.10 + - image: cimg/python:3.10 environment: PACKAGE_DIR: nx_arangodb TESTS_DIR: tests diff --git a/nx_arangodb/classes/dict/node.py b/nx_arangodb/classes/dict/node.py index 5c4cba2f..62314aa7 100644 --- a/nx_arangodb/classes/dict/node.py +++ b/nx_arangodb/classes/dict/node.py @@ -10,10 +10,12 @@ from nx_arangodb.logger import logger from ..function import ( + ArangoDBBatchError, aql, aql_doc_get_key, aql_doc_has_key, aql_fetch_data, + check_list_for_errors, doc_delete, doc_insert, doc_update, @@ -27,6 +29,8 @@ keys_are_not_reserved, keys_are_strings, logger_debug, + separate_nodes_by_collections, + upsert_collection_documents, ) ############# @@ -125,7 +129,8 @@ def __contains__(self, key: str) -> bool: return True assert self.node_id - return aql_doc_has_key(self.db, self.node_id, key, self.parent_keys) + result: bool = aql_doc_has_key(self.db, self.node_id, key, self.parent_keys) + return result @key_is_string @logger_debug @@ -372,11 +377,44 @@ def clear(self) -> None: self.FETCHED_ALL_DATA = False self.FETCHED_ALL_IDS = False + @keys_are_strings + @logger_debug + def __update_local_nodes(self, nodes: Any) -> None: + for node_id, node_data in nodes.items(): + node_attr_dict = self.node_attr_dict_factory() + node_attr_dict.node_id = node_id + node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, node_data) + + self.data[node_id] = node_attr_dict + @keys_are_strings @logger_debug def update(self, nodes: Any) -> None: """g._node.update({'node/1': {'foo': 'bar'}, 'node/2': {'baz': 'qux'}})""" - raise NotImplementedError("NodeDict.update()") + separated_by_collection = separate_nodes_by_collections( + nodes, self.default_node_type + ) + + result = upsert_collection_documents(self.db, separated_by_collection) + + all_good = check_list_for_errors(result) + if all_good: + # Means no single operation failed, in this case we update the local cache + self.__update_local_nodes(nodes) + else: + # In this case some or all documents failed. Right now we will not + # update the local cache, but raise an error instead. + # Reason: We cannot set silent to True, because we need as it does + # not report errors then. We need to update the driver to also pass + # the errors back to the user, then we can adjust the behavior here. + # This will also save network traffic and local computation time. + errors = [] + for collections_results in result: + for collection_result in collections_results: + errors.append(collection_result) + m = "Failed to insert at least one node. Will not update local cache." + logger.warning(m) + raise ArangoDBBatchError(errors) @logger_debug def values(self) -> Any: diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py index ef87dc6d..f0985b05 100644 --- a/nx_arangodb/classes/function.py +++ b/nx_arangodb/classes/function.py @@ -5,12 +5,10 @@ from __future__ import annotations -from collections import UserDict from typing import Any, Callable, Tuple import networkx as nx -import numpy as np -import numpy.typing as npt +from arango import ArangoError, DocumentInsertError from arango.collection import StandardCollection from arango.cursor import Cursor from arango.database import StandardDatabase @@ -29,14 +27,9 @@ SrcIndices, ) -import nx_arangodb as nxadb from nx_arangodb.logger import logger -from ..exceptions import ( - AQLMultipleResultsFound, - GraphDoesNotExist, - InvalidTraversalDirection, -) +from ..exceptions import AQLMultipleResultsFound, InvalidTraversalDirection def do_load_all_edge_attributes(attributes: set[str]) -> bool: @@ -68,7 +61,7 @@ def get_arangodb_graph( ]: """Pulls the graph from the database, assuming the graph exists. - Returns the folowing representations: + Returns the following representations: - Node dictionary (nx.Graph) - Adjacency dictionary (nx.Graph) - Source Indices (COO) @@ -647,3 +640,107 @@ def get_update_dict( update_dict = {key: update_dict} return update_dict + + +class ArangoDBBatchError(ArangoError): + def __init__(self, errors): + self.errors = errors + super().__init__(self._format_errors()) + + def _format_errors(self): + return "\n".join(str(error) for error in self.errors) + + +def check_list_for_errors(lst): + for element in lst: + if element is type(bool): + if element is False: + return False + + elif isinstance(element, list): + for sub_element in element: + if isinstance(sub_element, DocumentInsertError): + return False + + return True + + +def is_arangodb_id(key): + return "/" in key + + +def get_arangodb_collection_key_tuple(key): + if not is_arangodb_id(key): + raise ValueError(f"Invalid ArangoDB key: {key}") + return key.split("/", 1) + + +def separate_nodes_by_collections(nodes: Any, default_collection: str) -> Any: + """ + Separate the dictionary into collections based on whether keys contain '/'. + :param nodes: + The input dictionary with keys that may or may not contain '/'. + :param default_collection: + The name of the default collection for keys without '/'. + :return: A dictionary where the keys are collection names and the + values are dictionaries of key-value pairs belonging to those + collections. + """ + separated: Any = {} + + for key, value in nodes.items(): + if is_arangodb_id(key): + collection, doc_key = get_arangodb_collection_key_tuple(key) + if collection not in separated: + separated[collection] = {} + separated[collection][doc_key] = value + else: + if default_collection not in separated: + separated[default_collection] = {} + separated[default_collection][key] = value + + return separated + + +def transform_local_documents_for_adb(original_documents): + """ + Transform original documents into a format suitable for UPSERT + operations in ArangoDB. + :param original_documents: Original documents in the format + {'key': {'any-attr-key': 'any-attr-value'}}. + :return: List of documents with '_key' attribute and additional attributes. + """ + transformed_documents = [] + + for key, values in original_documents.items(): + transformed_doc = {"_key": key} + transformed_doc.update(values) + transformed_documents.append(transformed_doc) + + return transformed_documents + + +def upsert_collection_documents(db: StandardDatabase, separated: Any) -> Any: + """ + Process each collection in the separated dictionary. + :param db: The ArangoDB database object. + :param separated: A dictionary where the keys are collection names and the + values are dictionaries + of key-value pairs belonging to those collections. + :return: A list of results from the insert_many operation. + If inserting a document fails, the exception is not raised but + returned as an object in the result list. + """ + + results = [] + + for collection_name, documents in separated.items(): + collection = db.collection(collection_name) + transformed_documents = transform_local_documents_for_adb(documents) + results.append( + collection.insert_many( + transformed_documents, silent=False, overwrite_mode="update" + ) + ) + + return results diff --git a/nx_arangodb/classes/reportviews.py b/nx_arangodb/classes/reportviews.py index 03763255..3cd3a2f4 100644 --- a/nx_arangodb/classes/reportviews.py +++ b/nx_arangodb/classes/reportviews.py @@ -21,6 +21,9 @@ def data(self, data=True, default=None): return self return CustomNodeDataView(self._nodes, data, default) + def update(self, data): + return self._nodes.update(data) + class CustomNodeDataView(nx.classes.reportviews.NodeDataView): def __iter__(self): diff --git a/tests/conftest.py b/tests/conftest.py index d2353c7a..da4f4624 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -49,7 +49,7 @@ def pytest_configure(config: Any) -> None: @pytest.fixture(scope="function") -def load_graph() -> None: +def load_karate_graph() -> None: global db db.delete_graph("KarateGraph", drop_collections=True, ignore_missing=True) adapter = ADBNX_Adapter(db) @@ -64,3 +64,24 @@ def load_graph() -> None: } ], ) + + +@pytest.fixture(scope="function") +def load_two_relation_graph() -> None: + global db + graph_name = "IntegrationTestTwoRelationGraph" + v1 = graph_name + "_v1" + v2 = graph_name + "_v2" + e1 = graph_name + "_e1" + e2 = graph_name + "_e2" + + if db.has_graph(graph_name): + db.delete_graph(graph_name, drop_collections=True) + + g = db.create_graph(graph_name) + g.create_edge_definition( + e1, from_vertex_collections=[v1], to_vertex_collections=[v2] + ) + g.create_edge_definition( + e2, from_vertex_collections=[v2], to_vertex_collections=[v1] + ) diff --git a/tests/test.py b/tests/test.py index 8a2f2ac0..f69c742d 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1,4 +1,4 @@ -from typing import Any, Callable +from typing import Any, Callable, Dict import networkx as nx import pytest @@ -74,7 +74,7 @@ def assert_k_components( assert d1 == d2 -def test_db(load_graph: Any) -> None: +def test_db(load_karate_graph: Any) -> None: assert db.version() @@ -172,7 +172,7 @@ def test_load_graph_with_non_default_weight_attribute(): def test_algorithm( algorithm_func: Callable[..., Any], assert_func: Callable[..., Any], - load_graph: Any, + load_karate_graph: Any, ) -> None: G_1 = G_NX G_2 = nxadb.Graph(incoming_graph_data=G_1) @@ -240,7 +240,7 @@ def test_algorithm( assert_func(r_9_orig, r_13_orig) -def test_shortest_path_remote_algorithm(load_graph: Any) -> None: +def test_shortest_path_remote_algorithm(load_karate_graph: Any) -> None: G_1 = nxadb.Graph(graph_name="KarateGraph") G_2 = nxadb.DiGraph(graph_name="KarateGraph") @@ -264,7 +264,104 @@ def test_shortest_path_remote_algorithm(load_graph: Any) -> None: (nxadb.MultiDiGraph), ], ) -def test_nodes_crud(load_graph: Any, graph_cls: type[nxadb.Graph]) -> None: +def test_node_dict_update_existing_single_collection( + load_karate_graph: Any, graph_cls: type[nxadb.Graph] +) -> None: + # This tests uses the existing nodes and updates each + # of them using the update method using a single collection + G_1 = nxadb.Graph(graph_name="KarateGraph", foo="bar") + + def extract_arangodb_key(adb_id: str) -> str: + return adb_id.split("/")[1] + + nodes_ids_list = G_1.nodes + local_nodes_dict = {} + + for node_id in nodes_ids_list: + local_nodes_dict[node_id] = {"extraValue": extract_arangodb_key(node_id)} + + G_1._node.update(local_nodes_dict) + + col = db.collection("person") + col_docs = col.all() + + # Check if the extraValue attribute was added to each document in the database + for doc in col_docs: + assert "extraValue" in doc + assert doc["extraValue"] == doc["_key"] + + # Check if the extraValue attribute was added to each document in the local cache + for node_id in nodes_ids_list: + assert "extraValue" in G_1._node.data[node_id] + assert G_1.nodes[node_id]["extraValue"] == extract_arangodb_key(node_id) + + +@pytest.mark.parametrize( + "graph_cls", + [ + (nxadb.Graph), + (nxadb.DiGraph), + (nxadb.MultiGraph), + (nxadb.MultiDiGraph), + ], +) +def test_node_dict_update_multiple_collections( + load_two_relation_graph: Any, graph_cls: type[nxadb.Graph] +) -> None: + # This tests uses the existing nodes and updates each + # of them using the update method using two collections + graph_name = "IntegrationTestTwoRelationGraph" + v_1_name = graph_name + "_v1" + v_2_name = graph_name + "_v2" + e_1_name = graph_name + "_e1" + e_2_name = graph_name + "_e2" + + # assert that those collections are empty + assert db.collection(v_1_name).count() == 0 + assert db.collection(v_2_name).count() == 0 + assert db.collection(e_1_name).count() == 0 + assert db.collection(e_2_name).count() == 0 + + G_1 = graph_cls(graph_name=graph_name, default_node_type=v_1_name) + assert len(G_1.nodes) == 0 + assert len(G_1.edges) == 0 + + # inserts into first collection (by default) + new_nodes_v1: Dict[str, Dict[str, Any]] = {"1": {}, "2": {}, "3": {}} + # needs to be inserted into second collection + new_nodes_v2: Dict[str, Dict[str, Any]] = { + f"{v_2_name}/4": {}, + f"{v_2_name}/5": {}, + f"{v_2_name}/6": {}, + } + + G_1._node.update(new_nodes_v1) + G_1._node.update(new_nodes_v2) + + assert db.collection(v_1_name).count() == 3 + assert db.collection(v_2_name).count() == 3 + + # check that local nodes in cache must have 6 elements + assert len(G_1.nodes) == 6 + # check that keys are present + # loop three times + for i in range(1, 4): + assert f"{v_1_name}/{str(i)}" in G_1.nodes + + for i in range(4, 7): + assert f"{v_2_name}/{i}" in G_1.nodes + + +@pytest.mark.parametrize( + "graph_cls", + [ + (nxadb.Graph), + (nxadb.DiGraph), + (nxadb.MultiGraph), + (nxadb.MultiDiGraph), + ], +) +def test_nodes_crud(load_karate_graph: Any, graph_cls: type[nxadb.Graph]) -> None: G_1 = graph_cls(graph_name="KarateGraph", foo="bar") G_2 = nx.Graph(G_NX) @@ -388,7 +485,7 @@ def test_nodes_crud(load_graph: Any, graph_cls: type[nxadb.Graph]) -> None: assert db.document("person/2")["object"]["sub_object"]["foo"] == "baz" -def test_graph_edges_crud(load_graph: Any) -> None: +def test_graph_edges_crud(load_karate_graph: Any) -> None: G_1 = nxadb.Graph(graph_name="KarateGraph") G_2 = G_NX @@ -535,7 +632,7 @@ def test_graph_edges_crud(load_graph: Any) -> None: assert db.document(edge_id)["object"]["sub_object"]["foo"] == "baz" -def test_digraph_edges_crud(load_graph: Any) -> None: +def test_digraph_edges_crud(load_karate_graph: Any) -> None: G_1 = nxadb.DiGraph(graph_name="KarateGraph") G_2 = G_NX @@ -684,7 +781,7 @@ def test_digraph_edges_crud(load_graph: Any) -> None: assert db.document(edge_id)["object"]["sub_object"]["foo"] == "baz" -def test_multigraph_edges_crud(load_graph: Any) -> None: +def test_multigraph_edges_crud(load_karate_graph: Any) -> None: G_1 = nxadb.MultiGraph(graph_name="KarateGraph") G_2 = G_NX @@ -846,7 +943,7 @@ def test_multigraph_edges_crud(load_graph: Any) -> None: assert db.document(edge_id)["object"]["sub_object"]["foo"] == "baz" -def test_multidigraph_edges_crud(load_graph: Any) -> None: +def test_multidigraph_edges_crud(load_karate_graph: Any) -> None: G_1 = nxadb.MultiDiGraph(graph_name="KarateGraph") G_2 = G_NX @@ -1016,7 +1113,7 @@ def test_multidigraph_edges_crud(load_graph: Any) -> None: assert db.document(edge_id)["object"]["sub_object"]["foo"] == "baz" -def test_graph_dict_init(load_graph: Any) -> None: +def test_graph_dict_init(load_karate_graph: Any) -> None: G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") assert db.collection("_graphs").has("KarateGraph") graph_document = db.collection("_graphs").get("KarateGraph") @@ -1030,7 +1127,7 @@ def test_graph_dict_init(load_graph: Any) -> None: assert db.has_document(graph_doc_id) -def test_graph_dict_init_extended(load_graph: Any) -> None: +def test_graph_dict_init_extended(load_karate_graph: Any) -> None: # Tests that available data (especially dicts) will be properly # stored as GraphDicts in the internal cache. G = nxadb.Graph(graph_name="KarateGraph", foo="bar", bar={"baz": True}) @@ -1041,7 +1138,7 @@ def test_graph_dict_init_extended(load_graph: Any) -> None: assert "baz" not in db.document(G.graph.graph_id) -def test_graph_dict_clear_will_not_remove_remote_data(load_graph: Any) -> None: +def test_graph_dict_clear_will_not_remove_remote_data(load_karate_graph: Any) -> None: G_adb = nxadb.Graph( graph_name="KarateGraph", foo="bar", @@ -1059,7 +1156,7 @@ def test_graph_dict_clear_will_not_remove_remote_data(load_graph: Any) -> None: assert db.document(G_adb.graph.graph_id)["ant"] == {"b": 6} -def test_graph_dict_set_item(load_graph: Any) -> None: +def test_graph_dict_set_item(load_karate_graph: Any) -> None: G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") try: db.collection(G.graph.COLLECTION_NAME).delete(G.name) @@ -1091,7 +1188,7 @@ def test_graph_dict_set_item(load_graph: Any) -> None: assert db.document(G.graph.graph_id)["json"] == value -def test_graph_dict_update(load_graph: Any) -> None: +def test_graph_dict_update(load_karate_graph: Any) -> None: G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") G.clear() @@ -1109,7 +1206,7 @@ def test_graph_dict_update(load_graph: Any) -> None: assert adb_doc["c"] == "d" -def test_graph_attr_dict_nested_update(load_graph: Any) -> None: +def test_graph_attr_dict_nested_update(load_karate_graph: Any) -> None: G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") G.clear() @@ -1121,7 +1218,7 @@ def test_graph_attr_dict_nested_update(load_graph: Any) -> None: assert db.document(G.graph.graph_id)["a"]["d"] == "e" -def test_graph_dict_nested_1(load_graph: Any) -> None: +def test_graph_dict_nested_1(load_karate_graph: Any) -> None: G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") G.clear() icon = {"football_icon": "MJ7"} @@ -1131,7 +1228,7 @@ def test_graph_dict_nested_1(load_graph: Any) -> None: assert db.document(G.graph.graph_id)["a"]["b"] == icon -def test_graph_dict_nested_2(load_graph: Any) -> None: +def test_graph_dict_nested_2(load_karate_graph: Any) -> None: G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") G.clear() icon = {"football_icon": "MJ7"} @@ -1143,7 +1240,7 @@ def test_graph_dict_nested_2(load_graph: Any) -> None: assert db.document(G.graph.graph_id)["x"]["y"]["amount_of_goals"] == 1337 -def test_graph_dict_empty_values(load_graph: Any) -> None: +def test_graph_dict_empty_values(load_karate_graph: Any) -> None: G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") G.clear() @@ -1156,7 +1253,7 @@ def test_graph_dict_empty_values(load_graph: Any) -> None: assert "none" not in G.graph -def test_graph_dict_nested_overwrite(load_graph: Any) -> None: +def test_graph_dict_nested_overwrite(load_karate_graph: Any) -> None: G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") G.clear() icon1 = {"football_icon": "MJ7"} @@ -1173,7 +1270,7 @@ def test_graph_dict_nested_overwrite(load_graph: Any) -> None: assert db.document(G.graph.graph_id)["a"]["b"]["basketball_icon"] == "MJ23" -def test_graph_dict_complex_nested(load_graph: Any) -> None: +def test_graph_dict_complex_nested(load_karate_graph: Any) -> None: G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") G.clear() @@ -1187,7 +1284,7 @@ def test_graph_dict_complex_nested(load_graph: Any) -> None: ) -def test_graph_dict_nested_deletion(load_graph: Any) -> None: +def test_graph_dict_nested_deletion(load_karate_graph: Any) -> None: G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") G.clear() icon = {"football_icon": "MJ7", "amount_of_goals": 1337} @@ -1203,7 +1300,7 @@ def test_graph_dict_nested_deletion(load_graph: Any) -> None: assert "x" not in db.document(G.graph.graph_id) -def test_readme(load_graph: Any) -> None: +def test_readme(load_karate_graph: Any) -> None: G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") assert len(G.nodes) == len(G_NX.nodes)