Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 154 additions & 25 deletions nx_arangodb/classes/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,30 @@

from __future__ import annotations

from collections import UserDict, defaultdict
import warnings
from collections import UserDict
from collections.abc import Iterator
from typing import Any, Callable, Generator
from typing import Any, Callable, Dict, List

from arango.database import StandardDatabase
from arango.exceptions import DocumentInsertError
from arango.graph import Graph

from nx_arangodb.logger import logger

from ..typing import AdjDict
from ..utils.arangodb import (
ArangoDBBatchError,
check_list_for_errors,
is_arangodb_id,
read_collection_name_from_local_id,
separate_edges_by_collections,
separate_nodes_by_collections,
upsert_collection_documents,
upsert_collection_edges,
)
from .function import (
aql,
aql_as_list,
aql_doc_get_key,
aql_doc_get_keys,
aql_doc_get_length,
aql_doc_has_key,
aql_edge_exists,
aql_edge_get,
Expand Down Expand Up @@ -496,11 +504,45 @@ def clear(self) -> None:
# for collection in self.graph.vertex_collections():
# self.graph.vertex_collection(collection).truncate()

@keys_are_strings
@logger_debug
def update_local_nodes(self, nodes: Any) -> None:
for node_id, node_data in nodes.items():
node_attr_dict = self.node_attr_dict_factory()
node_attr_dict.node_id = node_id
node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, node_data)

self.data[node_id] = node_attr_dict

@keys_are_strings
@logger_debug
def update(self, nodes: Any) -> None:
"""g._node.update({'node/1': {'foo': 'bar'}, 'node/2': {'baz': 'qux'}})"""
raise NotImplementedError("NodeDict.update()")
separated_by_collection = separate_nodes_by_collections(
nodes, self.default_node_type
)

result = upsert_collection_documents(self.db, separated_by_collection)

all_good = check_list_for_errors(result)
if all_good:
# Means no single operation failed, in this case we update the local cache
self.update_local_nodes(nodes)
else:
# In this case some or all documents failed. Right now we will not
# update the local cache, but raise an error instead.
# Reason: We cannot set silent to True, because we need as it does
# not report errors then. We need to update the driver to also pass
# the errors back to the user, then we can adjust the behavior here.
# This will also save network traffic and local computation time.
errors = []
for collections_results in result:
for collection_result in collections_results:
errors.append(collection_result)
warnings.warn(
"Failed to insert at least one node. Will not update local cache."
)
raise ArangoDBBatchError(errors)

# TODO: Revisit typing of return value
@logger_debug
Expand Down Expand Up @@ -614,7 +656,7 @@ def __init__(
self.graph = graph
self.edge_id: str | None = None

# NodeAttrDict may be a child of another NodeAttrDict
# EdgeAttrDict may be a child of another EdgeAttrDict
# e.g G._adj['node/1']['node/2']['object']['foo'] = 'bar'
# In this case, **parent_keys** would be ['object']
# and **root** would be G._adj['node/1']['node/2']
Expand Down Expand Up @@ -933,7 +975,62 @@ def clear(self) -> None:
@logger_debug
def update(self, edges: Any) -> None:
"""g._adj['node/1'].update({'node/2': {'foo': 'bar'}})"""
raise NotImplementedError("AdjListInnerDict.update()")
from_col_name = read_collection_name_from_local_id(
self.src_node_id, self.default_node_type
)

to_upsert: Dict[str, List[Dict[str, Any]]] = {from_col_name: []}

for edge_id, edge_data in edges.items():
edge_doc = edge_data
edge_doc["_from"] = self.src_node_id
edge_doc["_to"] = edge_id

edge_doc_id = edge_data.get("_id")
# TODO: @Anthony please check if the implementation is correct
# of the default for edge_type_func, which is right now:
# * edge_type_func: Callable[[str, str], str] = lambda u, v: f"{u}_to_{v}",
#
# How does that help to identify the edge's collection name?
# The below implementation I wanted to use but returns in my example:
# "person/9_to_person/34" which is not a valid or requested collection name.
#
# edge_type = edge_data.get("_edge_type")
# if edge_type is None:
# edge_type = self.edge_type_func(self.src_node_id, edge_id)
#
# -> Therefore right now I need to assume that this is always a
# valid ArangoDB document ID
assert is_arangodb_id(edge_doc_id)
edge_col_name = read_collection_name_from_local_id(edge_doc_id, "")

if to_upsert.get(edge_col_name) is None:
to_upsert[edge_col_name] = [edge_doc]
else:
to_upsert[edge_col_name].append(edge_doc)

# perform write to ArangoDB
result = upsert_collection_edges(self.db, to_upsert)

all_good = check_list_for_errors(result)
if all_good:
# Means no single operation failed, in this case we update the local cache
self.__set_adj_elements(edges)
else:
# In this case some or all documents failed. Right now we will not
# update the local cache, but raise an error instead.
# Reason: We cannot set silent to True, because we need as it does
# not report errors then. We need to update the driver to also pass
# the errors back to the user, then we can adjust the behavior here.
# This will also save network traffic and local computation time.
errors = []
for collections_results in result:
for collection_result in collections_results:
errors.append(collection_result)
warnings.warn(
"Failed to insert at least one node. Will not update local cache."
)
raise ArangoDBBatchError(errors)

# TODO: Revisit typing of return value
@logger_debug
Expand Down Expand Up @@ -974,6 +1071,14 @@ def __fetch_all(self) -> None:

self.FETCHED_ALL_DATA = True

def __set_adj_elements(self, edges):
for dst_node_id, edge in edges.items():
# Copied from above, from __fetch_all
edge_attr_dict = self.edge_attr_dict_factory()
edge_attr_dict.edge_id = edge["_id"]
edge_attr_dict.data = build_edge_attr_dict_data(edge_attr_dict, edge)
self.data[edge["_to"]] = edge_attr_dict


class AdjListOuterDict(UserDict[str, AdjListInnerDict]):
"""The outer-level of the dict of dict of dict structure
Expand Down Expand Up @@ -1138,8 +1243,29 @@ def clear(self) -> None:
@keys_are_strings
@logger_debug
def update(self, edges: Any) -> None:
"""g._adj.update({'node/1': {'node/2': {'foo': 'bar'}})"""
raise NotImplementedError("AdjListOuterDict.update()")
"""g._adj.update({'node/1': {'node/2': {'_id': 'foo/bar', 'foo': "bar"}})"""
separated_by_edge_collection = separate_edges_by_collections(edges)
result = upsert_collection_edges(self.db, separated_by_edge_collection)

all_good = check_list_for_errors(result)
if all_good:
# Means no single operation failed, in this case we update the local cache
self.__set_adj_elements(edges)
else:
# In this case some or all documents failed. Right now we will not
# update the local cache, but raise an error instead.
# Reason: We cannot set silent to True, because we need as it does
# not report errors then. We need to update the driver to also pass
# the errors back to the user, then we can adjust the behavior here.
# This will also save network traffic and local computation time.
errors = []
for collections_results in result:
for collection_result in collections_results:
errors.append(collection_result)
warnings.warn(
"Failed to insert at least one node. Will not update local cache."
)
raise ArangoDBBatchError(errors)

# TODO: Revisit typing of return value
@logger_debug
Expand Down Expand Up @@ -1171,25 +1297,15 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any:
yield from result

@logger_debug
def __fetch_all(self) -> None:
self.clear()

_, adj_dict, _, _, _ = get_arangodb_graph(
self.graph,
load_node_dict=False,
load_adj_dict=True,
is_directed=False, # TODO: Abstract based on Graph type
is_multigraph=False, # TODO: Abstract based on Graph type
load_coo=False,
)

for src_node_id, inner_dict in adj_dict.items():
def __set_adj_elements(self, edges: AdjDict) -> None:
for src_node_id, inner_dict in edges.items():
for dst_node_id, edge in inner_dict.items():

if src_node_id in self.data:
if dst_node_id in self.data[src_node_id].data:
continue

# TODO: Clean up those two if/else statements later
if src_node_id in self.data:
src_inner_dict = self.data[src_node_id]
else:
Expand All @@ -1209,8 +1325,21 @@ def __fetch_all(self) -> None:
edge_attr_dict = src_inner_dict.edge_attr_dict_factory()
edge_attr_dict.edge_id = edge["_id"]
edge_attr_dict.data = build_edge_attr_dict_data(edge_attr_dict, edge)

self.data[src_node_id].data[dst_node_id] = edge_attr_dict
self.data[dst_node_id].data[src_node_id] = edge_attr_dict

@logger_debug
def __fetch_all(self) -> None:
self.clear()

_, adj_dict, _, _, _ = get_arangodb_graph(
self.graph,
load_node_dict=False,
load_adj_dict=True,
is_directed=False, # TODO: Abstract based on Graph type
is_multigraph=False, # TODO: Abstract based on Graph type
load_coo=False,
)

self.__set_adj_elements(adj_dict)
self.FETCHED_ALL_DATA = True
5 changes: 3 additions & 2 deletions nx_arangodb/classes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
GraphDoesNotExist,
InvalidTraversalDirection,
)
from ..typing import AdjDict


def get_arangodb_graph(
Expand All @@ -35,7 +36,7 @@ def get_arangodb_graph(
load_coo: bool,
) -> Tuple[
dict[str, dict[str, Any]],
dict[str, dict[str, dict[str, Any]]],
AdjDict,
npt.NDArray[np.int64],
npt.NDArray[np.int64],
dict[str, int],
Expand Down Expand Up @@ -152,7 +153,7 @@ def wrapper(
return wrapper


RESERVED_KEYS = {"_id", "_key", "_rev"}
RESERVED_KEYS = {"_id", "_key", "_rev", "_from", "_to"}


def key_is_not_reserved(func: Callable[..., Any]) -> Any:
Expand Down
39 changes: 36 additions & 3 deletions nx_arangodb/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@
from __future__ import annotations

from collections.abc import Hashable
from typing import TypeVar
from typing import Any, Dict, TypeVar

import cupy as cp
import numpy as np
import numpy.typing as npt

from nx_arangodb.logger import logger

try:
import cupy as cp
except ModuleNotFoundError as e:
GPU_ENABLED = False
logger.info(f"NXCG is disabled. {e}.")


AttrKey = TypeVar("AttrKey", bound=Hashable)
EdgeKey = TypeVar("EdgeKey", bound=Hashable)
NodeKey = TypeVar("NodeKey", bound=Hashable)
Expand All @@ -18,6 +25,32 @@
IndexValue = TypeVar("IndexValue")
Dtype = TypeVar("Dtype")

# AdjDict is a dictionary of dictionaries of dictionaries
# The outer dict is holding _from_id(s) as keys
# - It may or may not hold valid ArangoDB document _id(s)
# The inner dict is holding _to_id(s) as keys
# - It may or may not hold valid ArangoDB document _id(s)
# The next inner dict contains then the actual edges data (key, val)
# Example
# {
# 'person/1': {
# 'person/32': {
# '_id': 'knows/16',
# 'extraValue': '16'
# },
# 'person/33': {
# '_id': 'knows/17',
# 'extraValue': '17'
# }
# ...
# }
# ...
# }
# The above example is a graph with 2 edges from person/1 to person/32 and person/33
AdjDictEdge = Dict[str, Any]
AdjDictInner = Dict[str, AdjDictEdge]
AdjDict = Dict[str, AdjDictInner]


class any_ndarray:
def __class_getitem__(cls, item):
Expand Down
Loading