diff --git a/redisvl/query/query.py b/redisvl/query/query.py index 1cb5aebf..a1b3832b 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -13,6 +13,8 @@ def __init__( return_fields: Optional[List[str]] = None, num_results: int = 10, dialect: int = 2, + sort_by: Optional[str] = None, + in_order: bool = False, ): """Base query class used to subclass many query types.""" self._return_fields = return_fields if return_fields is not None else [] @@ -20,6 +22,8 @@ def __init__( self._dialect = dialect self._first = 0 self._limit = num_results + self._sort_by = sort_by + self._in_order = in_order def __str__(self) -> str: return " ".join([str(x) for x in self.query.get_args()]) @@ -137,6 +141,7 @@ def __init__( num_results: int = 10, dialect: int = 2, sort_by: Optional[str] = None, + in_order: bool = False, params: Optional[Dict[str, Any]] = None, ): """A query for a running a filtered search with a filter expression. @@ -149,6 +154,9 @@ def __init__( return. Defaults to 10. sort_by (Optional[str]): The field to order the results by. Defaults to None. Results will be ordered by vector distance. + in_order (bool): Requires the terms in the field to have + the same order as the terms in the query filter, regardless of + the offsets between them. Defaults to False. params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None. @@ -165,9 +173,8 @@ def __init__( q = FilterQuery(return_fields=["brand", "price"], filter_expression=t) """ - super().__init__(return_fields, num_results, dialect) + super().__init__(return_fields, num_results, dialect, sort_by, in_order) self.set_filter(filter_expression) - self._sort_by = sort_by self._params = params or {} @property @@ -186,6 +193,10 @@ def query(self) -> Query: ) if self._sort_by: query = query.sort_by(self._sort_by) + + if self._in_order: + query = query.in_order() + return query @@ -208,13 +219,13 @@ def __init__( return_score: bool = True, dialect: int = 2, sort_by: Optional[str] = None, + in_order: bool = False, ): - super().__init__(return_fields, num_results, dialect) + super().__init__(return_fields, num_results, dialect, sort_by, in_order) self.set_filter(filter_expression) self._vector = vector self._field = vector_field_name self._dtype = dtype.lower() - self._sort_by = sort_by if return_score: self._return_fields.append(self.DISTANCE_ID) @@ -232,6 +243,7 @@ def __init__( return_score: bool = True, dialect: int = 2, sort_by: Optional[str] = None, + in_order: bool = False, ): """A query for running a vector search along with an optional filter expression. @@ -254,6 +266,9 @@ def __init__( Defaults to 2. sort_by (Optional[str]): The field to order the results by. Defaults to None. Results will be ordered by vector distance. + in_order (bool): Requires the terms in the field to have + the same order as the terms in the query filter, regardless of + the offsets between them. Defaults to False. Raises: TypeError: If filter_expression is not of type redisvl.query.FilterExpression @@ -271,6 +286,7 @@ def __init__( return_score, dialect, sort_by, + in_order, ) @property @@ -291,6 +307,10 @@ def query(self) -> Query: query = query.sort_by(self._sort_by) else: query = query.sort_by(self.DISTANCE_ID) + + if self._in_order: + query = query.in_order() + return query @property @@ -323,6 +343,7 @@ def __init__( return_score: bool = True, dialect: int = 2, sort_by: Optional[str] = None, + in_order: bool = False, ): """A query for running a filtered vector search based on semantic distance threshold. @@ -348,6 +369,10 @@ def __init__( Defaults to 2. sort_by (Optional[str]): The field to order the results by. Defaults to None. Results will be ordered by vector distance. + in_order (bool): Requires the terms in the field to have + the same order as the terms in the query filter, regardless of + the offsets between them. Defaults to False. + Raises: TypeError: If filter_expression is not of type redisvl.query.FilterExpression @@ -365,6 +390,7 @@ def __init__( return_score, dialect, sort_by, + in_order, ) self.set_distance_threshold(distance_threshold) @@ -415,6 +441,10 @@ def query(self) -> Query: query = query.sort_by(self._sort_by) else: query = query.sort_by(self.DISTANCE_ID) + + if self._in_order: + query = query.in_order() + return query @property diff --git a/tests/unit/test_query_types.py b/tests/unit/test_query_types.py index 09d63aca..e0fd4f17 100644 --- a/tests/unit/test_query_types.py +++ b/tests/unit/test_query_types.py @@ -46,6 +46,7 @@ def test_filter_query(): assert filter_query.params == {} assert filter_query._dialect == 2 assert filter_query._sort_by == None + assert filter_query._in_order == False # Test set_filter functionality new_filter_expression = Tag("category") == "Sportswear" @@ -64,6 +65,12 @@ def test_filter_query(): ) assert filter_query._sort_by == "price" + # Test in_order functionality + filter_query = FilterQuery( + filter_expression, return_fields, num_results=10, in_order=True + ) + assert filter_query._in_order + def test_vector_query(): # Create a vector query @@ -81,6 +88,7 @@ def test_vector_query(): assert vector_query.params != {} assert vector_query._dialect == 3 assert vector_query._sort_by == None + assert vector_query._in_order == False # Test set_filter functionality new_filter_expression = Tag("category") == "Sportswear" @@ -124,6 +132,7 @@ def test_range_query(): assert isinstance(range_query.params, dict) assert range_query.params != {} assert range_query._sort_by == None + assert range_query._sort_by == None # Test set_filter functionality new_filter_expression = Tag("category") == "Outdoor"