diff --git a/nx_arangodb/classes/dict.py b/nx_arangodb/classes/dict.py index d15bfcc5..cc79ec29 100644 --- a/nx_arangodb/classes/dict.py +++ b/nx_arangodb/classes/dict.py @@ -5,22 +5,30 @@ from __future__ import annotations -from collections import UserDict, defaultdict +import warnings +from collections import UserDict from collections.abc import Iterator -from typing import Any, Callable, Generator +from typing import Any, Callable, Dict, List from arango.database import StandardDatabase -from arango.exceptions import DocumentInsertError from arango.graph import Graph from nx_arangodb.logger import logger +from ..typing import AdjDict +from ..utils.arangodb import ( + ArangoDBBatchError, + check_list_for_errors, + is_arangodb_id, + read_collection_name_from_local_id, + separate_edges_by_collections, + separate_nodes_by_collections, + upsert_collection_documents, + upsert_collection_edges, +) from .function import ( aql, - aql_as_list, aql_doc_get_key, - aql_doc_get_keys, - aql_doc_get_length, aql_doc_has_key, aql_edge_exists, aql_edge_get, @@ -496,11 +504,45 @@ def clear(self) -> None: # for collection in self.graph.vertex_collections(): # self.graph.vertex_collection(collection).truncate() + @keys_are_strings + @logger_debug + def update_local_nodes(self, nodes: Any) -> None: + for node_id, node_data in nodes.items(): + node_attr_dict = self.node_attr_dict_factory() + node_attr_dict.node_id = node_id + node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, node_data) + + self.data[node_id] = node_attr_dict + @keys_are_strings @logger_debug def update(self, nodes: Any) -> None: """g._node.update({'node/1': {'foo': 'bar'}, 'node/2': {'baz': 'qux'}})""" - raise NotImplementedError("NodeDict.update()") + separated_by_collection = separate_nodes_by_collections( + nodes, self.default_node_type + ) + + result = upsert_collection_documents(self.db, separated_by_collection) + + all_good = check_list_for_errors(result) + if all_good: + # Means no single operation failed, in this case we update the local cache + self.update_local_nodes(nodes) + else: + # In this case some or all documents failed. Right now we will not + # update the local cache, but raise an error instead. + # Reason: We cannot set silent to True, because we need as it does + # not report errors then. We need to update the driver to also pass + # the errors back to the user, then we can adjust the behavior here. + # This will also save network traffic and local computation time. + errors = [] + for collections_results in result: + for collection_result in collections_results: + errors.append(collection_result) + warnings.warn( + "Failed to insert at least one node. Will not update local cache." + ) + raise ArangoDBBatchError(errors) # TODO: Revisit typing of return value @logger_debug @@ -614,7 +656,7 @@ def __init__( self.graph = graph self.edge_id: str | None = None - # NodeAttrDict may be a child of another NodeAttrDict + # EdgeAttrDict may be a child of another EdgeAttrDict # e.g G._adj['node/1']['node/2']['object']['foo'] = 'bar' # In this case, **parent_keys** would be ['object'] # and **root** would be G._adj['node/1']['node/2'] @@ -933,7 +975,62 @@ def clear(self) -> None: @logger_debug def update(self, edges: Any) -> None: """g._adj['node/1'].update({'node/2': {'foo': 'bar'}})""" - raise NotImplementedError("AdjListInnerDict.update()") + from_col_name = read_collection_name_from_local_id( + self.src_node_id, self.default_node_type + ) + + to_upsert: Dict[str, List[Dict[str, Any]]] = {from_col_name: []} + + for edge_id, edge_data in edges.items(): + edge_doc = edge_data + edge_doc["_from"] = self.src_node_id + edge_doc["_to"] = edge_id + + edge_doc_id = edge_data.get("_id") + # TODO: @Anthony please check if the implementation is correct + # of the default for edge_type_func, which is right now: + # * edge_type_func: Callable[[str, str], str] = lambda u, v: f"{u}_to_{v}", + # + # How does that help to identify the edge's collection name? + # The below implementation I wanted to use but returns in my example: + # "person/9_to_person/34" which is not a valid or requested collection name. + # + # edge_type = edge_data.get("_edge_type") + # if edge_type is None: + # edge_type = self.edge_type_func(self.src_node_id, edge_id) + # + # -> Therefore right now I need to assume that this is always a + # valid ArangoDB document ID + assert is_arangodb_id(edge_doc_id) + edge_col_name = read_collection_name_from_local_id(edge_doc_id, "") + + if to_upsert.get(edge_col_name) is None: + to_upsert[edge_col_name] = [edge_doc] + else: + to_upsert[edge_col_name].append(edge_doc) + + # perform write to ArangoDB + result = upsert_collection_edges(self.db, to_upsert) + + all_good = check_list_for_errors(result) + if all_good: + # Means no single operation failed, in this case we update the local cache + self.__set_adj_elements(edges) + else: + # In this case some or all documents failed. Right now we will not + # update the local cache, but raise an error instead. + # Reason: We cannot set silent to True, because we need as it does + # not report errors then. We need to update the driver to also pass + # the errors back to the user, then we can adjust the behavior here. + # This will also save network traffic and local computation time. + errors = [] + for collections_results in result: + for collection_result in collections_results: + errors.append(collection_result) + warnings.warn( + "Failed to insert at least one node. Will not update local cache." + ) + raise ArangoDBBatchError(errors) # TODO: Revisit typing of return value @logger_debug @@ -974,6 +1071,14 @@ def __fetch_all(self) -> None: self.FETCHED_ALL_DATA = True + def __set_adj_elements(self, edges): + for dst_node_id, edge in edges.items(): + # Copied from above, from __fetch_all + edge_attr_dict = self.edge_attr_dict_factory() + edge_attr_dict.edge_id = edge["_id"] + edge_attr_dict.data = build_edge_attr_dict_data(edge_attr_dict, edge) + self.data[edge["_to"]] = edge_attr_dict + class AdjListOuterDict(UserDict[str, AdjListInnerDict]): """The outer-level of the dict of dict of dict structure @@ -1138,8 +1243,29 @@ def clear(self) -> None: @keys_are_strings @logger_debug def update(self, edges: Any) -> None: - """g._adj.update({'node/1': {'node/2': {'foo': 'bar'}})""" - raise NotImplementedError("AdjListOuterDict.update()") + """g._adj.update({'node/1': {'node/2': {'_id': 'foo/bar', 'foo': "bar"}})""" + separated_by_edge_collection = separate_edges_by_collections(edges) + result = upsert_collection_edges(self.db, separated_by_edge_collection) + + all_good = check_list_for_errors(result) + if all_good: + # Means no single operation failed, in this case we update the local cache + self.__set_adj_elements(edges) + else: + # In this case some or all documents failed. Right now we will not + # update the local cache, but raise an error instead. + # Reason: We cannot set silent to True, because we need as it does + # not report errors then. We need to update the driver to also pass + # the errors back to the user, then we can adjust the behavior here. + # This will also save network traffic and local computation time. + errors = [] + for collections_results in result: + for collection_result in collections_results: + errors.append(collection_result) + warnings.warn( + "Failed to insert at least one node. Will not update local cache." + ) + raise ArangoDBBatchError(errors) # TODO: Revisit typing of return value @logger_debug @@ -1171,25 +1297,15 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any: yield from result @logger_debug - def __fetch_all(self) -> None: - self.clear() - - _, adj_dict, _, _, _ = get_arangodb_graph( - self.graph, - load_node_dict=False, - load_adj_dict=True, - is_directed=False, # TODO: Abstract based on Graph type - is_multigraph=False, # TODO: Abstract based on Graph type - load_coo=False, - ) - - for src_node_id, inner_dict in adj_dict.items(): + def __set_adj_elements(self, edges: AdjDict) -> None: + for src_node_id, inner_dict in edges.items(): for dst_node_id, edge in inner_dict.items(): if src_node_id in self.data: if dst_node_id in self.data[src_node_id].data: continue + # TODO: Clean up those two if/else statements later if src_node_id in self.data: src_inner_dict = self.data[src_node_id] else: @@ -1209,8 +1325,21 @@ def __fetch_all(self) -> None: edge_attr_dict = src_inner_dict.edge_attr_dict_factory() edge_attr_dict.edge_id = edge["_id"] edge_attr_dict.data = build_edge_attr_dict_data(edge_attr_dict, edge) - self.data[src_node_id].data[dst_node_id] = edge_attr_dict self.data[dst_node_id].data[src_node_id] = edge_attr_dict + @logger_debug + def __fetch_all(self) -> None: + self.clear() + + _, adj_dict, _, _, _ = get_arangodb_graph( + self.graph, + load_node_dict=False, + load_adj_dict=True, + is_directed=False, # TODO: Abstract based on Graph type + is_multigraph=False, # TODO: Abstract based on Graph type + load_coo=False, + ) + + self.__set_adj_elements(adj_dict) self.FETCHED_ALL_DATA = True diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py index 44a2ef98..a7c3ce91 100644 --- a/nx_arangodb/classes/function.py +++ b/nx_arangodb/classes/function.py @@ -24,6 +24,7 @@ GraphDoesNotExist, InvalidTraversalDirection, ) +from ..typing import AdjDict def get_arangodb_graph( @@ -35,7 +36,7 @@ def get_arangodb_graph( load_coo: bool, ) -> Tuple[ dict[str, dict[str, Any]], - dict[str, dict[str, dict[str, Any]]], + AdjDict, npt.NDArray[np.int64], npt.NDArray[np.int64], dict[str, int], @@ -152,7 +153,7 @@ def wrapper( return wrapper -RESERVED_KEYS = {"_id", "_key", "_rev"} +RESERVED_KEYS = {"_id", "_key", "_rev", "_from", "_to"} def key_is_not_reserved(func: Callable[..., Any]) -> Any: diff --git a/nx_arangodb/typing.py b/nx_arangodb/typing.py index cdd3cefa..4f030c8b 100644 --- a/nx_arangodb/typing.py +++ b/nx_arangodb/typing.py @@ -3,12 +3,19 @@ from __future__ import annotations from collections.abc import Hashable -from typing import TypeVar +from typing import Any, Dict, TypeVar -import cupy as cp -import numpy as np import numpy.typing as npt +from nx_arangodb.logger import logger + +try: + import cupy as cp +except ModuleNotFoundError as e: + GPU_ENABLED = False + logger.info(f"NXCG is disabled. {e}.") + + AttrKey = TypeVar("AttrKey", bound=Hashable) EdgeKey = TypeVar("EdgeKey", bound=Hashable) NodeKey = TypeVar("NodeKey", bound=Hashable) @@ -18,6 +25,32 @@ IndexValue = TypeVar("IndexValue") Dtype = TypeVar("Dtype") +# AdjDict is a dictionary of dictionaries of dictionaries +# The outer dict is holding _from_id(s) as keys +# - It may or may not hold valid ArangoDB document _id(s) +# The inner dict is holding _to_id(s) as keys +# - It may or may not hold valid ArangoDB document _id(s) +# The next inner dict contains then the actual edges data (key, val) +# Example +# { +# 'person/1': { +# 'person/32': { +# '_id': 'knows/16', +# 'extraValue': '16' +# }, +# 'person/33': { +# '_id': 'knows/17', +# 'extraValue': '17' +# } +# ... +# } +# ... +# } +# The above example is a graph with 2 edges from person/1 to person/32 and person/33 +AdjDictEdge = Dict[str, Any] +AdjDictInner = Dict[str, AdjDictEdge] +AdjDict = Dict[str, AdjDictInner] + class any_ndarray: def __class_getitem__(cls, item): diff --git a/nx_arangodb/utils/arangodb.py b/nx_arangodb/utils/arangodb.py new file mode 100644 index 00000000..1a2ad116 --- /dev/null +++ b/nx_arangodb/utils/arangodb.py @@ -0,0 +1,197 @@ +from typing import Any, Optional + +from arango import ArangoError, DocumentInsertError +from arango.database import StandardDatabase + + +class ArangoDBBatchError(ArangoError): + def __init__(self, errors): + self.errors = errors + super().__init__(self._format_errors()) + + def _format_errors(self): + return "\n".join(str(error) for error in self.errors) + + +def check_list_for_errors(lst): + for element in lst: + if element is type(bool): + if element is False: + return False + + elif isinstance(element, list): + for sub_element in element: + if isinstance(sub_element, DocumentInsertError): + return False + + return True + + +def extract_arangodb_key(arangodb_id): + assert "/" in arangodb_id + return arangodb_id.split("/")[1] + + +def extract_arangodb_collection_name(arangodb_id: str) -> str: + assert is_arangodb_id(arangodb_id) + return arangodb_id.split("/")[0] + + +def is_arangodb_id(key): + return "/" in key + + +def read_collection_name_from_local_id( + local_id: Optional[str], default_collection: str +) -> str: + if local_id is None: + print("local_id is None, cannot read collection name.") + return "" + + if is_arangodb_id(local_id): + return extract_arangodb_collection_name(local_id) + + assert default_collection is not None + assert default_collection != "" + return default_collection + + +def get_arangodb_collection_key_tuple(key): + assert is_arangodb_id(key) + if is_arangodb_id(key): + return key.split("/", 1) + + +def separate_nodes_by_collections(nodes: Any, default_collection: str) -> Any: + """ + Separate the dictionary into collections based on whether keys contain '/'. + + :param nodes: + The input dictionary with keys that may or may not contain '/'. + :param default_collection: + The name of the default collection for keys without '/'. + :return: A dictionary where the keys are collection names and the + values are dictionaries of key-value pairs belonging to those + collections. + """ + separated: Any = {} + + for key, value in nodes.items(): + if is_arangodb_id(key): + collection, doc_key = get_arangodb_collection_key_tuple(key) + if collection not in separated: + separated[collection] = {} + separated[collection][doc_key] = value + else: + if default_collection not in separated: + separated[default_collection] = {} + separated[default_collection][key] = value + + return separated + + +def separate_edges_by_collections(edges: Any) -> Any: + """ + Separate the dictionary into collections based on whether keys contain '/'. + + :param edges: + The input dictionary with keys that must contain the real doc id. + :return: A dictionary where the keys are collection names and the + values are dictionaries of key-value pairs belonging to those + collections. + """ + separated: Any = {} + + for from_doc_id, target_dict in edges.items(): + for to_doc_id, edge_doc in target_dict.items(): + assert edge_doc is not None + assert "_id" in edge_doc + edge_collection_name = extract_arangodb_collection_name(edge_doc["_id"]) + + if separated.get(edge_collection_name) is None: + separated[edge_collection_name] = [] + + edge_doc["_from"] = from_doc_id + edge_doc["_to"] = to_doc_id + + separated[edge_collection_name].append(edge_doc) + + return separated + + +def transform_local_documents_for_adb(original_documents): + """ + Transform original documents into a format suitable for UPSERT + operations in ArangoDB. + + :param original_documents: Original documents in the format + {'key': {'any-attr-key': 'any-attr-value'}}. + :return: List of documents with '_key' attribute and additional attributes. + """ + transformed_documents = [] + + for key, values in original_documents.items(): + transformed_doc = {"_key": key} + transformed_doc.update(values) + transformed_documents.append(transformed_doc) + + return transformed_documents + + +def upsert_collection_documents(db: StandardDatabase, separated: Any) -> Any: + """ + Process each collection in the separated dictionary. + + :param db: The ArangoDB database object. + :param separated: A dictionary where the keys are collection names and the + values are dictionaries + of key-value pairs belonging to those collections. + :return: A list of results from the insert_many operation. + + If inserting a document fails, the exception is not raised but + returned as an object in the result list. + """ + + results = [] + + for collection_name, documents in separated.items(): + collection = db.collection(collection_name) + transformed_documents = transform_local_documents_for_adb(documents) + results.append( + collection.insert_many( + transformed_documents, silent=False, overwrite_mode="update" + ) + ) + + return results + + +def upsert_collection_edges(db: StandardDatabase, separated: Any) -> Any: + """ + Process each collection in the separated dictionary. + + :param db: The ArangoDB database object. + :param separated: A dictionary where the keys are collection names and the + values are dictionaries + of key-value pairs belonging to those collections. + :return: A list of results from the insert_many operation. + + If inserting a document fails, the exception is not raised but + returned as an object in the result list. + """ + + results = [] + + for collection_name, documents_list in separated.items(): + collection = db.collection(collection_name) + + if documents_list: + results.append( + collection.insert_many( + documents_list, + silent=False, + overwrite_mode="update", + ) + ) + + return results diff --git a/tests/conftest.py b/tests/conftest.py index d2353c7a..da4f4624 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -49,7 +49,7 @@ def pytest_configure(config: Any) -> None: @pytest.fixture(scope="function") -def load_graph() -> None: +def load_karate_graph() -> None: global db db.delete_graph("KarateGraph", drop_collections=True, ignore_missing=True) adapter = ADBNX_Adapter(db) @@ -64,3 +64,24 @@ def load_graph() -> None: } ], ) + + +@pytest.fixture(scope="function") +def load_two_relation_graph() -> None: + global db + graph_name = "IntegrationTestTwoRelationGraph" + v1 = graph_name + "_v1" + v2 = graph_name + "_v2" + e1 = graph_name + "_e1" + e2 = graph_name + "_e2" + + if db.has_graph(graph_name): + db.delete_graph(graph_name, drop_collections=True) + + g = db.create_graph(graph_name) + g.create_edge_definition( + e1, from_vertex_collections=[v1], to_vertex_collections=[v2] + ) + g.create_edge_definition( + e2, from_vertex_collections=[v2], to_vertex_collections=[v1] + ) diff --git a/tests/test.py b/tests/test.py index 0c9da34c..007a3bc9 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1,18 +1,19 @@ -from typing import Any +from typing import Any, Dict import networkx as nx -import pandas as pd import pytest import nx_arangodb as nxadb from nx_arangodb.classes.dict import EdgeAttrDict, NodeAttrDict +from nx_arangodb.typing import AdjDict, AdjDictEdge +from nx_arangodb.utils.arangodb import extract_arangodb_key from .conftest import db G_NX = nx.karate_club_graph() -def test_db(load_graph: Any) -> None: +def test_db(load_karate_graph: Any) -> None: assert db.version() @@ -36,7 +37,7 @@ def test_load_graph_from_nxadb(): db.delete_graph(graph_name, drop_collections=True) -def test_bc(load_graph): +def test_bc(load_karate_graph): G_1 = G_NX G_2 = nxadb.Graph(incoming_graph_data=G_1) G_3 = nxadb.Graph(graph_name="KarateGraph") @@ -72,7 +73,7 @@ def test_bc(load_graph): assert len(r_6) == len(r_7) == len(r_8) == len(G_4) > 0 -def test_pagerank(load_graph: Any) -> None: +def test_pagerank(load_karate_graph: Any) -> None: G_1 = G_NX G_2 = nxadb.Graph(incoming_graph_data=G_1) G_3 = nxadb.Graph(graph_name="KarateGraph") @@ -106,7 +107,7 @@ def test_pagerank(load_graph: Any) -> None: assert len(r_6) == len(r_7) == len(r_8) == len(G_4) > 0 -def test_louvain(load_graph: Any) -> None: +def test_louvain(load_karate_graph: Any) -> None: G_1 = G_NX G_2 = nxadb.Graph(incoming_graph_data=G_1) G_3 = nxadb.Graph(graph_name="KarateGraph") @@ -143,7 +144,7 @@ def test_louvain(load_graph: Any) -> None: assert len(r_8) > 0 -def test_shortest_path(load_graph: Any) -> None: +def test_shortest_path(load_karate_graph: Any) -> None: G_1 = nxadb.Graph(graph_name="KarateGraph") G_2 = nxadb.DiGraph(graph_name="KarateGraph") @@ -158,7 +159,210 @@ def test_shortest_path(load_graph: Any) -> None: assert r_3 != r_4 -def test_graph_nodes_crud(load_graph: Any) -> None: +def test_node_dict_update_existing_single_collection(load_karate_graph: Any) -> None: + # This tests uses the existing nodes and updates each + # of them using the update method using a single collection + G_1 = nxadb.Graph(graph_name="KarateGraph", foo="bar") + + nodes_ids_list = G_1.nodes + local_nodes_dict = {} + + for node_id in nodes_ids_list: + local_nodes_dict[node_id] = {"extraValue": extract_arangodb_key(node_id)} + + G_1._node.update(local_nodes_dict) + + col = db.collection("person") + col_docs = col.all() + + # Check if the extraValue attribute was added to each document in the database + for doc in col_docs: + assert "extraValue" in doc + assert doc["extraValue"] == doc["_key"] + + # Check if the extraValue attribute was added to each document in the local cache + for node_id in nodes_ids_list: + assert "extraValue" in G_1._node.data[node_id] + assert G_1.nodes[node_id]["extraValue"] == extract_arangodb_key(node_id) + + +def test_node_dict_update_multiple_collections(load_two_relation_graph: Any) -> None: + # This tests uses the existing nodes and updates each + # of them using the update method using two collections + graph_name = "IntegrationTestTwoRelationGraph" + v_1_name = graph_name + "_v1" + v_2_name = graph_name + "_v2" + e_1_name = graph_name + "_e1" + e_2_name = graph_name + "_e2" + + # assert that those collections are empty + assert db.collection(v_1_name).count() == 0 + assert db.collection(v_2_name).count() == 0 + assert db.collection(e_1_name).count() == 0 + assert db.collection(e_2_name).count() == 0 + + G_1 = nxadb.Graph(graph_name=graph_name, default_node_type=v_1_name) + assert len(G_1.nodes) == 0 + assert len(G_1.edges) == 0 + + # inserts into first collection (by default) + new_nodes_v1: Dict[str, Dict[str, Any]] = {"1": {}, "2": {}, "3": {}} + # needs to be inserted into second collection + new_nodes_v2: Dict[str, Dict[str, Any]] = { + f"{v_2_name}/4": {}, + f"{v_2_name}/5": {}, + f"{v_2_name}/6": {}, + } + + G_1._node.update(new_nodes_v1) + G_1._node.update(new_nodes_v2) + + assert db.collection(v_1_name).count() == 3 + assert db.collection(v_2_name).count() == 3 + assert len(G_1.nodes) == 6 + + for i in range(1, 4): + assert f"{v_1_name}/{str(i)}" in G_1.nodes + + for i in range(4, 7): + assert f"{v_2_name}/{i}" in G_1.nodes + + +def test_edge_adj_dict_update_existing_single_collection( + load_karate_graph: Any, +) -> None: + G_1 = nxadb.Graph(graph_name="KarateGraph", foo="bar") + + local_adj = G_1.adj + local_edges_dict: AdjDict = {} + + for from_doc_id, target_dict in local_adj.items(): + for to_doc_id, edge_doc in target_dict.items(): + edge_doc_id = edge_doc["_id"] + if from_doc_id not in local_edges_dict: + local_edges_dict[from_doc_id] = {} + + local_edges_dict[from_doc_id][to_doc_id] = { + "_id": edge_doc_id, + "extraValue": edge_doc["_key"], + } + + G_1._adj.update(local_edges_dict) + + edge_col = db.collection("knows") + edge_col_docs = edge_col.all() + + # Check if the extraValue attribute was added to each document in the database + for doc in edge_col_docs: + assert "extraValue" in doc + assert doc["extraValue"] == doc["_key"] + + # Check if the extraValue attribute was added to each document in the local cache + for from_doc_id, target_dict in local_edges_dict.items(): + for to_doc_id, edge_doc in target_dict.items(): + assert "extraValue" in G_1._adj[from_doc_id][to_doc_id] + assert G_1.adj[from_doc_id][to_doc_id][ + "extraValue" + ] == extract_arangodb_key(edge_doc["_id"]) + + +def test_edge_adj_inner_dict_update_existing_single_collection( + load_karate_graph: Any, +) -> None: + G_1 = nxadb.Graph(graph_name="KarateGraph", foo="bar") + + local_adj = G_1.adj + local_inner_edges_dict: Dict[str, AdjDictEdge] = {} + from_doc_id_to_use: str = "person/9" + + target_dict = local_adj[from_doc_id_to_use] + for to_doc_id, edge_doc in target_dict.items(): + # will contain three items/documents + edge_doc_id = edge_doc["_id"] + local_inner_edges_dict[to_doc_id] = { + "_id": edge_doc_id, + "extraValue": edge_doc["_key"], + } + + G_1._adj[from_doc_id_to_use].update(local_inner_edges_dict) + + edge_col = db.collection("knows") + edge_col_docs = edge_col.all() + + # Check if the extraValue attribute was added to requested docs in ADB + for doc in edge_col_docs: + if doc["_from"] == from_doc_id_to_use: + assert "extraValue" in doc + assert doc["extraValue"] == doc["_key"] + + # Check if the extraValue attribute was added to each document in the local cache + for to_doc_id in local_inner_edges_dict.keys(): + assert "extraValue" in G_1._adj[from_doc_id_to_use][to_doc_id] + assert G_1.adj[from_doc_id_to_use][to_doc_id][ + "extraValue" + ] == extract_arangodb_key(local_inner_edges_dict[to_doc_id]["_id"]) + return + + +def test_edge_dict_update_multiple_collections(load_two_relation_graph: Any) -> None: + graph_name = "IntegrationTestTwoRelationGraph" + v_1_name = graph_name + "_v1" + v_2_name = graph_name + "_v2" + e_1_name = graph_name + "_e1" + e_2_name = graph_name + "_e2" + + assert db.collection(v_1_name).count() == 0 + assert db.collection(v_2_name).count() == 0 + assert db.collection(e_1_name).count() == 0 + assert db.collection(e_2_name).count() == 0 + + G_1 = nxadb.Graph(graph_name=graph_name, default_node_type=v_1_name) + assert len(G_1.nodes) == 0 + assert len(G_1.edges) == 0 + + # inserts into first collection (by default) + new_edges_dict: AdjDict = { + graph_name + + "_v1/1": { + graph_name + "_v1/2": {"_id": e_1_name + "/1"}, + graph_name + "_v1/3": {"_id": e_1_name + "/2"}, + }, + graph_name + + "_v2/1": { + graph_name + "_v1/2": {"_id": e_2_name + "/1"}, + graph_name + "_v1/3": {"_id": e_2_name + "/2"}, + }, + } + + G_1._adj.update(new_edges_dict) + + # _adj list is not responsible for maintaining the vertex collections + assert db.collection(v_1_name).count() == 0 + assert db.collection(v_2_name).count() == 0 + + assert db.collection(e_1_name).count() == 2 + assert db.collection(e_2_name).count() == 2 + + # Check that the edge ids are present in the database + assert db.has_document({"_id": e_1_name + "/1"}) + assert db.has_document({"_id": e_1_name + "/2"}) + assert db.has_document({"_id": e_2_name + "/1"}) + assert db.has_document({"_id": e_2_name + "/2"}) + + # Check local state + assert len(G_1.nodes) == 0 + assert len(G_1.edges) == 4 + + local_edge_cache = G_1._adj + assert f"{v_1_name}/{1}" in local_edge_cache + assert f"{v_2_name}/{1}" in local_edge_cache + assert f"{v_1_name}/{2}" in local_edge_cache[f"{v_1_name}/{1}"] + assert f"{v_1_name}/{3}" in local_edge_cache[f"{v_1_name}/{1}"] + assert f"{v_1_name}/{2}" in local_edge_cache[f"{v_2_name}/{1}"] + assert f"{v_1_name}/{3}" in local_edge_cache[f"{v_2_name}/{1}"] + + +def test_graph_nodes_crud(load_karate_graph: Any) -> None: G_1 = nxadb.Graph(graph_name="KarateGraph", foo="bar") G_2 = nx.Graph(G_NX) @@ -281,7 +485,7 @@ def test_graph_nodes_crud(load_graph: Any) -> None: assert db.document("person/2")["object"]["sub_object"]["foo"] == "baz" -def test_graph_edges_crud(load_graph: Any) -> None: +def test_graph_edges_crud(load_karate_graph: Any) -> None: G_1 = nxadb.Graph(graph_name="KarateGraph") G_2 = G_NX @@ -423,7 +627,7 @@ def test_graph_edges_crud(load_graph: Any) -> None: assert db.document(edge_id)["object"]["sub_object"]["foo"] == "baz" -def test_readme(load_graph: Any) -> None: +def test_readme(load_karate_graph: Any) -> None: G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") assert len(G.nodes) == len(G_NX.nodes)