Skip to content
203 changes: 185 additions & 18 deletions nx_arangodb/classes/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import json
from collections import UserDict, defaultdict
from collections.abc import Iterator
from typing import Any, Callable, Generator
Expand Down Expand Up @@ -36,6 +37,7 @@
get_arangodb_graph,
get_node_id,
get_node_type_and_id,
json_serializable,
key_is_not_reserved,
key_is_string,
keys_are_not_reserved,
Expand All @@ -54,6 +56,12 @@ def graph_dict_factory(
return lambda: GraphDict(db, graph_name)


def graph_attr_dict_factory(
db: StandardDatabase, graph: Graph, graph_id: str
) -> Callable[..., GraphAttrDict]:
return lambda: GraphAttrDict(db, graph, graph_id)


def node_dict_factory(
db: StandardDatabase, graph: Graph, default_node_type: str
) -> Callable[..., NodeDict]:
Expand Down Expand Up @@ -98,6 +106,36 @@ def edge_attr_dict_factory(
#########


def build_graph_attr_dict_data(
parent: GraphAttrDict, data: dict[str, Any]
) -> dict[str, Any | GraphAttrDict]:
"""Recursively build a GraphAttrDict from a dict.

It's possible that **value** is a nested dict, so we need to
recursively build a GraphAttrDict for each nested dict.

Returns the parent GraphAttrDict.
"""
graph_attr_dict_data = {}
for key, value in data.items():
graph_attr_dict_value = process_graph_attr_dict_value(parent, key, value)
graph_attr_dict_data[key] = graph_attr_dict_value

return graph_attr_dict_data


def process_graph_attr_dict_value(parent: GraphAttrDict, key: str, value: Any) -> Any:
if not isinstance(value, dict):
return value

graph_attr_dict = parent.graph_attr_dict_factory()
graph_attr_dict.root = parent.root or parent
graph_attr_dict.parent_keys = parent.parent_keys + [key]
graph_attr_dict.data = build_graph_attr_dict_data(graph_attr_dict, value)

return graph_attr_dict


class GraphDict(UserDict[str, Any]):
"""A dictionary-like object for storing graph attributes.

Expand All @@ -110,8 +148,6 @@ class GraphDict(UserDict[str, Any]):
:type graph_name: str
"""

COLLECTION_NAME = "nxadb_graphs"

@logger_debug
def __init__(
self, db: StandardDatabase, graph_name: str, *args: Any, **kwargs: Any
Expand All @@ -121,13 +157,28 @@ def __init__(

self.db = db
self.graph_name = graph_name
self.COLLECTION_NAME = "nxadb_graphs"
self.graph_id = f"{self.COLLECTION_NAME}/{graph_name}"

self.adb_graph = db.graph(graph_name)
self.collection = create_collection(db, self.COLLECTION_NAME)
self.graph_attr_dict_factory = graph_attr_dict_factory(
self.db, self.adb_graph, self.graph_id
)

result = doc_get_or_insert(self.db, self.COLLECTION_NAME, self.graph_id)
for k, v in result.items():
self.data[k] = self.__process_graph_dict_value(k, v)

def __process_graph_dict_value(self, key: str, value: Any) -> Any:
if not isinstance(value, dict):
return value

data = doc_get_or_insert(self.db, self.COLLECTION_NAME, self.graph_id)
self.data.update(data)
graph_attr_dict = self.graph_attr_dict_factory()
graph_attr_dict.parent_keys = [key]
graph_attr_dict.data = build_graph_attr_dict_data(graph_attr_dict, value)

return graph_attr_dict

@key_is_string
@logger_debug
Expand All @@ -148,20 +199,25 @@ def __getitem__(self, key: str) -> Any:

result = aql_doc_get_key(self.db, self.graph_id, key)

if not result:
if result is None:
raise KeyError(key)

self.data[key] = result
graph_dict_value = self.__process_graph_dict_value(key, result)
self.data[key] = graph_dict_value

return result
return graph_dict_value

@key_is_string
@key_is_not_reserved
@logger_debug
# @value_is_json_serializable # TODO?
def __setitem__(self, key: str, value: Any) -> None:
"""G.graph['foo'] = 'bar'"""
self.data[key] = value
if value is None:
self.__delitem__(key)
return

graph_dict_value = self.__process_graph_dict_value(key, value)
self.data[key] = graph_dict_value
self.data["_rev"] = doc_update(self.db, self.graph_id, {key: value})

@key_is_string
Expand All @@ -172,25 +228,128 @@ def __delitem__(self, key: str) -> None:
self.data.pop(key, None)
self.data["_rev"] = doc_update(self.db, self.graph_id, {key: None})

@keys_are_strings
@keys_are_not_reserved
# @values_are_json_serializable # TODO?
@logger_debug
def update(self, attrs: Any) -> None:
"""G.graph.update({'foo': 'bar'})"""

if not attrs:
return

self.data.update(attrs)
graph_attr_dict = self.graph_attr_dict_factory()
graph_attr_dict_data = build_graph_attr_dict_data(graph_attr_dict, attrs)
graph_attr_dict.data = graph_attr_dict_data

self.data.update(graph_attr_dict_data)
self.data["_rev"] = doc_update(self.db, self.graph_id, attrs)

# @logger_debug
# def clear(self) -> None:
# """G.graph.clear()"""
# self.data.clear()
@logger_debug
def clear(self) -> None:
"""G.graph.clear()"""
self.data.clear()


@json_serializable
class GraphAttrDict(UserDict[str, Any]):
"""The inner-level of the dict of dict structure
representing the attributes of a graph stored in the database.

Only used if the value associated with a GraphDict key is a dict.

:param db: The ArangoDB database.
:type db: StandardDatabase
:param graph: The ArangoDB graph.
:type graph: Graph
:param graph_id: The ArangoDB graph ID.
:type graph_id: str
"""

@logger_debug
def __init__(
self,
db: StandardDatabase,
graph: Graph,
graph_id: str,
*args: Any,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
self.data: dict[str, Any] = {}

self.db = db
self.graph = graph
self.graph_id: str = graph_id

self.root: GraphAttrDict | None = None
self.parent_keys: list[str] = []
self.graph_attr_dict_factory = graph_attr_dict_factory(
self.db, self.graph, self.graph_id
)

# # if clear_remote:
# # doc_insert(self.db, self.COLLECTION_NAME, self.graph_id, silent=True)
@key_is_string
@logger_debug
def __contains__(self, key: str) -> bool:
"""'bar' in G.graph['foo']"""
if key in self.data:
return True

return aql_doc_has_key(self.db, self.graph.name, key)

@key_is_string
@logger_debug
def __getitem__(self, key: str) -> Any:
"""G.graph['foo']['bar']"""

if value := self.data.get(key):
return value

result = aql_doc_get_key(self.db, self.graph_id, key, self.parent_keys)

if result is None:
raise KeyError(key)

graph_attr_dict_value = process_graph_attr_dict_value(self, key, result)
self.data[key] = graph_attr_dict_value

return graph_attr_dict_value

@key_is_string
@logger_debug
def __setitem__(self, key, value):
"""
G.graph['foo'] = 'bar'
G.graph['object'] = {'foo': 'bar'}
G._node['object']['foo'] = 'baz'
"""
if value is None:
self.__delitem__(key)
return

graph_attr_dict_value = process_graph_attr_dict_value(self, key, value)
update_dict = get_update_dict(self.parent_keys, {key: value})
self.data[key] = graph_attr_dict_value
root_data = self.root.data if self.root else self.data
root_data["_rev"] = doc_update(self.db, self.graph_id, update_dict)

@key_is_string
@logger_debug
def __delitem__(self, key):
"""del G.graph['foo']['bar']"""
self.data.pop(key, None)
update_dict = get_update_dict(self.parent_keys, {key: None})
root_data = self.root.data if self.root else self.data
root_data["_rev"] = doc_update(self.db, self.graph_id, update_dict)

@logger_debug
def update(self, attrs: Any) -> None:
"""G.graph['foo'].update({'bar': 'baz'})"""
if not attrs:
return

self.data.update(build_graph_attr_dict_data(self, attrs))
updated_dict = get_update_dict(self.parent_keys, attrs)
root_data = self.root.data if self.root else self.data
root_data["_rev"] = doc_update(self.db, self.graph_id, updated_dict)


########
Expand Down Expand Up @@ -304,6 +463,10 @@ def __setitem__(self, key: str, value: Any) -> None:
G._node['node/1']['object'] = {'foo': 'bar'}
G._node['node/1']['object']['foo'] = 'baz'
"""
if value is None:
self.__delitem__(key)
return

assert self.node_id
node_attr_dict_value = process_node_attr_dict_value(self, key, value)
update_dict = get_update_dict(self.parent_keys, {key: value})
Expand Down Expand Up @@ -656,6 +819,10 @@ def __getitem__(self, key: str) -> Any:
@logger_debug
def __setitem__(self, key: str, value: Any) -> None:
"""G._adj['node/1']['node/2']['foo'] = 'bar'"""
if value is None:
self.__delitem__(key)
return

assert self.edge_id
edge_attr_dict_value = process_edge_attr_dict_value(self, key, value)
update_dict = get_update_dict(self.parent_keys, {key: value})
Expand Down
11 changes: 11 additions & 0 deletions nx_arangodb/classes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,17 @@ def get_arangodb_graph(
)


def json_serializable(cls):
def to_dict(self):
return {
key: (value.to_dict() if isinstance(value, cls) else value)
for key, value in self.items()
}

cls.to_dict = to_dict
return cls


def key_is_string(func: Callable[..., Any]) -> Any:
"""Decorator to check if the key is a string."""

Expand Down
Loading