Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,6 @@ arango/version.py

# test results
*_results.txt

# devcontainers
.devcontainer
11 changes: 10 additions & 1 deletion arango/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@
from arango.typings import Fields, Headers, Json, Jsons, Params
from arango.utils import (
build_filter_conditions,
build_sort_expression,
get_batches,
get_doc_id,
is_none_or_bool,
is_none_or_int,
is_none_or_str,
validate_sort_parameters,
)


Expand Down Expand Up @@ -718,6 +720,7 @@ def all(
:return: Document cursor.
:rtype: arango.cursor.Cursor
:raise arango.exceptions.DocumentGetError: If retrieval fails.
:raise arango.exceptions.SortValidationError: If sort parameters are invalid.
"""
assert is_none_or_int(skip), "skip must be a non-negative int"
assert is_none_or_int(limit), "limit must be a non-negative int"
Expand Down Expand Up @@ -753,6 +756,7 @@ def find(
skip: Optional[int] = None,
limit: Optional[int] = None,
allow_dirty_read: bool = False,
sort: Optional[Jsons] = None,
) -> Result[Cursor]:
"""Return all documents that match the given filters.

Expand All @@ -764,23 +768,28 @@ def find(
:type limit: int | None
:param allow_dirty_read: Allow reads from followers in a cluster.
:type allow_dirty_read: bool
:param sort: Document sort parameters
:type sort: Jsons | None
:return: Document cursor.
:rtype: arango.cursor.Cursor
:raise arango.exceptions.DocumentGetError: If retrieval fails.
:raise arango.exceptions.SortValidationError: If sort parameters are invalid.
"""
assert isinstance(filters, dict), "filters must be a dict"
assert is_none_or_int(skip), "skip must be a non-negative int"
assert is_none_or_int(limit), "limit must be a non-negative int"
if sort:
validate_sort_parameters(sort)

skip_val = skip if skip is not None else 0
limit_val = limit if limit is not None else "null"
query = f"""
FOR doc IN @@collection
{build_filter_conditions(filters)}
LIMIT {skip_val}, {limit_val}
{build_sort_expression(sort)}
RETURN doc
"""

bind_vars = {"@collection": self.name}

request = Request(
Expand Down
7 changes: 7 additions & 0 deletions arango/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,3 +1074,10 @@ class JWTRefreshError(ArangoClientError):

class JWTExpiredError(ArangoClientError):
"""JWT token has expired."""


###################################
# Parameter Validation Exceptions #
###################################
class SortValidationError(ArangoClientError):
"""Invalid sort parameters."""
43 changes: 41 additions & 2 deletions arango/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from contextlib import contextmanager
from typing import Any, Iterator, Sequence, Union

from arango.exceptions import DocumentParseError
from arango.typings import Json
from arango.exceptions import DocumentParseError, SortValidationError
from arango.typings import Json, Jsons


@contextmanager
Expand Down Expand Up @@ -126,3 +126,42 @@ def build_filter_conditions(filters: Json) -> str:
conditions.append(f"doc.{field} == {json.dumps(v)}")

return "FILTER " + " AND ".join(conditions)


def validate_sort_parameters(sort: Sequence[Json]) -> bool:
"""Validate sort parameters for an AQL query.

:param sort: Document sort parameters.
:type sort: Sequence[Json]
:return: Validation success.
:rtype: bool
:raise arango.exceptions.SortValidationError: If sort parameters are invalid.
"""
assert isinstance(sort, Sequence)
for param in sort:
if "sort_by" not in param or "sort_order" not in param:
raise SortValidationError(
"Each sort parameter must have 'sort_by' and 'sort_order'."
)
if param["sort_order"].upper() not in ["ASC", "DESC"]:
raise SortValidationError("'sort_order' must be either 'ASC' or 'DESC'")
return True


def build_sort_expression(sort: Jsons | None) -> str:
"""Build a sort condition for an AQL query.

:param sort: Document sort parameters.
:type sort: Jsons | None
:return: The complete AQL sort condition.
:rtype: str
"""
if not sort:
return ""

sort_chunks = []
for sort_param in sort:
chunk = f"doc.{sort_param['sort_by']} {sort_param['sort_order']}"
sort_chunks.append(chunk)

return "SORT " + ", ".join(sort_chunks)
6 changes: 6 additions & 0 deletions docs/document.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ Standard documents are managed via collection API wrapper:
assert student['GPA'] == 3.6
assert student['last'] == 'Kim'

# Retrieve one or more matching documents, sorted by a field.
for student in students.find({'first': 'John'}, sort=[{'sort_by': 'GPA', 'sort_order': 'DESC'}]):
assert student['_key'] == 'john'
assert student['GPA'] == 3.6
assert student['last'] == 'Kim'

# Retrieve a document by key.
students.get('john')

Expand Down
20 changes: 20 additions & 0 deletions tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,26 @@ def test_document_find(col, bad_col, docs):
# Set up test documents
col.import_bulk(docs)

# Test find with sort expression (single field)
found = list(col.find({}, sort=[{"sort_by": "text", "sort_order": "ASC"}]))
assert len(found) == 6
assert found[0]["text"] == "bar"
assert found[-1]["text"] == "foo"

# Test find with sort expression (multiple fields)
found = list(
col.find(
{},
sort=[
{"sort_by": "text", "sort_order": "ASC"},
{"sort_by": "val", "sort_order": "DESC"},
],
)
)
assert len(found) == 6
assert found[0]["val"] == 6
assert found[-1]["val"] == 1

# Test find (single match) with default options
found = list(col.find({"val": 2}))
assert len(found) == 1
Expand Down
Loading