diff --git a/nx_arangodb/classes/dict/adj.py b/nx_arangodb/classes/dict/adj.py index dab0825b..a97791c9 100644 --- a/nx_arangodb/classes/dict/adj.py +++ b/nx_arangodb/classes/dict/adj.py @@ -39,16 +39,15 @@ edge_link, get_arangodb_graph, get_node_id, + get_node_type, get_node_type_and_id, get_update_dict, - is_arangodb_id, json_serializable, key_is_adb_id_or_int, key_is_not_reserved, key_is_string, keys_are_not_reserved, keys_are_strings, - read_collection_name_from_local_id, separate_edges_by_collections, upsert_collection_edges, ) @@ -1180,11 +1179,10 @@ def copy(self) -> Any: return {key: value.copy() for key, value in self.data.items()} @keys_are_strings - def update(self, edges: Any) -> None: + def update(self, edges: dict[str, dict[str, Any]]) -> None: """g._adj['node/1'].update({'node/2': {'foo': 'bar'}})""" - from_col_name = read_collection_name_from_local_id( - self.src_node_id, self.default_node_type - ) + assert self.src_node_id + from_col_name = get_node_type(self.src_node_id, self.default_node_type) to_upsert: Dict[str, List[Dict[str, Any]]] = {from_col_name: []} @@ -1194,10 +1192,10 @@ def update(self, edges: Any) -> None: edge_doc["_to"] = edge_id edge_doc_id = edge_data.get("_id") - assert is_arangodb_id(edge_doc_id) - edge_col_name = read_collection_name_from_local_id( - edge_doc_id, self.default_node_type - ) + if not edge_doc_id: + raise ValueError("Edge _id field is required for update.") + + edge_col_name = get_node_type(edge_doc_id, self.default_node_type) if to_upsert.get(edge_col_name) is None: to_upsert[edge_col_name] = [edge_doc] diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py index 1709beaa..19f09324 100644 --- a/nx_arangodb/classes/function.py +++ b/nx_arangodb/classes/function.py @@ -637,17 +637,36 @@ def edge_link( return edge +def is_arangodb_id(key): + return "/" in key + + +def get_node_type(key: str, default_node_type: str) -> str: + """Gets the node type.""" + return key.split("/")[0] if is_arangodb_id(key) else default_node_type + + def get_node_id(key: str, default_node_type: str) -> str: """Gets the node ID.""" - return key if "/" in key else f"{default_node_type}/{key}" + return key if is_arangodb_id(key) else f"{default_node_type}/{key}" def get_node_type_and_id(key: str, default_node_type: str) -> tuple[str, str]: """Gets the node type and ID.""" - if "/" in key: - return key.split("/")[0], key + return ( + (key.split("/")[0], key) + if is_arangodb_id(key) + else (default_node_type, f"{default_node_type}/{key}") + ) + + +def get_node_type_and_key(key: str, default_node_type: str) -> tuple[str, str]: + """Gets the node type and key.""" + if is_arangodb_id(key): + col, key = key.split("/", 1) + return col, key - return default_node_type, f"{default_node_type}/{key}" + return default_node_type, key def get_update_dict( @@ -683,38 +702,9 @@ def check_list_for_errors(lst): return True -def is_arangodb_id(key): - return "/" in key - - -def get_arangodb_collection_key_tuple(key): - if not is_arangodb_id(key): - raise ValueError(f"Invalid ArangoDB key: {key}") - return key.split("/", 1) - - -def extract_arangodb_collection_name(arangodb_id: str) -> str: - if not is_arangodb_id(arangodb_id): - raise ValueError(f"Invalid ArangoDB key: {arangodb_id}") - return arangodb_id.split("/")[0] - - -def read_collection_name_from_local_id( - local_id: Optional[str], default_collection: str -) -> str: - if local_id is None: - print("local_id is None, cannot read collection name.") - return "" - - if is_arangodb_id(local_id): - return extract_arangodb_collection_name(local_id) - - assert default_collection is not None - assert default_collection != "" - return default_collection - - -def separate_nodes_by_collections(nodes: Any, default_collection: str) -> Any: +def separate_nodes_by_collections( + nodes: dict[str, Any], default_collection: str +) -> Any: """ Separate the dictionary into collections based on whether keys contain '/'. :param nodes: @@ -728,15 +718,12 @@ def separate_nodes_by_collections(nodes: Any, default_collection: str) -> Any: separated: Any = {} for key, value in nodes.items(): - if is_arangodb_id(key): - collection, doc_key = get_arangodb_collection_key_tuple(key) - if collection not in separated: - separated[collection] = {} - separated[collection][doc_key] = value - else: - if default_collection not in separated: - separated[default_collection] = {} - separated[default_collection][key] = value + collection, doc_key = get_node_type_and_key(key, default_collection) + + if collection not in separated: + separated[collection] = {} + + separated[collection][doc_key] = value return separated diff --git a/tests/test.py b/tests/test.py index 31f8d10d..145feb6a 100644 --- a/tests/test.py +++ b/tests/test.py @@ -113,6 +113,7 @@ def test_load_graph_from_nxadb(): name=graph_name, incoming_graph_data=G_NX, default_node_type="person", + write_async=False, ) assert db.has_graph(graph_name) @@ -134,6 +135,7 @@ def test_load_graph_from_nxadb_w_specific_edge_attribute(): incoming_graph_data=G_NX, default_node_type="person", edge_collections_attributes={"weight"}, + write_async=False, ) # TODO: re-enable this line as soon as CPU based data caching is implemented # graph._adj._fetch_all() @@ -163,6 +165,7 @@ def test_load_graph_from_nxadb_w_not_available_edge_attribute(): default_node_type="person", # This will lead to weight not being loaded into the edge data edge_collections_attributes={"_id"}, + write_async=False, ) # Should just succeed without any errors (fallback to weight: 1 as above) @@ -1592,15 +1595,9 @@ def test_graph_dict_clear_will_not_remove_remote_data(load_karate_graph: Any) -> def test_graph_dict_set_item(load_karate_graph: Any) -> None: - try: - db.collection("nxadb_graphs").delete("KarateGraph") - except DocumentDeleteError: - pass - except Exception as e: - print(f"An unexpected error occurred: {e}") - raise - - G = nxadb.Graph(name="KarateGraph", default_node_type="person") + name = "KarateGraph" + db.collection("nxadb_graphs").delete(name, ignore_missing=True) + G = nxadb.Graph(name=name, default_node_type="person") json_values = [ "aString", @@ -1819,7 +1816,9 @@ def test_incoming_graph_data_not_nx_graph( name = "KarateGraph" db.delete_graph(name, drop_collections=True, ignore_missing=True) - G = nxadb.Graph(incoming_graph_data=incoming_graph_data, name=name) + G = nxadb.Graph( + incoming_graph_data=incoming_graph_data, name=name, write_async=False + ) assert len(G.adj) == len(G_NX.adj) == db.collection(G.default_node_type).count() assert ( @@ -1870,7 +1869,9 @@ def test_incoming_graph_data_not_nx_graph_digraph( name = "KarateGraph" db.delete_graph(name, drop_collections=True, ignore_missing=True) - G = nxadb.DiGraph(incoming_graph_data=incoming_graph_data, name=name) + G = nxadb.DiGraph( + incoming_graph_data=incoming_graph_data, name=name, write_async=False + ) assert ( len(G.adj)