Skip to content

added ability to load edge attrs. #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions nx_arangodb/classes/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
from .enum import DIRECTED_GRAPH_TYPES, MULTIGRAPH_TYPES, GraphType, TraversalDirection
from .function import (
aql,
aql_as_list,
aql_doc_get_key,
aql_doc_get_keys,
aql_doc_get_length,
aql_doc_has_key,
aql_edge_count_src,
aql_edge_count_src_dst,
Expand All @@ -33,7 +30,6 @@
aql_edge_id,
aql_fetch_data,
aql_fetch_data_edge,
aql_single,
create_collection,
doc_delete,
doc_get_or_insert,
Expand All @@ -45,7 +41,6 @@
get_update_dict,
json_serializable,
key_is_adb_id_or_int,
key_is_int,
key_is_not_reserved,
key_is_string,
keys_are_not_reserved,
Expand Down Expand Up @@ -752,6 +747,7 @@ def _fetch_all(self):
load_node_dict=True,
load_adj_dict=False,
load_coo=False,
edge_collections_attributes=set(),
load_all_vertex_attributes=True,
load_all_edge_attributes=False, # not used
is_directed=False, # not used
Expand Down Expand Up @@ -2254,6 +2250,7 @@ def set_edge_multigraph(
load_node_dict=False,
load_adj_dict=True,
load_coo=False,
edge_collections_attributes=set(),
load_all_vertex_attributes=False, # not used
load_all_edge_attributes=True,
is_directed=self.is_directed,
Expand Down
2 changes: 2 additions & 0 deletions nx_arangodb/classes/digraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
default_node_type: str | None = None,
edge_type_key: str = "_edge_type",
edge_type_func: Callable[[str, str], str] | None = None,
edge_collections_attributes: set[str] | None = None,
db: StandardDatabase | None = None,
read_parallelism: int = 10,
read_batch_size: int = 100000,
Expand All @@ -41,6 +42,7 @@ def __init__(
default_node_type,
edge_type_key,
edge_type_func,
edge_collections_attributes,
db,
read_parallelism,
read_batch_size,
Expand Down
30 changes: 28 additions & 2 deletions nx_arangodb/classes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
DiGraphAdjDict,
DstIndices,
EdgeIndices,
EdgeValuesDict,
GraphAdjDict,
MultiDiGraphAdjDict,
MultiGraphAdjDict,
Expand All @@ -38,11 +39,19 @@
)


def do_load_all_edge_attributes(attributes: set[str]) -> bool:
if len(attributes) == 0:
return True

return False


def get_arangodb_graph(
adb_graph: Graph,
load_node_dict: bool,
load_adj_dict: bool,
load_coo: bool,
edge_collections_attributes: set[str],
load_all_vertex_attributes: bool,
load_all_edge_attributes: bool,
is_directed: bool,
Expand All @@ -55,6 +64,7 @@ def get_arangodb_graph(
DstIndices,
EdgeIndices,
ArangoIDtoIndex,
EdgeValuesDict,
]:
"""Pulls the graph from the database, assuming the graph exists.

Expand All @@ -71,7 +81,7 @@ def get_arangodb_graph(

metagraph: dict[str, dict[str, Any]] = {
"vertexCollections": {col: set() for col in v_cols},
"edgeCollections": {col: set() for col in e_cols},
"edgeCollections": {col: edge_collections_attributes for col in e_cols},
}

if not any((load_node_dict, load_adj_dict, load_coo)):
Expand All @@ -89,6 +99,21 @@ def get_arangodb_graph(
assert config.username
assert config.password

res_do_load_all_edge_attributes = do_load_all_edge_attributes(
edge_collections_attributes
)

if res_do_load_all_edge_attributes is not load_all_edge_attributes:
if len(edge_collections_attributes) > 0:
raise ValueError(
"You have specified to load at least one specific edge attribute"
" and at the same time set the parameter `load_all_vertex_attributes`"
" to true. This combination is not allowed."
)
else:
# We need this case as the user wants by purpose to not load any edge data
res_do_load_all_edge_attributes = load_all_edge_attributes

(
node_dict,
adj_dict,
Expand All @@ -106,7 +131,7 @@ def get_arangodb_graph(
load_adj_dict=load_adj_dict,
load_coo=load_coo,
load_all_vertex_attributes=load_all_vertex_attributes,
load_all_edge_attributes=load_all_edge_attributes,
load_all_edge_attributes=res_do_load_all_edge_attributes,
is_directed=is_directed,
is_multigraph=is_multigraph,
symmetrize_edges_if_directed=symmetrize_edges_if_directed,
Expand All @@ -121,6 +146,7 @@ def get_arangodb_graph(
dst_indices,
edge_indices,
vertex_ids_to_index,
edge_values,
)


Expand Down
19 changes: 19 additions & 0 deletions nx_arangodb/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
default_node_type: str | None = None,
edge_type_key: str = "_edge_type",
edge_type_func: Callable[[str, str], str] | None = None,
edge_collections_attributes: set[str] | None = None,
db: StandardDatabase | None = None,
read_parallelism: int = 10,
read_batch_size: int = 100000,
Expand All @@ -69,6 +70,8 @@ def __init__(
self.read_batch_size = read_batch_size
self.write_batch_size = write_batch_size

self._set_edge_collections_attributes_to_fetch(edge_collections_attributes)

# NOTE: Need to revisit these...
# self.maintain_node_dict_cache = False
# self.maintain_adj_dict_cache = False
Expand All @@ -80,6 +83,7 @@ def __init__(
self.dst_indices: npt.NDArray[np.int64] | None = None
self.edge_indices: npt.NDArray[np.int64] | None = None
self.vertex_ids_to_index: dict[str, int] | None = None
self.edge_values: dict[str, list[int | float]] | None = None

# Does not apply to undirected graphs
self.symmetrize_edges = symmetrize_edges
Expand Down Expand Up @@ -236,6 +240,17 @@ def _set_factory_methods(self) -> None:
*adj_args, self.symmetrize_edges
)

def _set_edge_collections_attributes_to_fetch(
self, attributes: set[str] | None
) -> None:
if attributes is None:
self._edge_collections_attributes = set()
return
if len(attributes) > 0:
self._edge_collections_attributes = attributes
if "_id" not in attributes:
self._edge_collections_attributes.add("_id")

###########
# Getters #
###########
Expand All @@ -258,6 +273,10 @@ def graph_name(self) -> str:
def graph_exists_in_db(self) -> bool:
return self._graph_exists_in_db

@property
def get_edge_attributes(self) -> set[str]:
return self._edge_collections_attributes

###########
# Setters #
###########
Expand Down
2 changes: 2 additions & 0 deletions nx_arangodb/classes/multidigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
default_node_type: str | None = None,
edge_type_key: str = "_edge_type",
edge_type_func: Callable[[str, str], str] | None = None,
edge_collections_attributes: set[str] | None = None,
db: StandardDatabase | None = None,
read_parallelism: int = 10,
read_batch_size: int = 100000,
Expand All @@ -40,6 +41,7 @@ def __init__(
default_node_type,
edge_type_key,
edge_type_func,
edge_collections_attributes,
db,
read_parallelism,
read_batch_size,
Expand Down
2 changes: 2 additions & 0 deletions nx_arangodb/classes/multigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
default_node_type: str | None = None,
edge_type_key: str = "_edge_type",
edge_type_func: Callable[[str, str], str] | None = None,
edge_collections_attributes: set[str] | None = None,
db: StandardDatabase | None = None,
read_parallelism: int = 10,
read_batch_size: int = 100000,
Expand All @@ -40,6 +41,7 @@ def __init__(
default_node_type,
edge_type_key,
edge_type_func,
edge_collections_attributes,
db,
read_parallelism,
read_batch_size,
Expand Down
48 changes: 30 additions & 18 deletions nx_arangodb/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import networkx as nx

import nx_arangodb as nxadb
from nx_arangodb.classes.function import do_load_all_edge_attributes
from nx_arangodb.logger import logger

try:
Expand Down Expand Up @@ -126,9 +127,9 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph:
load_node_dict=True,
load_adj_dict=True,
load_coo=False,
edge_collections_attributes=G.get_edge_attributes,
load_all_vertex_attributes=False,
# TODO: Only return the edge attributes that are needed
load_all_edge_attributes=True,
load_all_edge_attributes=do_load_all_edge_attributes(G.get_edge_attributes),
is_directed=G.is_directed(),
is_multigraph=G.is_multigraph(),
symmetrize_edges_if_directed=G.symmetrize_edges if G.is_directed() else False,
Expand Down Expand Up @@ -158,27 +159,37 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
and G.dst_indices is not None
and G.edge_indices is not None
and G.vertex_ids_to_index is not None
and G.edge_values is not None
):
m = "**use_coo_cache** is enabled. using cached COO data. no pull required."
logger.debug(m)

else:
start_time = time.time()

_, _, src_indices, dst_indices, edge_indices, vertex_ids_to_index = (
nxadb.classes.function.get_arangodb_graph(
adb_graph=G.adb_graph,
load_node_dict=False,
load_adj_dict=False,
load_coo=True,
load_all_vertex_attributes=False, # not used
load_all_edge_attributes=False, # not used
is_directed=G.is_directed(),
is_multigraph=G.is_multigraph(),
symmetrize_edges_if_directed=(
G.symmetrize_edges if G.is_directed() else False
),
)
(
_,
_,
src_indices,
dst_indices,
edge_indices,
vertex_ids_to_index,
edge_values,
) = nxadb.classes.function.get_arangodb_graph(
adb_graph=G.adb_graph,
load_node_dict=False,
load_adj_dict=False,
load_coo=True,
edge_collections_attributes=G.get_edge_attributes,
load_all_vertex_attributes=False, # not used
load_all_edge_attributes=do_load_all_edge_attributes(
G.get_edge_attributes
),
is_directed=G.is_directed(),
is_multigraph=G.is_multigraph(),
symmetrize_edges_if_directed=(
G.symmetrize_edges if G.is_directed() else False
),
)

print(f"ADB -> COO load took {time.time() - start_time}s")
Expand All @@ -187,6 +198,7 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
G.dst_indices = dst_indices
G.edge_indices = edge_indices
G.vertex_ids_to_index = vertex_ids_to_index
G.edge_values = edge_values

N = len(G.vertex_ids_to_index) # type: ignore
src_indices_cp = cp.array(G.src_indices)
Expand All @@ -204,7 +216,7 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
src_indices=src_indices_cp,
dst_indices=dst_indices_cp,
edge_indices=edge_indices_cp,
# edge_values,
edge_values=G.edge_values,
# edge_masks,
# node_values,
# node_masks,
Expand All @@ -222,7 +234,7 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
N=N,
src_indices=src_indices_cp,
dst_indices=dst_indices_cp,
# edge_values,
edge_values=G.edge_values,
# edge_masks,
# node_values,
# node_masks,
Expand Down
Loading