From 885be62b58117b08ebc0f3bb09585ec937b38384 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Thu, 4 Jul 2024 15:43:50 -0400 Subject: [PATCH 1/6] GA-147 | initial commit --- nx_arangodb/classes/dict.py | 510 ++++++++++++++++++-------------- nx_arangodb/classes/function.py | 45 ++- tests/test.py | 39 ++- 3 files changed, 360 insertions(+), 234 deletions(-) diff --git a/nx_arangodb/classes/dict.py b/nx_arangodb/classes/dict.py index f70c6e70..aadb3cbb 100644 --- a/nx_arangodb/classes/dict.py +++ b/nx_arangodb/classes/dict.py @@ -39,8 +39,13 @@ key_is_string, keys_are_not_reserved, keys_are_strings, + logger_debug, ) +############# +# Factories # +############# + def graph_dict_factory( db: StandardDatabase, graph_name: str @@ -87,6 +92,11 @@ def edge_attr_dict_factory( return lambda: EdgeAttrDict(db, graph) +######### +# Graph # +######### + + class GraphDict(UserDict[str, Any]): """A dictionary-like object for storing graph attributes. @@ -101,10 +111,10 @@ class GraphDict(UserDict[str, Any]): COLLECTION_NAME = "nxadb_graphs" + @logger_debug def __init__( self, db: StandardDatabase, graph_name: str, *args: Any, **kwargs: Any ): - logger.debug("GraphDict.__init__") super().__init__(*args, **kwargs) self.data: dict[str, Any] = {} @@ -119,22 +129,22 @@ def __init__( self.data.update(data) @key_is_string + @logger_debug def __contains__(self, key: str) -> bool: """'foo' in G.graph""" if key in self.data: - logger.debug(f"cached in GraphDict.__contains__({key})") return True - logger.debug("aql_doc_has_key in GraphDict.__contains__") return aql_doc_has_key(self.db, self.graph_id, key) @key_is_string + @logger_debug def __getitem__(self, key: str) -> Any: """G.graph['foo']""" + if value := self.data.get(key): return value - logger.debug("aql_doc_get_key in GraphDict.__getitem__") result = aql_doc_get_key(self.db, self.graph_id, key) if not result: @@ -146,40 +156,87 @@ def __getitem__(self, key: str) -> Any: @key_is_string @key_is_not_reserved + @logger_debug # @value_is_json_serializable # TODO? def __setitem__(self, key: str, value: Any) -> None: """G.graph['foo'] = 'bar'""" self.data[key] = value - logger.debug(f"doc_update in GraphDict.__setitem__({key})") - doc_update(self.db, self.graph_id, {key: value}) + self.data["_rev"] = doc_update(self.db, self.graph_id, {key: value}) @key_is_string @key_is_not_reserved + @logger_debug def __delitem__(self, key: str) -> None: """del G.graph['foo']""" self.data.pop(key, None) - logger.debug(f"doc_update in GraphDict.__delitem__({key})") - doc_update(self.db, self.graph_id, {key: None}) + self.data["_rev"] = doc_update(self.db, self.graph_id, {key: None}) @keys_are_strings @keys_are_not_reserved # @values_are_json_serializable # TODO? + @logger_debug def update(self, attrs: Any) -> None: """G.graph.update({'foo': 'bar'})""" - if attrs: - self.data.update(attrs) - logger.debug(f"doc_update in GraphDict.update({attrs})") - doc_update(self.db, self.graph_id, attrs) + if not attrs: + return + + self.data.update(attrs) + self.data["_rev"] = doc_update(self.db, self.graph_id, attrs) + @logger_debug def clear(self) -> None: """G.graph.clear()""" self.data.clear() - logger.debug("cleared GraphDict") # if clear_remote: # doc_insert(self.db, self.COLLECTION_NAME, self.graph_id, silent=True) +######## +# Node # +######## + + +def process_node_attr_dict_value(parent: NodeAttrDict, key: str, value: Any) -> Any: + if not isinstance(value, dict): + return value + + node_attr_dict = parent.node_attr_dict_factory() + node_attr_dict.node_id = parent.node_id + node_attr_dict.parent_keys = parent.parent_keys + [key] + node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, value) + + return node_attr_dict + + +def build_node_attr_dict_data( + parent: NodeAttrDict, data: dict[str, Any] +) -> dict[str, Any | NodeAttrDict]: + """Recursively build a NodeAttrDict from a dict. + + It's possible that **value** is a nested dict, so we need to + recursively build a NodeAttrDict for each nested dict. + + Returns the parent NodeAttrDict. + """ + node_attr_dict_data = {} + for key, value in data.items(): + node_attr_dict_value = process_node_attr_dict_value(parent, key, value) + node_attr_dict_data[key] = node_attr_dict_value + + return node_attr_dict_data + + +def get_update_dict( + parent_keys: list[str], update_dict: dict[str, Any] +) -> dict[str, Any]: + if parent_keys: + for key in reversed(parent_keys): + update_dict = {key: update_dict} + + return update_dict + + class NodeAttrDict(UserDict[str, Any]): """The inner-level of the dict of dict structure representing the nodes (vertices) of a graph. @@ -190,96 +247,123 @@ class NodeAttrDict(UserDict[str, Any]): :type graph: Graph """ + @logger_debug def __init__(self, db: StandardDatabase, graph: Graph, *args: Any, **kwargs: Any): - logger.debug("NodeAttrDict.__init__") self.db = db self.graph = graph - self.node_id: str + self.node_id: str | None = None + + # NodeAttrDict may be a child of another NodeAttrDict + # e.g G._node['node/1']['object']['foo'] = 'bar' + # In this case, parent_keys would be ['object'] + self.parent_keys: list[str] = [] super().__init__(*args, **kwargs) self.data: dict[str, Any] = {} + self.node_attr_dict_factory = node_attr_dict_factory(self.db, self.graph) @key_is_string + @logger_debug def __contains__(self, key: str) -> bool: """'foo' in G._node['node/1']""" if key in self.data: - logger.debug(f"cached in NodeAttrDict.__contains__({key})") return True - logger.debug("aql_doc_has_key in NodeAttrDict.__contains__") + assert self.node_id return aql_doc_has_key(self.db, self.node_id, key) @key_is_string + @logger_debug def __getitem__(self, key: str) -> Any: """G._node['node/1']['foo']""" if value := self.data.get(key): - logger.debug(f"cached in NodeAttrDict.__getitem__({key})") return value - logger.debug(f"aql_doc_get_key in NodeAttrDict.__getitem__({key})") + assert self.node_id result = aql_doc_get_key(self.db, self.node_id, key) if not result: raise KeyError(key) - self.data[key] = result + node_attr_dict_value = process_node_attr_dict_value(self, key, result) + self.data[key] = node_attr_dict_value - return result + return node_attr_dict_value @key_is_string @key_is_not_reserved # @value_is_json_serializable # TODO? + @logger_debug def __setitem__(self, key: str, value: Any) -> None: - """G._node['node/1']['foo'] = 'bar'""" - self.data[key] = value - logger.debug(f"doc_update in NodeAttrDict.__setitem__({key})") - doc_update(self.db, self.node_id, {key: value}) + """ + G._node['node/1']['foo'] = 'bar' + G._node['node/1']['object'] = {'foo': 'bar'} + G._node['node/1']['object']['foo'] = 'baz' + """ + assert self.node_id + node_attr_dict_value = process_node_attr_dict_value(self, key, value) + update_dict = get_update_dict(self.parent_keys, {key: value}) + self.data[key] = node_attr_dict_value + self.data["_rev"] = doc_update(self.db, self.node_id, update_dict) @key_is_string @key_is_not_reserved + @logger_debug def __delitem__(self, key: str) -> None: """del G._node['node/1']['foo']""" + assert self.node_id self.data.pop(key, None) - logger.debug(f"doc_update in NodeAttrDict({self.node_id}).__delitem__({key})") - doc_update(self.db, self.node_id, {key: None}) - - def __iter__(self) -> Iterator[str]: - """for key in G._node['node/1']""" - logger.debug(f"NodeAttrDict({self.node_id}).__iter__") - yield from aql_doc_get_keys(self.db, self.node_id) - - def __len__(self) -> int: - """len(G._node['node/1'])""" - logger.debug(f"NodeAttrDict({self.node_id}).__len__") - return aql_doc_get_length(self.db, self.node_id) - - # TODO: Revisit typing of return value - from collections.abc import KeysView - - def keys(self) -> Any: - """G._node['node/1'].keys()""" - logger.debug(f"NodeAttrDict({self.node_id}).keys()") - yield from self.__iter__() - - # TODO: Revisit typing of return value - def values(self) -> Any: - """G._node['node/1'].values()""" - logger.debug(f"NodeAttrDict({self.node_id}).values()") - self.data = self.db.document(self.node_id) - yield from self.data.values() - - # TODO: Revisit typing of return value - def items(self) -> Any: - """G._node['node/1'].items()""" - logger.debug(f"NodeAttrDict({self.node_id}).items()") - self.data = self.db.document(self.node_id) - yield from self.data.items() - + update_dict = get_update_dict(self.parent_keys, {key: None}) + self.data["_rev"] = doc_update(self.db, self.node_id, update_dict) + + # @logger_debug + # def __iter__(self) -> Iterator[str]: + # """for key in G._node['node/1']""" + # yield from aql_doc_get_keys(self.db, self.node_id, self.parent_keys) + + # @logger_debug + # def __len__(self) -> int: + # """len(G._node['node/1'])""" + # return aql_doc_get_length(self.db, self.node_id, self.parent_keys) + + # @logger_debug + # def keys(self) -> Any: + # """G._node['node/1'].keys()""" + # yield from self.__iter__() + + # @logger_debug + # # TODO: Revisit typing of return value + # def values(self) -> Any: + # """G._node['node/1'].values()""" + # breakpoint() + # self.data = self.db.document(self.node_id) + # yield from self.data.values() + + # @logger_debug + # # TODO: Revisit typing of return value + # def items(self) -> Any: + # """G._node['node/1'].items()""" + + # # TODO: Revisit this lazy hack + # if self.parent_keys: + # yield from self.data.items() + # else: + # self.data = self.db.document(self.node_id) + # yield from self.data.items() + + # ? + # def pull(): + # pass + + # ? + # def push(): + # pass + + @logger_debug def clear(self) -> None: """G._node['node/1'].clear()""" self.data.clear() - logger.debug(f"cleared NodeAttrDict({self.node_id})") # if clear_remote: # doc_insert(self.db, self.node_id, silent=True, overwrite=True) @@ -287,17 +371,20 @@ def clear(self) -> None: @keys_are_strings @keys_are_not_reserved # @values_are_json_serializable # TODO? + @logger_debug def update(self, attrs: Any) -> None: """G._node['node/1'].update({'foo': 'bar'})""" - if attrs: - self.data.update(attrs) + if not attrs: + return - if not self.node_id: - logger.debug("Node ID not set, skipping NodeAttrDict(?).update()") - return + self.data.update(build_node_attr_dict_data(self, attrs)) - logger.debug(f"NodeAttrDict({self.node_id}).update({attrs})") - doc_update(self.db, self.node_id, attrs) + if not self.node_id: + logger.warning("Node ID not set, skipping NodeAttrDict(?).update()") + return + + update_dict = get_update_dict(self.parent_keys, attrs) + self.data["_rev"] = doc_update(self.db, self.node_id, update_dict) class NodeDict(UserDict[str, NodeAttrDict]): @@ -316,6 +403,7 @@ class NodeDict(UserDict[str, NodeAttrDict]): :type default_node_type: str """ + @logger_debug def __init__( self, db: StandardDatabase, @@ -324,7 +412,6 @@ def __init__( *args: Any, **kwargs: Any, ): - logger.debug("NodeDict.__init__") super().__init__(*args, **kwargs) self.data: dict[str, NodeAttrDict] = {} @@ -334,32 +421,29 @@ def __init__( self.node_attr_dict_factory = node_attr_dict_factory(self.db, self.graph) @key_is_string + @logger_debug def __contains__(self, key: str) -> bool: """'node/1' in G._node""" node_id = get_node_id(key, self.default_node_type) if node_id in self.data: - logger.debug(f"cached in NodeDict.__contains__({node_id})") return True - logger.debug(f"graph.has_vertex in NodeDict.__contains__({node_id})") return bool(self.graph.has_vertex(node_id)) @key_is_string + @logger_debug def __getitem__(self, key: str) -> NodeAttrDict: """G._node['node/1']""" node_id = get_node_id(key, self.default_node_type) if value := self.data.get(node_id): - logger.debug(f"cached in NodeDict.__getitem__({node_id})") return value if value := self.graph.vertex(node_id): - logger.debug(f"graph.vertex in NodeDict.__getitem__({node_id})") node_attr_dict: NodeAttrDict = self.node_attr_dict_factory() node_attr_dict.node_id = node_id - node_attr_dict.data = value - + node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, value) self.data[node_id] = node_attr_dict return node_attr_dict @@ -367,6 +451,7 @@ def __getitem__(self, key: str) -> NodeAttrDict: raise KeyError(key) @key_is_string + @logger_debug def __setitem__(self, key: str, value: NodeAttrDict) -> None: """G._node['node/1'] = {'foo': 'bar'} @@ -377,7 +462,6 @@ def __setitem__(self, key: str, value: NodeAttrDict) -> None: node_type, node_id = get_node_type_and_id(key, self.default_node_type) - logger.debug(f"doc_insert in NodeDict.__setitem__({key})") result = doc_insert(self.db, node_type, node_id, value.data) node_attr_dict = self.node_attr_dict_factory() @@ -387,6 +471,7 @@ def __setitem__(self, key: str, value: NodeAttrDict) -> None: self.data[node_id] = node_attr_dict @key_is_string + @logger_debug def __delitem__(self, key: str) -> None: """del g._node['node/1']""" node_id = get_node_id(key, self.default_node_type) @@ -394,6 +479,7 @@ def __delitem__(self, key: str) -> None: if not self.graph.has_vertex(node_id): raise KeyError(key) + # TODO: wrap in edges_delete() method remove_statements = "\n".join( f"REMOVE e IN `{edge_def['edge_collection']}` OPTIONS {{ignoreErrors: true}}" # noqa for edge_def in self.graph.edge_definitions() @@ -406,17 +492,16 @@ def __delitem__(self, key: str) -> None: bind_vars = {"src_node_id": node_id, "graph_name": self.graph.name} - logger.debug(f"remove_edges in NodeDict.__delitem__({node_id})") aql(self.db, query, bind_vars) + ##### - logger.debug(f"doc_delete in NodeDict.__delitem__({node_id})") doc_delete(self.db, node_id) self.data.pop(node_id, None) + @logger_debug def __len__(self) -> int: """len(g._node)""" - logger.debug("NodeDict.__len__") return sum( [ self.graph.vertex_collection(c).count() @@ -424,64 +509,53 @@ def __len__(self) -> int: ] ) + @logger_debug def __iter__(self) -> Iterator[str]: """iter(g._node)""" - logger.debug("NodeDict.__iter__") for collection in self.graph.vertex_collections(): yield from self.graph.vertex_collection(collection).ids() + @logger_debug def clear(self) -> None: """g._node.clear()""" self.data.clear() - logger.debug("cleared NodeDict") # if clear_remote: # for collection in self.graph.vertex_collections(): # self.graph.vertex_collection(collection).truncate() @keys_are_strings + @logger_debug def update(self, nodes: Any) -> None: """g._node.update({'node/1': {'foo': 'bar'}, 'node/2': {'baz': 'qux'}})""" raise NotImplementedError("NodeDict.update()") - # for node_id, attrs in nodes.items(): - # node_id = get_node_id(node_id, self.default_node_type) - - # result = doc_update(self.db, node_id, attrs) - - # node_attr_dict = self.node_attr_dict_factory() - # node_attr_dict.node_id = node_id - # node_attr_dict.data = result - - # self.data[node_id] = node_attr_dict + @logger_debug def keys(self) -> Any: """g._node.keys()""" - logger.debug("NodeDict.keys()") return self.__iter__() # TODO: Revisit typing of return value + @logger_debug def values(self) -> Any: """g._node.values()""" - logger.debug("NodeDict.values()") self.__fetch_all() yield from self.data.values() # TODO: Revisit typing of return value + @logger_debug def items(self, data: str | None = None, default: Any | None = None) -> Any: """g._node.items() or G._node.items(data='foo')""" if data is None: - logger.debug("NodeDict.items(data=None)") self.__fetch_all() yield from self.data.items() else: - logger.debug(f"NodeDict.items(data={data})") v_cols = list(self.graph.vertex_collections()) result = aql_fetch_data(self.db, v_cols, data, default) yield from result.items() + @logger_debug def __fetch_all(self): - logger.debug("NodeDict.__fetch_all()") - self.data.clear() for collection in self.graph.vertex_collections(): for doc in self.graph.vertex_collection(collection).all(): @@ -494,6 +568,11 @@ def __fetch_all(self): self.data[node_id] = node_attr_dict +############# +# Adjacency # +############# + + class EdgeAttrDict(UserDict[str, Any]): """The innermost-level of the dict of dict of dict structure representing the Adjacency List of a graph. @@ -506,6 +585,7 @@ class EdgeAttrDict(UserDict[str, Any]): :type graph: Graph """ + @logger_debug def __init__( self, db: StandardDatabase, @@ -513,35 +593,31 @@ def __init__( *args: Any, **kwargs: Any, ) -> None: - logger.debug("EdgeAttrDict.__init__") - super().__init__(*args, **kwargs) self.data: dict[str, Any] = {} self.db = db self.graph = graph - self.edge_id: str + self.edge_id: str | None = None @key_is_string + @logger_debug def __contains__(self, key: str) -> bool: """'foo' in G._adj['node/1']['node/2']""" if key in self.data: - logger.debug(f"cached in EdgeAttrDict({self.edge_id}).__contains__({key})") return True - logger.debug(f"aql_doc_has_key in EdgeAttrDict({self.edge_id}).__contains__") + assert self.edge_id return aql_doc_has_key(self.db, self.edge_id, key) @key_is_string + @logger_debug def __getitem__(self, key: str) -> Any: """G._adj['node/1']['node/2']['foo']""" if value := self.data.get(key): - logger.debug(f"cached in EdgeAttrDict({self.edge_id}).__getitem__({key})") return value - logger.debug( - f"aql_doc_get_key in EdgeAttrDict({self.edge_id}).__getitem__({key})" - ) + assert self.edge_id result = aql_doc_get_key(self.db, self.edge_id, key) if not result: @@ -554,68 +630,75 @@ def __getitem__(self, key: str) -> Any: @key_is_string @key_is_not_reserved # @value_is_json_serializable # TODO? + @logger_debug def __setitem__(self, key: str, value: Any) -> None: """G._adj['node/1']['node/2']['foo'] = 'bar'""" + assert self.edge_id self.data[key] = value - logger.debug(f"doc_update in EdgeAttrDict({self.edge_id}).__setitem__({key})") - doc_update(self.db, self.edge_id, {key: value}) + self.data["_rev"] = doc_update(self.db, self.edge_id, {key: value}) @key_is_string @key_is_not_reserved + @logger_debug def __delitem__(self, key: str) -> None: """del G._adj['node/1']['node/2']['foo']""" + assert self.edge_id self.data.pop(key, None) - logger.debug(f"doc_update in EdgeAttrDict({self.edge_id}).__delitem__({key})") - doc_update(self.db, self.edge_id, {key: None}) - - def __iter__(self) -> Iterator[str]: - """for key in G._adj['node/1']['node/2']""" - logger.debug(f"EEdgeAttrDict({self.edge_id}).__iter__") - yield from aql_doc_get_keys(self.db, self.edge_id) - - def __len__(self) -> int: - """len(G._adj['node/1']['node/'2])""" - logger.debug(f"EdgeAttrDict({self.edge_id}).__len__") - return aql_doc_get_length(self.db, self.edge_id) - - # TODO: Revisit typing of return value - def keys(self) -> Any: - """G._adj['node/1']['node/'2].keys()""" - logger.debug(f"EdgeAttrDict({self.edge_id}).keys()") - return self.__iter__() - - # TODO: Revisit typing of return value - def values(self) -> Any: - """G._adj['node/1']['node/'2].values()""" - logger.debug(f"EdgeAttrDict({self.edge_id}).values()") - self.data = self.db.document(self.edge_id) - yield from self.data.values() - - # TODO: Revisit typing of return value - def items(self) -> Any: - """G._adj['node/1']['node/'2].items()""" - logger.debug(f"EdgeAttrDict({self.edge_id}).items()") - self.data = self.db.document(self.edge_id) - yield from self.data.items() - + self.data["_rev"] = doc_update(self.db, self.edge_id, {key: None}) + + # @logger_debug + # def __iter__(self) -> Iterator[str]: + # """for key in G._adj['node/1']['node/2']""" + # assert self.edge_id + # yield from aql_doc_get_keys(self.db, self.edge_id) + + # @logger_debug + # def __len__(self) -> int: + # """len(G._adj['node/1']['node/'2])""" + # assert self.edge_id + # return aql_doc_get_length(self.db, self.edge_id) + + # # TODO: Revisit typing of return value + # @logger_debug + # def keys(self) -> Any: + # """G._adj['node/1']['node/'2].keys()""" + # return self.__iter__() + + # # TODO: Revisit typing of return value + # @logger_debug + # def values(self) -> Any: + # """G._adj['node/1']['node/'2].values()""" + # self.data = self.db.document(self.edge_id) + # yield from self.data.values() + + # # TODO: Revisit typing of return value + # @logger_debug + # def items(self) -> Any: + # """G._adj['node/1']['node/'2].items()""" + # self.data = self.db.document(self.edge_id) + # yield from self.data.items() + + @logger_debug def clear(self) -> None: """G._adj['node/1']['node/'2].clear()""" self.data.clear() - logger.debug(f"cleared EdgeAttrDict({self.edge_id})") @keys_are_strings @keys_are_not_reserved + @logger_debug def update(self, attrs: Any) -> None: """G._adj['node/1']['node/'2].update({'foo': 'bar'})""" - if attrs: - self.data.update(attrs) + if not attrs: + return - if not hasattr(self, "edge_id"): - logger.debug("Edge ID not set, skipping EdgeAttrDict(?).update()") - return + self.data.update(attrs) - logger.debug(f"EdgeAttrDict({self.edge_id}).update({attrs})") - doc_update(self.db, self.edge_id, attrs) + if not self.edge_id: + logger.warning("Edge ID not set, skipping EdgeAttrDict(?).update()") + return + + assert self.edge_id + self.data["_rev"] = doc_update(self.db, self.edge_id, attrs) class AdjListInnerDict(UserDict[str, EdgeAttrDict]): @@ -634,6 +717,7 @@ class AdjListInnerDict(UserDict[str, EdgeAttrDict]): :type edge_type_func: Callable[[str, str], str] """ + @logger_debug def __init__( self, db: StandardDatabase, @@ -644,8 +728,6 @@ def __init__( *args: Any, **kwargs: Any, ): - logger.debug("AdjListInnerDict.__init__") - super().__init__(*args, **kwargs) self.data: dict[str, EdgeAttrDict] = {} @@ -655,40 +737,41 @@ def __init__( self.edge_type_func = edge_type_func self.adjlist_outer_dict = adjlist_outer_dict - self.src_node_id: str + self.src_node_id: str | None = None self.edge_attr_dict_factory = edge_attr_dict_factory(self.db, self.graph) self.FETCHED_ALL_DATA = False + @logger_debug def __get_mirrored_edge_attr_dict(self, dst_node_id: str) -> EdgeAttrDict | None: if self.adjlist_outer_dict is None: return None - logger.debug(f"checking for mirrored edge ({self.src_node_id}, {dst_node_id})") if dst_node_id in self.adjlist_outer_dict.data: if self.src_node_id in self.adjlist_outer_dict.data[dst_node_id].data: return self.adjlist_outer_dict.data[dst_node_id].data[self.src_node_id] return None + @logger_debug def __repr__(self) -> str: return f"'{self.src_node_id}'" + @logger_debug def __str__(self) -> str: return f"'{self.src_node_id}'" @key_is_string + @logger_debug def __contains__(self, key: str) -> bool: """'node/2' in G.adj['node/1']""" + assert self.src_node_id dst_node_id = get_node_id(key, self.default_node_type) if dst_node_id in self.data: - logger.debug(f"cached in AdjListInnerDict.__contains__({dst_node_id})") return True - logger.debug(f"aql_edge_exists in AdjListInnerDict.__contains__({dst_node_id})") - result = aql_edge_exists( self.db, self.src_node_id, @@ -700,21 +783,19 @@ def __contains__(self, key: str) -> bool: return result if result else False @key_is_string + @logger_debug def __getitem__(self, key: str) -> EdgeAttrDict: """g._adj['node/1']['node/2']""" + assert self.src_node_id dst_node_id = get_node_id(key, self.default_node_type) if dst_node_id in self.data: - m = f"cached in AdjListInnerDict({self.src_node_id}).__getitem__({dst_node_id})" # noqa - logger.debug(m) return self.data[dst_node_id] if mirrored_edge_attr_dict := self.__get_mirrored_edge_attr_dict(dst_node_id): - logger.debug("No need to fetch the edge, as it is already cached") self.data[dst_node_id] = mirrored_edge_attr_dict - return mirrored_edge_attr_dict + return mirrored_edge_attr_dict # type: ignore # false positive - m = f"aql_edge_get in AdjListInnerDict({self.src_node_id}).__getitem__({dst_node_id})" # noqa edge = aql_edge_get( self.db, self.src_node_id, @@ -735,28 +816,25 @@ def __getitem__(self, key: str) -> EdgeAttrDict: return edge_attr_dict @key_is_string + @logger_debug def __setitem__(self, key: str, value: dict[str, Any] | EdgeAttrDict) -> None: """g._adj['node/1']['node/2'] = {'foo': 'bar'}""" assert isinstance(value, EdgeAttrDict) - logger.debug(f"AdjListInnerDict({self.src_node_id}).__setitem__({key})") + assert self.src_node_id src_node_type = self.src_node_id.split("/")[0] dst_node_type, dst_node_id = get_node_type_and_id(key, self.default_node_type) if mirrored_edge_attr_dict := self.__get_mirrored_edge_attr_dict(dst_node_id): - logger.debug("No need to create a new edge, as it already exists") self.data[dst_node_id] = mirrored_edge_attr_dict return edge_type = value.data.get("_edge_type") if edge_type is None: edge_type = self.edge_type_func(src_node_type, dst_node_type) - logger.debug(f"No edge type specified, so generated: {edge_type})") edge_id: str | None - if hasattr(value, "edge_id"): - m = f"edge id found, deleting ({self.src_node_id, dst_node_id})" - logger.debug(m) + if value.edge_id: self.graph.delete_edge(value.edge_id) elif edge_id := aql_edge_id( @@ -766,12 +844,9 @@ def __setitem__(self, key: str, value: dict[str, Any] | EdgeAttrDict) -> None: self.graph.name, direction="ANY", ): - m = f"existing edge found, deleting ({self.src_node_id, dst_node_id})" - logger.debug(m) self.graph.delete_edge(edge_id) edge_data = value.data - logger.debug(f"graph.link({self.src_node_id}, {dst_node_id})") edge = self.graph.link(edge_type, self.src_node_id, dst_node_id, edge_data) edge_attr_dict = self.edge_attr_dict_factory() @@ -786,17 +861,16 @@ def __setitem__(self, key: str, value: dict[str, Any] | EdgeAttrDict) -> None: self.data[dst_node_id] = edge_attr_dict @key_is_string + @logger_debug def __delitem__(self, key: str) -> None: """del g._adj['node/1']['node/2']""" + assert self.src_node_id dst_node_id = get_node_id(key, self.default_node_type) self.data.pop(dst_node_id, None) if self.__get_mirrored_edge_attr_dict(dst_node_id): - m = "No need to delete the edge, as the next del will take care of it" - logger.debug(m) return - logger.debug(f"fetching edge ({self.src_node_id, dst_node_id})") edge_id = aql_edge_id( self.db, self.src_node_id, @@ -806,22 +880,19 @@ def __delitem__(self, key: str) -> None: ) if not edge_id: - m = f"edge not found, AdjListInnerDict({self.src_node_id}).__delitem__({dst_node_id})" # noqa - logger.debug(m) return - logger.debug(f"graph.delete_edge({edge_id})") self.graph.delete_edge(edge_id) + @logger_debug def __len__(self) -> int: """len(g._adj['node/1'])""" assert self.src_node_id if self.FETCHED_ALL_DATA: - m = f"Already fetched data, skipping AdjListInnerDict({self.src_node_id}).__len__" # noqa - logger.debug(m) return len(self.data) + # TODO: Create aql_edge_count() function query = """ RETURN LENGTH( FOR v, e IN 1..1 OUTBOUND @src_node_id GRAPH @graph_name @@ -831,19 +902,18 @@ def __len__(self) -> int: bind_vars = {"src_node_id": self.src_node_id, "graph_name": self.graph.name} - logger.debug(f"aql_single in AdjListInnerDict({self.src_node_id}).__len__") result = aql_single(self.db, query, bind_vars) + ##### if result is None: return 0 return int(result) + @logger_debug def __iter__(self) -> Iterator[str]: """for k in g._adj['node/1']""" if self.FETCHED_ALL_DATA: - m = f"Already fetched data, skipping AdjListInnerDict({self.src_node_id}).__iter__" # noqa - logger.debug(m) yield from self.data.keys() else: @@ -854,45 +924,43 @@ def __iter__(self) -> Iterator[str]: bind_vars = {"src_node_id": self.src_node_id, "graph_name": self.graph.name} - logger.debug(f"aql in AdjListInnerDict({self.src_node_id}).__iter__") yield from aql(self.db, query, bind_vars) # TODO: Revisit typing of return value + @logger_debug def keys(self) -> Any: """g._adj['node/1'].keys()""" - logger.debug(f"AdjListInnerDict({self.src_node_id}).keys()") return self.__iter__() + @logger_debug def clear(self) -> None: """G._adj['node/1'].clear()""" self.data.clear() self.FETCHED_ALL_DATA = False - logger.debug(f"cleared AdjListInnerDict({self.src_node_id})") @keys_are_strings + @logger_debug def update(self, edges: Any) -> None: """g._adj['node/1'].update({'node/2': {'foo': 'bar'}})""" raise NotImplementedError("AdjListInnerDict.update()") # TODO: Revisit typing of return value + @logger_debug def values(self) -> Any: """g._adj['node/1'].values()""" - logger.debug(f"AdjListInnerDict({self.src_node_id}).values()") self.__fetch_all() yield from self.data.values() # TODO: Revisit typing of return value + @logger_debug def items(self) -> Any: """g._adj['node/1'].items()""" - logger.debug(f"AdjListInnerDict({self.src_node_id}).items()") self.__fetch_all() yield from self.data.items() + @logger_debug def __fetch_all(self) -> None: - logger.debug(f"AdjListInnerDict({self.src_node_id}).__fetch_all()") - if self.FETCHED_ALL_DATA: - logger.debug("Already fetched data, skipping fetch") return self.clear() @@ -930,6 +998,7 @@ class AdjListOuterDict(UserDict[str, AdjListInnerDict]): :type edge_type_func: Callable[[str, str], str] """ + @logger_debug def __init__( self, db: StandardDatabase, @@ -939,8 +1008,6 @@ def __init__( *args: Any, **kwargs: Any, ): - logger.debug("AdjListOuterDict.__init__") - super().__init__(*args, **kwargs) self.data: dict[str, AdjListInnerDict] = {} @@ -954,35 +1021,35 @@ def __init__( self.FETCHED_ALL_DATA = False + @logger_debug def __repr__(self) -> str: return f"'{self.graph.name}'" + @logger_debug def __str__(self) -> str: return f"'{self.graph.name}'" @key_is_string + @logger_debug def __contains__(self, key: str) -> bool: """'node/1' in G.adj""" node_id = get_node_id(key, self.default_node_type) if node_id in self.data: - logger.debug(f"cached in AdjListOuterDict.__contains__({node_id})") return True - logger.debug("graph.has_vertex in AdjListOuterDict.__contains__") return bool(self.graph.has_vertex(node_id)) @key_is_string + @logger_debug def __getitem__(self, key: str) -> AdjListInnerDict: """G.adj["node/1"]""" - node_type, node_id = get_node_type_and_id(key, self.default_node_type) + node_id = get_node_id(key, self.default_node_type) if value := self.data.get(node_id): - logger.debug(f"cached in AdjListOuterDict.__getitem__({node_id})") return value if self.graph.has_vertex(node_id): - logger.debug(f"graph.vertex in AdjListOuterDict.__getitem__({node_id})") adjlist_inner_dict: AdjListInnerDict = self.adjlist_inner_dict_factory() adjlist_inner_dict.src_node_id = node_id @@ -993,41 +1060,40 @@ def __getitem__(self, key: str) -> AdjListInnerDict: raise KeyError(key) @key_is_string + @logger_debug def __setitem__(self, src_key: str, adjlist_inner_dict: AdjListInnerDict) -> None: """ g._adj['node/1'] = AdjListInnerDict() """ assert isinstance(adjlist_inner_dict, AdjListInnerDict) - assert not hasattr(adjlist_inner_dict, "src_node_id") - - logger.debug(f"AdjListOuterDict.__setitem__({src_key})") + assert len(adjlist_inner_dict.data) == 0 # See NOTE below src_node_type, src_node_id = get_node_type_and_id( src_key, self.default_node_type ) # NOTE: this might not actually be needed... - results = {} - for dst_key, edge_dict in adjlist_inner_dict.data.items(): - dst_node_type, dst_node_id = get_node_type_and_id( - dst_key, self.default_node_type - ) + # results = {} + # for dst_key, edge_dict in adjlist_inner_dict.data.items(): + # dst_node_type, dst_node_id = get_node_type_and_id( + # dst_key, self.default_node_type + # ) - edge_type = edge_dict.get("_edge_type") - if edge_type is None: - edge_type = self.edge_type_func(src_node_type, dst_node_type) + # edge_type = edge_dict.get("_edge_type") + # if edge_type is None: + # edge_type = self.edge_type_func(src_node_type, dst_node_type) - logger.debug(f"graph.link({src_key}, {dst_key})") - results[dst_key] = self.graph.link( - edge_type, src_node_id, dst_node_id, edge_dict - ) + # results[dst_key] = self.graph.link( + # edge_type, src_node_id, dst_node_id, edge_dict + # ) adjlist_inner_dict.src_node_id = src_node_id - adjlist_inner_dict.data = results + # adjlist_inner_dict.data = results self.data[src_node_id] = adjlist_inner_dict @key_is_string + @logger_debug def __delitem__(self, key: str) -> None: """ del G._adj['node/1'] @@ -1035,13 +1101,12 @@ def __delitem__(self, key: str) -> None: # Nothing else to do here, as this delete is always invoked by # G.remove_node(), which already removes all edges via # del G._node['node/1'] - logger.debug(f"AdjListOuterDict.__delitem__({key}) (just cache)") node_id = get_node_id(key, self.default_node_type) self.data.pop(node_id, None) + @logger_debug def __len__(self) -> int: """len(g._adj)""" - logger.debug("AdjListOuterDict.__len__") return sum( [ self.graph.vertex_collection(c).count() @@ -1049,10 +1114,9 @@ def __len__(self) -> int: ] ) + @logger_debug def __iter__(self) -> Iterator[str]: """for k in g._adj""" - logger.debug("AdjListOuterDict.__iter__") - if self.FETCHED_ALL_DATA: yield from self.data.keys() @@ -1061,34 +1125,36 @@ def __iter__(self) -> Iterator[str]: yield from self.graph.vertex_collection(collection).ids() # TODO: Revisit typing of return value + @logger_debug def keys(self) -> Any: """g._adj.keys()""" - logger.debug("AdjListOuterDict.keys()") return self.__iter__() + @logger_debug def clear(self) -> None: """g._node.clear()""" self.data.clear() self.FETCHED_ALL_DATA = False - logger.debug("cleared AdjListOuterDict") # if clear_remote: # for ed in self.graph.edge_definitions(): # self.graph.edge_collection(ed["edge_collection"]).truncate() @keys_are_strings + @logger_debug def update(self, edges: Any) -> None: """g._adj.update({'node/1': {'node/2': {'foo': 'bar'}})""" raise NotImplementedError("AdjListOuterDict.update()") # TODO: Revisit typing of return value + @logger_debug def values(self) -> Any: """g._adj.values()""" - logger.debug("AdjListOuterDict.values()") self.__fetch_all() yield from self.data.values() # TODO: Revisit typing of return value + @logger_debug def items(self, data: str | None = None, default: Any | None = None) -> Any: # TODO: Revisit typing # -> ( @@ -1097,22 +1163,18 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any: # ): """g._adj.items() or G._adj.items(data='foo')""" if data is None: - logger.debug("AdjListOuterDict.items(data=None)") self.__fetch_all() yield from self.data.items() else: - logger.debug(f"AdjListOuterDict.items(data={data})") e_cols = [ed["edge_collection"] for ed in self.graph.edge_definitions()] result = aql_fetch_data_edge(self.db, e_cols, data, default) yield from result # TODO: Revisit this logic + @logger_debug def __fetch_all(self) -> None: - logger.debug("AdjListOuterDict.__fetch_all()") - if self.FETCHED_ALL_DATA: - logger.debug("Already fetched data, skipping fetch") return self.clear() diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py index cada904a..e94cf97c 100644 --- a/nx_arangodb/classes/function.py +++ b/nx_arangodb/classes/function.py @@ -15,6 +15,7 @@ from arango.database import StandardDatabase import nx_arangodb as nxadb +from nx_arangodb.logger import logger from ..exceptions import ( AQLMultipleResultsFound, @@ -98,6 +99,16 @@ def wrapper(self: Any, key: Any, *args: Any, **kwargs: Any) -> Any: return wrapper +def logger_debug(func: Callable[..., Any]) -> Any: + """Decorator to log debug messages.""" + + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + logger.debug(f"{func.__name__} - {args} - {kwargs}") + return func(self, *args, **kwargs) + + return wrapper + + def keys_are_strings(func: Callable[..., Any]) -> Any: """Decorator to check if the keys are strings.""" @@ -219,20 +230,37 @@ def aql_doc_get_key(db: StandardDatabase, id: str, key: str) -> Any: return aql_single(db, query, bind_vars) -def aql_doc_get_keys(db: StandardDatabase, id: str) -> list[str]: +def aql_doc_get_items( + db: StandardDatabase, id: str, nested_key: list[str] = [] +) -> dict[str, Any]: + """Gets the items of a document.""" + nested_key_str = "." + ".".join(nested_key) if nested_key else "" + query = f"RETURN DOCUMENT(@id){nested_key_str}" + bind_vars = {"id": id} + result = aql_single(db, query, bind_vars) + return result or {} + + +def aql_doc_get_keys( + db: StandardDatabase, id: str, nested_keys: list[str] = [] +) -> list[str]: """Gets the keys of a document.""" - query = "RETURN ATTRIBUTES(DOCUMENT(@id))" + nested_keys_str = "." + ".".join(nested_keys) if nested_keys else "" + query = f"RETURN ATTRIBUTES(DOCUMENT(@id){nested_keys_str})" bind_vars = {"id": id} result = aql_single(db, query, bind_vars) - return list(result) if result is not None else [] + return list(result or []) -def aql_doc_get_length(db: StandardDatabase, id: str) -> int: +def aql_doc_get_length( + db: StandardDatabase, id: str, nested_keys: list[str] = [] +) -> int: """Gets the length of a document.""" - query = "RETURN LENGTH(DOCUMENT(@id))" + nested_keys_str = "." + ".".join(nested_keys) if nested_keys else "" + query = f"RETURN LENGTH(DOCUMENT(@id){nested_keys_str})" bind_vars = {"id": id} result = aql_single(db, query, bind_vars) - return int(result) if result is not None else 0 + return int(result or 0) def aql_edge_exists( @@ -375,9 +403,10 @@ def aql_fetch_data_edge( def doc_update( db: StandardDatabase, id: str, data: dict[str, Any], **kwargs: Any -) -> None: +) -> str: """Updates a document in the collection.""" - db.update_document({**data, "_id": id}, keep_none=False, silent=True, **kwargs) + res = db.update_document({**data, "_id": id}, keep_none=False, **kwargs) + return str(res["_rev"]) def doc_delete(db: StandardDatabase, id: str, **kwargs: Any) -> None: diff --git a/tests/test.py b/tests/test.py index a589ef74..fb48eb9e 100644 --- a/tests/test.py +++ b/tests/test.py @@ -217,8 +217,6 @@ def test_graph_nodes_crud(load_graph: Any) -> None: len(G_1.nodes) == len(G_2.nodes) + 1 G_1.clear() assert G_1.nodes["person/35"]["foo"] == {"bar": "baz"} - # TODO: Support this use case: - # G_1.nodes["person/35"]["foo"]["bar"] = "baz2" G_1.add_nodes_from(["1", "2", "3"], foo="bar") G_1.clear() @@ -256,6 +254,24 @@ def test_graph_nodes_crud(load_graph: Any) -> None: assert not db.has_document("person/1") assert not db.has_document(edge_id) + G_1.nodes["person/2"]["object"] = {"foo": "bar", "bar": "foo"} + assert db.document("person/2")["object"] == {"foo": "bar", "bar": "foo"} + + G_1.nodes["person/2"]["object"]["foo"] = "baz" + assert db.document("person/2")["object"]["foo"] == "baz" + + del G_1.nodes["person/2"]["object"]["foo"] + assert "foo" not in db.document("person/2")["object"] + + G_1.nodes["person/2"]["object"].update({"sub_object": {"foo": "bar"}}) + assert db.document("person/2")["object"]["sub_object"]["foo"] == "bar" + + G_1.clear() + + assert G_1.nodes["person/2"]["object"]["sub_object"]["foo"] == "bar" + G_1.nodes["person/2"]["object"]["sub_object"]["foo"] = "baz" + assert db.document("person/2")["object"]["sub_object"]["foo"] == "baz" + def test_graph_edges_crud(load_graph: Any) -> None: G_1 = nxadb.Graph(graph_name="KarateGraph") @@ -372,6 +388,25 @@ def test_graph_edges_crud(load_graph: Any) -> None: assert G_1["person/1"]["person/2"]["weight"] == new_weight assert G_1["person/2"]["person/1"]["weight"] == new_weight + edge_id = G_1["person/1"]["person/2"]["_id"] + G_1["person/1"]["person/2"]["object"] = {"foo": "bar", "bar": "foo"} + assert db.document(edge_id)["object"] == {"foo": "bar", "bar": "foo"} + + G_1["person/1"]["person/2"]["object"]["foo"] = "baz" + assert db.document(edge_id)["object"]["foo"] == "baz" + + del G_1["person/1"]["person/2"]["object"]["foo"] + assert "foo" not in db.document(edge_id)["object"] + + G_1["person/1"]["person/2"]["object"].update({"sub_object": {"foo": "bar"}}) + assert db.document(edge_id)["object"]["sub_object"]["foo"] == "bar" + + G_1.clear() + + assert G_1["person/1"]["person/2"]["object"]["sub_object"]["foo"] == "bar" + G_1["person/1"]["person/2"]["object"]["sub_object"]["foo"] = "baz" + assert db.document(edge_id)["object"]["sub_object"]["foo"] == "baz" + def test_readme(load_graph: Any) -> None: G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") From 8c10c6e2c5b0c70fb837257a21728d94e5bb9eaa Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Thu, 4 Jul 2024 16:09:15 -0400 Subject: [PATCH 2/6] new: recursive `EdgeAttrDict` --- nx_arangodb/classes/dict.py | 111 +++++++++++++++++++++++------------- 1 file changed, 71 insertions(+), 40 deletions(-) diff --git a/nx_arangodb/classes/dict.py b/nx_arangodb/classes/dict.py index aadb3cbb..2ffc276c 100644 --- a/nx_arangodb/classes/dict.py +++ b/nx_arangodb/classes/dict.py @@ -249,6 +249,8 @@ class NodeAttrDict(UserDict[str, Any]): @logger_debug def __init__(self, db: StandardDatabase, graph: Graph, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.data: dict[str, Any] = {} self.db = db self.graph = graph @@ -258,9 +260,6 @@ def __init__(self, db: StandardDatabase, graph: Graph, *args: Any, **kwargs: Any # e.g G._node['node/1']['object']['foo'] = 'bar' # In this case, parent_keys would be ['object'] self.parent_keys: list[str] = [] - - super().__init__(*args, **kwargs) - self.data: dict[str, Any] = {} self.node_attr_dict_factory = node_attr_dict_factory(self.db, self.graph) @key_is_string @@ -336,7 +335,6 @@ def __delitem__(self, key: str) -> None: # # TODO: Revisit typing of return value # def values(self) -> Any: # """G._node['node/1'].values()""" - # breakpoint() # self.data = self.db.document(self.node_id) # yield from self.data.values() @@ -380,7 +378,7 @@ def update(self, attrs: Any) -> None: self.data.update(build_node_attr_dict_data(self, attrs)) if not self.node_id: - logger.warning("Node ID not set, skipping NodeAttrDict(?).update()") + logger.debug("Node ID not set, skipping NodeAttrDict(?).update()") return update_dict = get_update_dict(self.parent_keys, attrs) @@ -437,13 +435,13 @@ def __getitem__(self, key: str) -> NodeAttrDict: """G._node['node/1']""" node_id = get_node_id(key, self.default_node_type) - if value := self.data.get(node_id): - return value + if vertex := self.data.get(node_id): + return vertex - if value := self.graph.vertex(node_id): + if vertex := self.graph.vertex(node_id): node_attr_dict: NodeAttrDict = self.node_attr_dict_factory() node_attr_dict.node_id = node_id - node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, value) + node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, vertex) self.data[node_id] = node_attr_dict return node_attr_dict @@ -573,6 +571,36 @@ def __fetch_all(self): ############# +def process_edge_attr_dict_value(parent: EdgeAttrDict, key: str, value: Any) -> Any: + if not isinstance(value, dict): + return value + + edge_attr_dict = parent.edge_attr_dict_factory() + edge_attr_dict.edge_id = parent.edge_id + edge_attr_dict.parent_keys = parent.parent_keys + [key] + edge_attr_dict.data = build_edge_attr_dict_data(edge_attr_dict, value) + + return edge_attr_dict + + +def build_edge_attr_dict_data( + parent: EdgeAttrDict, data: dict[str, Any] +) -> dict[str, Any | EdgeAttrDict]: + """Recursively build an EdgeAttrDict from a dict. + + It's possible that **value** is a nested dict, so we need to + recursively build a EdgeAttrDict for each nested dict. + + Returns the parent EdgeAttrDict. + """ + edge_attr_dict_data = {} + for key, value in data.items(): + edge_attr_dict_value = process_edge_attr_dict_value(parent, key, value) + edge_attr_dict_data[key] = edge_attr_dict_value + + return edge_attr_dict_data + + class EdgeAttrDict(UserDict[str, Any]): """The innermost-level of the dict of dict of dict structure representing the Adjacency List of a graph. @@ -600,6 +628,12 @@ def __init__( self.graph = graph self.edge_id: str | None = None + # NodeAttrDict may be a child of another NodeAttrDict + # e.g G._adj['node/1']['node/2']['object']['foo'] = 'bar' + # In this case, parent_keys would be ['object'] + self.parent_keys: list[str] = [] + self.edge_attr_dict_factory = edge_attr_dict_factory(self.db, self.graph) + @key_is_string @logger_debug def __contains__(self, key: str) -> bool: @@ -623,9 +657,10 @@ def __getitem__(self, key: str) -> Any: if not result: raise KeyError(key) - self.data[key] = result + edge_attr_dict_value = process_edge_attr_dict_value(self, key, result) + self.data[key] = edge_attr_dict_value - return result + return edge_attr_dict_value @key_is_string @key_is_not_reserved @@ -634,8 +669,10 @@ def __getitem__(self, key: str) -> Any: def __setitem__(self, key: str, value: Any) -> None: """G._adj['node/1']['node/2']['foo'] = 'bar'""" assert self.edge_id - self.data[key] = value - self.data["_rev"] = doc_update(self.db, self.edge_id, {key: value}) + edge_attr_dict_value = process_edge_attr_dict_value(self, key, value) + update_dict = get_update_dict(self.parent_keys, {key: value}) + self.data[key] = edge_attr_dict_value + self.data["_rev"] = doc_update(self.db, self.edge_id, update_dict) @key_is_string @key_is_not_reserved @@ -644,7 +681,8 @@ def __delitem__(self, key: str) -> None: """del G._adj['node/1']['node/2']['foo']""" assert self.edge_id self.data.pop(key, None) - self.data["_rev"] = doc_update(self.db, self.edge_id, {key: None}) + update_dict = get_update_dict(self.parent_keys, {key: None}) + self.data["_rev"] = doc_update(self.db, self.edge_id, update_dict) # @logger_debug # def __iter__(self) -> Iterator[str]: @@ -691,14 +729,14 @@ def update(self, attrs: Any) -> None: if not attrs: return - self.data.update(attrs) + self.data.update(build_edge_attr_dict_data(self, attrs)) if not self.edge_id: - logger.warning("Edge ID not set, skipping EdgeAttrDict(?).update()") + logger.debug("Edge ID not set, skipping EdgeAttrDict(?).update()") return - assert self.edge_id - self.data["_rev"] = doc_update(self.db, self.edge_id, attrs) + update_dict = get_update_dict(self.parent_keys, attrs) + self.data["_rev"] = doc_update(self.db, self.edge_id, update_dict) class AdjListInnerDict(UserDict[str, EdgeAttrDict]): @@ -733,13 +771,12 @@ def __init__( self.db = db self.graph = graph - self.default_node_type = default_node_type self.edge_type_func = edge_type_func - self.adjlist_outer_dict = adjlist_outer_dict + self.default_node_type = default_node_type + self.edge_attr_dict_factory = edge_attr_dict_factory(self.db, self.graph) self.src_node_id: str | None = None - - self.edge_attr_dict_factory = edge_attr_dict_factory(self.db, self.graph) + self.adjlist_outer_dict = adjlist_outer_dict self.FETCHED_ALL_DATA = False @@ -786,16 +823,16 @@ def __contains__(self, key: str) -> bool: @logger_debug def __getitem__(self, key: str) -> EdgeAttrDict: """g._adj['node/1']['node/2']""" - assert self.src_node_id dst_node_id = get_node_id(key, self.default_node_type) - if dst_node_id in self.data: - return self.data[dst_node_id] + if edge := self.data.get(dst_node_id): + return edge - if mirrored_edge_attr_dict := self.__get_mirrored_edge_attr_dict(dst_node_id): - self.data[dst_node_id] = mirrored_edge_attr_dict - return mirrored_edge_attr_dict # type: ignore # false positive + if edge := self.__get_mirrored_edge_attr_dict(dst_node_id): + self.data[dst_node_id] = edge + return edge # type: ignore # false positive + assert self.src_node_id edge = aql_edge_get( self.db, self.src_node_id, @@ -809,8 +846,7 @@ def __getitem__(self, key: str) -> EdgeAttrDict: edge_attr_dict = self.edge_attr_dict_factory() edge_attr_dict.edge_id = edge["_id"] - edge_attr_dict.data = edge - + edge_attr_dict.data = build_edge_attr_dict_data(edge_attr_dict, edge) self.data[dst_node_id] = edge_attr_dict return edge_attr_dict @@ -825,8 +861,8 @@ def __setitem__(self, key: str, value: dict[str, Any] | EdgeAttrDict) -> None: src_node_type = self.src_node_id.split("/")[0] dst_node_type, dst_node_id = get_node_type_and_id(key, self.default_node_type) - if mirrored_edge_attr_dict := self.__get_mirrored_edge_attr_dict(dst_node_id): - self.data[dst_node_id] = mirrored_edge_attr_dict + if edge := self.__get_mirrored_edge_attr_dict(dst_node_id): + self.data[dst_node_id] = edge return edge_type = value.data.get("_edge_type") @@ -851,13 +887,8 @@ def __setitem__(self, key: str, value: dict[str, Any] | EdgeAttrDict) -> None: edge_attr_dict = self.edge_attr_dict_factory() edge_attr_dict.edge_id = edge["_id"] - edge_attr_dict.data = { - **edge_data, - **edge, - "_from": self.src_node_id, - "_to": dst_node_id, - } - + edge_data = {**edge_data, **edge, "_from": self.src_node_id, "_to": dst_node_id} + edge_attr_dict.data = build_edge_attr_dict_data(edge_attr_dict, edge_data) self.data[dst_node_id] = edge_attr_dict @key_is_string @@ -1013,8 +1044,8 @@ def __init__( self.db = db self.graph = graph - self.default_node_type = default_node_type self.edge_type_func = edge_type_func + self.default_node_type = default_node_type self.adjlist_inner_dict_factory = adjlist_inner_dict_factory( db, graph, default_node_type, edge_type_func, self ) From e62dbff2c1ee0ed47e6dfe8743ce755836872e6f Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Thu, 4 Jul 2024 16:44:22 -0400 Subject: [PATCH 3/6] fix: `nested_keys` param --- nx_arangodb/classes/function.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py index e94cf97c..f48e338b 100644 --- a/nx_arangodb/classes/function.py +++ b/nx_arangodb/classes/function.py @@ -215,17 +215,23 @@ def aql_single( return result[0] -def aql_doc_has_key(db: StandardDatabase, id: str, key: str) -> bool: +def aql_doc_has_key( + db: StandardDatabase, id: str, key: str, nested_keys: list[str] = [] +) -> bool: """Checks if a document has a key.""" - query = "RETURN HAS(DOCUMENT(@id), @key)" + nested_keys_str = "." + ".".join(nested_keys) if nested_keys else "" + query = f"RETURN HAS(DOCUMENT(@id){nested_keys_str}, @key)" bind_vars = {"id": id, "key": key} result = aql_single(db, query, bind_vars) return bool(result) if result is not None else False -def aql_doc_get_key(db: StandardDatabase, id: str, key: str) -> Any: +def aql_doc_get_key( + db: StandardDatabase, id: str, key: str, nested_keys: list[str] = [] +) -> Any: """Gets a key from a document.""" - query = "RETURN DOCUMENT(@id).@key" + nested_keys_str = "." + ".".join(nested_keys) if nested_keys else "" + query = f"RETURN DOCUMENT(@id){nested_keys_str}.@key" bind_vars = {"id": id, "key": key} return aql_single(db, query, bind_vars) From 0ee6580ba39497cfb27c4bf490396c4e0a4ee57a Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Thu, 4 Jul 2024 16:44:29 -0400 Subject: [PATCH 4/6] update tests --- tests/test.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test.py b/tests/test.py index fb48eb9e..0c9da34c 100644 --- a/tests/test.py +++ b/tests/test.py @@ -5,6 +5,7 @@ import pytest import nx_arangodb as nxadb +from nx_arangodb.classes.dict import EdgeAttrDict, NodeAttrDict from .conftest import db @@ -255,21 +256,28 @@ def test_graph_nodes_crud(load_graph: Any) -> None: assert not db.has_document(edge_id) G_1.nodes["person/2"]["object"] = {"foo": "bar", "bar": "foo"} + assert "_rev" not in G_1.nodes["person/2"]["object"] + assert isinstance(G_1.nodes["person/2"]["object"], NodeAttrDict) assert db.document("person/2")["object"] == {"foo": "bar", "bar": "foo"} G_1.nodes["person/2"]["object"]["foo"] = "baz" assert db.document("person/2")["object"]["foo"] == "baz" del G_1.nodes["person/2"]["object"]["foo"] + assert "_rev" not in G_1.nodes["person/2"]["object"] + assert isinstance(G_1.nodes["person/2"]["object"], NodeAttrDict) assert "foo" not in db.document("person/2")["object"] G_1.nodes["person/2"]["object"].update({"sub_object": {"foo": "bar"}}) + assert "_rev" not in G_1.nodes["person/2"]["object"]["sub_object"] + assert isinstance(G_1.nodes["person/2"]["object"]["sub_object"], NodeAttrDict) assert db.document("person/2")["object"]["sub_object"]["foo"] == "bar" G_1.clear() assert G_1.nodes["person/2"]["object"]["sub_object"]["foo"] == "bar" G_1.nodes["person/2"]["object"]["sub_object"]["foo"] = "baz" + assert "_rev" not in G_1.nodes["person/2"]["object"]["sub_object"] assert db.document("person/2")["object"]["sub_object"]["foo"] == "baz" @@ -390,21 +398,28 @@ def test_graph_edges_crud(load_graph: Any) -> None: edge_id = G_1["person/1"]["person/2"]["_id"] G_1["person/1"]["person/2"]["object"] = {"foo": "bar", "bar": "foo"} + assert "_rev" not in G_1["person/1"]["person/2"]["object"] + assert isinstance(G_1["person/1"]["person/2"]["object"], EdgeAttrDict) assert db.document(edge_id)["object"] == {"foo": "bar", "bar": "foo"} G_1["person/1"]["person/2"]["object"]["foo"] = "baz" assert db.document(edge_id)["object"]["foo"] == "baz" del G_1["person/1"]["person/2"]["object"]["foo"] + assert "_rev" not in G_1["person/1"]["person/2"]["object"] + assert isinstance(G_1["person/1"]["person/2"]["object"], EdgeAttrDict) assert "foo" not in db.document(edge_id)["object"] G_1["person/1"]["person/2"]["object"].update({"sub_object": {"foo": "bar"}}) + assert "_rev" not in G_1["person/1"]["person/2"]["object"]["sub_object"] + assert isinstance(G_1["person/1"]["person/2"]["object"]["sub_object"], EdgeAttrDict) assert db.document(edge_id)["object"]["sub_object"]["foo"] == "bar" G_1.clear() assert G_1["person/1"]["person/2"]["object"]["sub_object"]["foo"] == "bar" G_1["person/1"]["person/2"]["object"]["sub_object"]["foo"] = "baz" + assert "_rev" not in G_1["person/1"]["person/2"]["object"]["sub_object"] assert db.document(edge_id)["object"]["sub_object"]["foo"] == "baz" From 47bea2a0af6f5d71f4004bb8b786d909ad64ff63 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Thu, 4 Jul 2024 16:45:39 -0400 Subject: [PATCH 5/6] new: `AttrDict.root` --- nx_arangodb/classes/dict.py | 41 +++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/nx_arangodb/classes/dict.py b/nx_arangodb/classes/dict.py index 2ffc276c..764ab24f 100644 --- a/nx_arangodb/classes/dict.py +++ b/nx_arangodb/classes/dict.py @@ -202,6 +202,7 @@ def process_node_attr_dict_value(parent: NodeAttrDict, key: str, value: Any) -> return value node_attr_dict = parent.node_attr_dict_factory() + node_attr_dict.root = parent.root or parent node_attr_dict.node_id = parent.node_id node_attr_dict.parent_keys = parent.parent_keys + [key] node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, value) @@ -258,7 +259,9 @@ def __init__(self, db: StandardDatabase, graph: Graph, *args: Any, **kwargs: Any # NodeAttrDict may be a child of another NodeAttrDict # e.g G._node['node/1']['object']['foo'] = 'bar' - # In this case, parent_keys would be ['object'] + # In this case, **parent_keys** would be ['object'] + # and **root** would be G._node['node/1'] + self.root: NodeAttrDict | None = None self.parent_keys: list[str] = [] self.node_attr_dict_factory = node_attr_dict_factory(self.db, self.graph) @@ -270,7 +273,7 @@ def __contains__(self, key: str) -> bool: return True assert self.node_id - return aql_doc_has_key(self.db, self.node_id, key) + return aql_doc_has_key(self.db, self.node_id, key, self.parent_keys) @key_is_string @logger_debug @@ -280,7 +283,7 @@ def __getitem__(self, key: str) -> Any: return value assert self.node_id - result = aql_doc_get_key(self.db, self.node_id, key) + result = aql_doc_get_key(self.db, self.node_id, key, self.parent_keys) if not result: raise KeyError(key) @@ -304,7 +307,8 @@ def __setitem__(self, key: str, value: Any) -> None: node_attr_dict_value = process_node_attr_dict_value(self, key, value) update_dict = get_update_dict(self.parent_keys, {key: value}) self.data[key] = node_attr_dict_value - self.data["_rev"] = doc_update(self.db, self.node_id, update_dict) + root_data = self.root.data if self.root else self.data + root_data["_rev"] = doc_update(self.db, self.node_id, update_dict) @key_is_string @key_is_not_reserved @@ -314,7 +318,8 @@ def __delitem__(self, key: str) -> None: assert self.node_id self.data.pop(key, None) update_dict = get_update_dict(self.parent_keys, {key: None}) - self.data["_rev"] = doc_update(self.db, self.node_id, update_dict) + root_data = self.root.data if self.root else self.data + root_data["_rev"] = doc_update(self.db, self.node_id, update_dict) # @logger_debug # def __iter__(self) -> Iterator[str]: @@ -382,7 +387,8 @@ def update(self, attrs: Any) -> None: return update_dict = get_update_dict(self.parent_keys, attrs) - self.data["_rev"] = doc_update(self.db, self.node_id, update_dict) + root_data = self.root.data if self.root else self.data + root_data["_rev"] = doc_update(self.db, self.node_id, update_dict) class NodeDict(UserDict[str, NodeAttrDict]): @@ -576,6 +582,7 @@ def process_edge_attr_dict_value(parent: EdgeAttrDict, key: str, value: Any) -> return value edge_attr_dict = parent.edge_attr_dict_factory() + edge_attr_dict.root = parent.root or parent edge_attr_dict.edge_id = parent.edge_id edge_attr_dict.parent_keys = parent.parent_keys + [key] edge_attr_dict.data = build_edge_attr_dict_data(edge_attr_dict, value) @@ -591,7 +598,10 @@ def build_edge_attr_dict_data( It's possible that **value** is a nested dict, so we need to recursively build a EdgeAttrDict for each nested dict. - Returns the parent EdgeAttrDict. + :param parent: The parent EdgeAttrDict. + :type parent: EdgeAttrDict + :param data: The data to build the EdgeAttrDict from. + :type data: dict[str, Any] """ edge_attr_dict_data = {} for key, value in data.items(): @@ -630,7 +640,9 @@ def __init__( # NodeAttrDict may be a child of another NodeAttrDict # e.g G._adj['node/1']['node/2']['object']['foo'] = 'bar' - # In this case, parent_keys would be ['object'] + # In this case, **parent_keys** would be ['object'] + # and **root** would be G._adj['node/1']['node/2'] + self.root: EdgeAttrDict | None = None self.parent_keys: list[str] = [] self.edge_attr_dict_factory = edge_attr_dict_factory(self.db, self.graph) @@ -642,7 +654,7 @@ def __contains__(self, key: str) -> bool: return True assert self.edge_id - return aql_doc_has_key(self.db, self.edge_id, key) + return aql_doc_has_key(self.db, self.edge_id, key, self.parent_keys) @key_is_string @logger_debug @@ -652,7 +664,7 @@ def __getitem__(self, key: str) -> Any: return value assert self.edge_id - result = aql_doc_get_key(self.db, self.edge_id, key) + result = aql_doc_get_key(self.db, self.edge_id, key, self.parent_keys) if not result: raise KeyError(key) @@ -672,7 +684,8 @@ def __setitem__(self, key: str, value: Any) -> None: edge_attr_dict_value = process_edge_attr_dict_value(self, key, value) update_dict = get_update_dict(self.parent_keys, {key: value}) self.data[key] = edge_attr_dict_value - self.data["_rev"] = doc_update(self.db, self.edge_id, update_dict) + root_data = self.root.data if self.root else self.data + root_data["_rev"] = doc_update(self.db, self.edge_id, update_dict) @key_is_string @key_is_not_reserved @@ -682,7 +695,8 @@ def __delitem__(self, key: str) -> None: assert self.edge_id self.data.pop(key, None) update_dict = get_update_dict(self.parent_keys, {key: None}) - self.data["_rev"] = doc_update(self.db, self.edge_id, update_dict) + root_data = self.root.data if self.root else self.data + root_data["_rev"] = doc_update(self.db, self.edge_id, update_dict) # @logger_debug # def __iter__(self) -> Iterator[str]: @@ -736,7 +750,8 @@ def update(self, attrs: Any) -> None: return update_dict = get_update_dict(self.parent_keys, attrs) - self.data["_rev"] = doc_update(self.db, self.edge_id, update_dict) + root_data = self.root.data if self.root else self.data + root_data["_rev"] = doc_update(self.db, self.edge_id, update_dict) class AdjListInnerDict(UserDict[str, EdgeAttrDict]): From 5466bab8cb7f90a760fe6c26056afef45bedd054 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Thu, 4 Jul 2024 17:04:37 -0400 Subject: [PATCH 6/6] fix: `FETCHED_ALL_DATA` --- nx_arangodb/classes/dict.py | 78 +++++++++++++++++++++++-------------- 1 file changed, 49 insertions(+), 29 deletions(-) diff --git a/nx_arangodb/classes/dict.py b/nx_arangodb/classes/dict.py index 764ab24f..76071d91 100644 --- a/nx_arangodb/classes/dict.py +++ b/nx_arangodb/classes/dict.py @@ -183,13 +183,13 @@ def update(self, attrs: Any) -> None: self.data.update(attrs) self.data["_rev"] = doc_update(self.db, self.graph_id, attrs) - @logger_debug - def clear(self) -> None: - """G.graph.clear()""" - self.data.clear() + # @logger_debug + # def clear(self) -> None: + # """G.graph.clear()""" + # self.data.clear() - # if clear_remote: - # doc_insert(self.db, self.COLLECTION_NAME, self.graph_id, silent=True) + # # if clear_remote: + # # doc_insert(self.db, self.COLLECTION_NAME, self.graph_id, silent=True) ######## @@ -363,13 +363,13 @@ def __delitem__(self, key: str) -> None: # def push(): # pass - @logger_debug - def clear(self) -> None: - """G._node['node/1'].clear()""" - self.data.clear() + # @logger_debug + # def clear(self) -> None: + # """G._node['node/1'].clear()""" + # self.data.clear() - # if clear_remote: - # doc_insert(self.db, self.node_id, silent=True, overwrite=True) + # # if clear_remote: + # # doc_insert(self.db, self.node_id, silent=True, overwrite=True) @keys_are_strings @keys_are_not_reserved @@ -424,6 +424,8 @@ def __init__( self.default_node_type = default_node_type self.node_attr_dict_factory = node_attr_dict_factory(self.db, self.graph) + self.FETCHED_ALL_DATA = False + @key_is_string @logger_debug def __contains__(self, key: str) -> bool: @@ -516,13 +518,22 @@ def __len__(self) -> int: @logger_debug def __iter__(self) -> Iterator[str]: """iter(g._node)""" - for collection in self.graph.vertex_collections(): - yield from self.graph.vertex_collection(collection).ids() + if self.FETCHED_ALL_DATA: + yield from self.data.keys() + else: + for collection in self.graph.vertex_collections(): + yield from self.graph.vertex_collection(collection).ids() + + @logger_debug + def keys(self) -> Any: + """g._node.keys()""" + return self.__iter__() @logger_debug def clear(self) -> None: """g._node.clear()""" self.data.clear() + self.FETCHED_ALL_DATA = False # if clear_remote: # for collection in self.graph.vertex_collections(): @@ -534,16 +545,13 @@ def update(self, nodes: Any) -> None: """g._node.update({'node/1': {'foo': 'bar'}, 'node/2': {'baz': 'qux'}})""" raise NotImplementedError("NodeDict.update()") - @logger_debug - def keys(self) -> Any: - """g._node.keys()""" - return self.__iter__() - # TODO: Revisit typing of return value @logger_debug def values(self) -> Any: """g._node.values()""" - self.__fetch_all() + if not self.FETCHED_ALL_DATA: + self.__fetch_all() + yield from self.data.values() # TODO: Revisit typing of return value @@ -551,7 +559,9 @@ def values(self) -> Any: def items(self, data: str | None = None, default: Any | None = None) -> Any: """g._node.items() or G._node.items(data='foo')""" if data is None: - self.__fetch_all() + if not self.FETCHED_ALL_DATA: + self.__fetch_all() + yield from self.data.items() else: v_cols = list(self.graph.vertex_collections()) @@ -571,6 +581,8 @@ def __fetch_all(self): self.data[node_id] = node_attr_dict + self.FETCHED_ALL_DATA = True + ############# # Adjacency # @@ -730,10 +742,10 @@ def __delitem__(self, key: str) -> None: # self.data = self.db.document(self.edge_id) # yield from self.data.items() - @logger_debug - def clear(self) -> None: - """G._adj['node/1']['node/'2].clear()""" - self.data.clear() + # @logger_debug + # def clear(self) -> None: + # """G._adj['node/1']['node/'2].clear()""" + # self.data.clear() @keys_are_strings @keys_are_not_reserved @@ -994,14 +1006,18 @@ def update(self, edges: Any) -> None: @logger_debug def values(self) -> Any: """g._adj['node/1'].values()""" - self.__fetch_all() + if not self.FETCHED_ALL_DATA: + self.__fetch_all() + yield from self.data.values() # TODO: Revisit typing of return value @logger_debug def items(self) -> Any: """g._adj['node/1'].items()""" - self.__fetch_all() + if not self.FETCHED_ALL_DATA: + self.__fetch_all() + yield from self.data.items() @logger_debug @@ -1196,7 +1212,9 @@ def update(self, edges: Any) -> None: @logger_debug def values(self) -> Any: """g._adj.values()""" - self.__fetch_all() + if not self.FETCHED_ALL_DATA: + self.__fetch_all() + yield from self.data.values() # TODO: Revisit typing of return value @@ -1209,7 +1227,9 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any: # ): """g._adj.items() or G._adj.items(data='foo')""" if data is None: - self.__fetch_all() + if not self.FETCHED_ALL_DATA: + self.__fetch_all() + yield from self.data.items() else: