Skip to content

Add batch_size parameter in import_bulk method #207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions arango/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from arango.response import Response
from arango.result import Result
from arango.typings import Fields, Headers, Json, Params
from arango.utils import get_doc_id, is_none_or_int, is_none_or_str
from arango.utils import get_batches, get_doc_id, is_none_or_int, is_none_or_str


class Collection(ApiGroup):
Expand Down Expand Up @@ -1934,7 +1934,8 @@ def import_bulk(
overwrite: Optional[bool] = None,
on_duplicate: Optional[str] = None,
sync: Optional[bool] = None,
) -> Result[Json]:
batch_size: Optional[int] = None,
) -> Union[Result[Json], List[Result[Json]]]:
"""Insert multiple documents into the collection.

.. note::
Expand Down Expand Up @@ -1984,6 +1985,9 @@ def import_bulk(
:type on_duplicate: str
:param sync: Block until operation is synchronized to disk.
:type sync: bool | None
:param batch_size: Max number of documents to import at once. If
unspecified, will import all documents at once.
:type batch_size: int | None
:return: Result of the bulk import.
:rtype: dict
:raise arango.exceptions.DocumentInsertError: If import fails.
Expand All @@ -2006,21 +2010,25 @@ def import_bulk(
if sync is not None:
params["waitForSync"] = sync

request = Request(
method="post",
endpoint="/_api/import",
data=documents,
params=params,
write=self.name,
)

def response_handler(resp: Response) -> Json:
if resp.is_success:
result: Json = resp.body
return result
raise DocumentInsertError(resp, request)

return self._execute(request, response_handler)
results = []
for batch in get_batches(documents, batch_size):
request = Request(
method="post",
endpoint="/_api/import",
data=batch,
params=params,
write=self.name,
)

results.append(self._execute(request, response_handler))

return results[0] if len(results) == 1 else results


class StandardCollection(Collection):
Expand Down
48 changes: 44 additions & 4 deletions arango/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,7 @@ def create_graph(
shard_count: Optional[int] = None,
replication_factor: Optional[int] = None,
write_concern: Optional[int] = None,
collections: Optional[Json] = None,
) -> Result[Graph]:
"""Create a new graph.

Expand Down Expand Up @@ -1217,18 +1218,49 @@ def create_graph(
parameter cannot be larger than that of **replication_factor**.
Default value is 1. Used for clusters only.
:type write_concern: int
:param collections: A list collection data objects to provision
the graph with. See below for example.
:type collections: dict | None
:return: Graph API wrapper.
:rtype: arango.graph.Graph
:raise arango.exceptions.GraphCreateError: If create fails.

Here is an example entry for parameter **edge_definitions**:

.. code-block:: python

[
{
'edge_collection': 'teach',
'from_vertex_collections': ['teachers'],
'to_vertex_collections': ['lectures']
}
]

Here is an example entry for parameter **collections**:
TODO: Rework **collections** data structure?
.. code-block:: python

{
'edge_collection': 'teach',
'from_vertex_collections': ['teachers'],
'to_vertex_collections': ['lectures']
'teachers': {
'docs': teacher_vertices_to_insert
'options': {
'overwrite' = True,
'sync' = True,
'batch_size' = 50
}
},
'lectures': {
'docs': lecture_vertices_to_insert
'options': {
'overwrite' = False,
'sync' = False,
'batch_size' = 4
}
},
'teach': {
'docs': teach_edges_to_insert
}
}
"""
data: Json = {"name": name, "options": dict()}
Expand Down Expand Up @@ -1263,7 +1295,15 @@ def response_handler(resp: Response) -> Graph:
return Graph(self._conn, self._executor, name)
raise GraphCreateError(resp, request)

return self._execute(request, response_handler)
graph = self._execute(request, response_handler)

if collections is not None:
for name, data in collections.items():
self.collection(name).import_bulk(
data["docs"], **data.get("options", {})
)

return graph

def delete_graph(
self,
Expand Down
26 changes: 25 additions & 1 deletion arango/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import logging
from contextlib import contextmanager
from typing import Any, Iterator, Union
from typing import Any, Iterator, List, Optional, Sequence, Union

from arango.exceptions import DocumentParseError
from arango.typings import Json
Expand Down Expand Up @@ -82,3 +82,27 @@ def is_none_or_str(obj: Any) -> bool:
:rtype: bool
"""
return obj is None or isinstance(obj, str)


def get_batches(
l: Sequence[Json], batch_size: Optional[int] = None
) -> Union[List[Sequence[Json]], Iterator[Sequence[Json]]]:
"""Generator to split a list in batches
of (maximum) **batch_size** elements each.
If **batch_size** is invalid, return entire
list as one batch.

:param l: The list of elements.
:type l: list
:param batch_size: Number of elements per batch.
:type batch_size: int | None
"""
if batch_size is None or batch_size <= 0 or batch_size >= len(l):
return [l]

def generator() -> Iterator[Sequence[Json]]:
n = int(batch_size) # type: ignore # (false positive)
for i in range(0, len(l), n):
yield l[i : i + n]

return generator()
10 changes: 10 additions & 0 deletions tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -1832,6 +1832,16 @@ def test_document_import_bulk(col, bad_col, docs):
assert col[doc_key]["loc"] == doc["loc"]
empty_collection(col)

# Test import bulk with batch_size
results = col.import_bulk(docs, batch_size=len(docs) // 2)
assert type(results) is list
assert len(results) == 2
empty_collection(col)

result = col.import_bulk(docs, batch_size=len(docs) * 2)
assert type(result) is dict
empty_collection(col)

# Test import bulk on_duplicate actions
doc = docs[0]
doc_key = doc["_key"]
Expand Down
31 changes: 31 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,37 @@ def test_graph_properties(graph, bad_graph, db):
assert isinstance(properties["revision"], str)


def test_graph_provision(graph, db):
vertices = [{"_key": str(i)} for i in range(1, 101)]
edges = [
{"_from": f"numbers/{j}", "_to": f"numbers/{i}", "result": j / i}
for i in range(1, 101)
for j in range(1, 101)
if j % i == 0
]
e_d = [
{
"edge_collection": "is_divisible_by",
"from_vertex_collections": ["numbers"],
"to_vertex_collections": ["numbers"],
}
]

name = "divisibility-graph"
db.delete_graph(name, drop_collections=True, ignore_missing=True)
graph = db.create_graph(
name=name,
edge_definitions=e_d,
collections={
"numbers": {"docs": vertices, "options": {"batch_size": 5}},
"is_divisible_by": {"docs": edges, "options": {"sync": True}},
},
)

assert graph.vertex_collection("numbers").count() == len(vertices)
assert graph.edge_collection("is_divisible_by").count() == len(edges)


def test_graph_management(db, bad_db):
# Test create graph
graph_name = generate_graph_name()
Expand Down