diff --git a/redisvl/query/query.py b/redisvl/query/query.py index a76a85b2..1cb5aebf 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -136,6 +136,7 @@ def __init__( return_fields: Optional[List[str]] = None, num_results: int = 10, dialect: int = 2, + sort_by: Optional[str] = None, params: Optional[Dict[str, Any]] = None, ): """A query for a running a filtered search with a filter expression. @@ -146,6 +147,8 @@ def __init__( return_fields (Optional[List[str]], optional): The fields to return. num_results (Optional[int], optional): The number of results to return. Defaults to 10. + sort_by (Optional[str]): The field to order the results by. Defaults + to None. Results will be ordered by vector distance. params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None. @@ -164,6 +167,7 @@ def __init__( """ super().__init__(return_fields, num_results, dialect) self.set_filter(filter_expression) + self._sort_by = sort_by self._params = params or {} @property @@ -180,6 +184,8 @@ def query(self) -> Query: .paging(self._first, self._limit) .dialect(self._dialect) ) + if self._sort_by: + query = query.sort_by(self._sort_by) return query @@ -201,12 +207,14 @@ def __init__( num_results: int = 10, return_score: bool = True, dialect: int = 2, + sort_by: Optional[str] = None, ): super().__init__(return_fields, num_results, dialect) self.set_filter(filter_expression) self._vector = vector self._field = vector_field_name self._dtype = dtype.lower() + self._sort_by = sort_by if return_score: self._return_fields.append(self.DISTANCE_ID) @@ -223,6 +231,7 @@ def __init__( num_results: int = 10, return_score: bool = True, dialect: int = 2, + sort_by: Optional[str] = None, ): """A query for running a vector search along with an optional filter expression. @@ -243,6 +252,8 @@ def __init__( distance. Defaults to True. dialect (int, optional): The RediSearch query dialect. Defaults to 2. + sort_by (Optional[str]): The field to order the results by. Defaults + to None. Results will be ordered by vector distance. Raises: TypeError: If filter_expression is not of type redisvl.query.FilterExpression @@ -259,6 +270,7 @@ def __init__( num_results, return_score, dialect, + sort_by, ) @property @@ -272,10 +284,13 @@ def query(self) -> Query: query = ( Query(base_query) .return_fields(*self._return_fields) - .sort_by(self.DISTANCE_ID) .paging(self._first, self._limit) .dialect(self._dialect) ) + if self._sort_by: + query = query.sort_by(self._sort_by) + else: + query = query.sort_by(self.DISTANCE_ID) return query @property @@ -307,6 +322,7 @@ def __init__( num_results: int = 10, return_score: bool = True, dialect: int = 2, + sort_by: Optional[str] = None, ): """A query for running a filtered vector search based on semantic distance threshold. @@ -330,7 +346,8 @@ def __init__( distance. Defaults to True. dialect (int, optional): The RediSearch query dialect. Defaults to 2. - + sort_by (Optional[str]): The field to order the results by. Defaults + to None. Results will be ordered by vector distance. Raises: TypeError: If filter_expression is not of type redisvl.query.FilterExpression @@ -347,6 +364,7 @@ def __init__( num_results, return_score, dialect, + sort_by, ) self.set_distance_threshold(distance_threshold) @@ -390,10 +408,13 @@ def query(self) -> Query: query = ( Query(base_query) .return_fields(*self._return_fields) - .sort_by(self.DISTANCE_ID) .paging(self._first, self._limit) .dialect(self._dialect) ) + if self._sort_by: + query = query.sort_by(self._sort_by) + else: + query = query.sort_by(self.DISTANCE_ID) return query @property diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index a38a4151..3cef9d83 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -18,6 +18,16 @@ def vector_query(): ) +@pytest.fixture +def sorted_vector_query(): + return VectorQuery( + vector=[0.1, 0.1, 0.5], + vector_field_name="user_embedding", + return_fields=["user", "credit_score", "age", "job", "location"], + sort_by="age", + ) + + @pytest.fixture def filter_query(): return FilterQuery( @@ -26,6 +36,15 @@ def filter_query(): ) +@pytest.fixture +def sorted_filter_query(): + return FilterQuery( + return_fields=["user", "credit_score", "age", "job", "location"], + filter_expression=Tag("credit_score") == "high", + sort_by="age", + ) + + @pytest.fixture def range_query(): return RangeQuery( @@ -36,6 +55,17 @@ def range_query(): ) +@pytest.fixture +def sorted_range_query(): + return RangeQuery( + vector=[0.1, 0.1, 0.5], + vector_field_name="user_embedding", + return_fields=["user", "credit_score", "age", "job", "location"], + distance_threshold=0.2, + sort_by="age", + ) + + @pytest.fixture def index(sample_data, redis_url): # construct a search index from the schema @@ -160,6 +190,7 @@ def search( age_range=None, location=None, distance_threshold=0.2, + sort=False, ): """Utility function to test filters.""" @@ -199,6 +230,21 @@ def search( else: assert len(results.docs) == expected_count + # check results are in sorted order + if sort: + if isinstance(query, RangeQuery): + assert [int(doc.age) for doc in results.docs] == [12, 14, 18, 100] + else: + assert [int(doc.age) for doc in results.docs] == [ + 12, + 14, + 15, + 18, + 35, + 94, + 100, + ] + @pytest.fixture( params=["vector_query", "filter_query", "range_query"], @@ -339,3 +385,18 @@ def test_paginate_range_query(index, range_query): assert len(all_results) == expected_count assert i == expected_iterations assert all(float(item["vector_distance"]) <= 0.2 for item in all_results) + + +def test_sort_filter_query(index, sorted_filter_query): + t = Text("job") % "" + search(sorted_filter_query, index, t, 7, sort=True) + + +def test_sort_vector_query(index, sorted_vector_query): + t = Text("job") % "" + search(sorted_vector_query, index, t, 7, sort=True) + + +def test_sort_range_query(index, sorted_range_query): + t = Text("job") % "" + search(sorted_range_query, index, t, 7, sort=True) diff --git a/tests/unit/test_query_types.py b/tests/unit/test_query_types.py index 2093564e..09d63aca 100644 --- a/tests/unit/test_query_types.py +++ b/tests/unit/test_query_types.py @@ -45,6 +45,7 @@ def test_filter_query(): assert isinstance(filter_query.params, dict) assert filter_query.params == {} assert filter_query._dialect == 2 + assert filter_query._sort_by == None # Test set_filter functionality new_filter_expression = Tag("category") == "Sportswear" @@ -57,6 +58,12 @@ def test_filter_query(): assert filter_query._limit == 7 assert filter_query._num_results == 10 + # Test sort_by functionality + filter_query = FilterQuery( + filter_expression, return_fields, num_results=10, sort_by="price" + ) + assert filter_query._sort_by == "price" + def test_vector_query(): # Create a vector query @@ -73,6 +80,7 @@ def test_vector_query(): assert isinstance(vector_query.params, dict) assert vector_query.params != {} assert vector_query._dialect == 3 + assert vector_query._sort_by == None # Test set_filter functionality new_filter_expression = Tag("category") == "Sportswear" @@ -85,6 +93,17 @@ def test_vector_query(): assert vector_query._limit == 7 assert vector_query._num_results == 10 + # Test sort_by functionality + vector_query = VectorQuery( + sample_vector, + "vector_field", + ["field1", "field2"], + dialect=3, + num_results=10, + sort_by="field2", + ) + assert vector_query._sort_by == "field2" + def test_range_query(): # Create a filter expression @@ -104,6 +123,7 @@ def test_range_query(): assert isinstance(range_query.query, Query) assert isinstance(range_query.params, dict) assert range_query.params != {} + assert range_query._sort_by == None # Test set_filter functionality new_filter_expression = Tag("category") == "Outdoor" @@ -115,3 +135,14 @@ def test_range_query(): assert range_query._first == 5 assert range_query._limit == 7 assert range_query._num_results == 10 + + # Test sort_by functionality + range_query = RangeQuery( + sample_vector, + "vector_field", + ["field1"], + filter_expression, + num_results=10, + sort_by="field1", + ) + assert range_query._sort_by == "field1"