diff --git a/dev_requirements.txt b/dev_requirements.txt index 3715599af0..7daefb8ca6 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -4,6 +4,7 @@ flake8==5.0.4 flake8-isort==6.0.0 flynt~=0.69.0 mock==4.0.3 +numpy>=1.21.0 packaging>=20.4 pytest==7.2.0 pytest-timeout==2.1.0 diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 2df2b5a754..b308c68f5f 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -80,6 +80,7 @@ def _parse_search(self, res, **kwargs): duration=kwargs["duration"], has_payload=kwargs["query"]._with_payloads, with_scores=kwargs["query"]._with_scores, + decode_fields=kwargs["decode_fields"], ) def _parse_aggregate(self, res, **kwargs): @@ -484,18 +485,27 @@ def search( self, query: Union[str, Query], query_params: Union[Dict[str, Union[str, int, float, bytes]], None] = None, + decode_fields: bool = True, ): """ - Search the index for a given query, and return a result of documents + Search the index for a given query, and return a result of documents. + + Args: + query: The search query. This can be a simple text string for basic queries, + or a Query object for more complex queries. Refer to RediSearch's + documentation for details on the query format. + query_params: Additional parameters for the query. These parameters are used + to replace placeholders in the query string. This is useful + for safely including user input in a search query. + decode_fields: If `True`, which is the default, decodes the fields in the + search results. If `False`, fields are returned in their raw + binary form. - ### Parameters - - - **query**: the search query. Either a text for simple queries with - default parameters, or a Query object for complex queries. - See RediSearch's documentation on query format + Returns: + A result set of documents matching the query. - For more information see `FT.SEARCH `_. - """ # noqa + For more information see https://redis.io/commands/ft.search + """ args, query = self._mk_query_args(query, query_params=query_params) st = time.time() res = self.execute_command(SEARCH_CMD, *args) @@ -504,7 +514,11 @@ def search( return res return self._parse_results( - SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0 + SEARCH_CMD, + res, + query=query, + duration=(time.time() - st) * 1000.0, + decode_fields=decode_fields, ) def explain( @@ -911,18 +925,27 @@ async def search( self, query: Union[str, Query], query_params: Dict[str, Union[str, int, float]] = None, + decode_fields: bool = True, ): """ - Search the index for a given query, and return a result of documents + Search the index for a given query, and return a result of documents. + + Args: + query: The search query. This can be a simple text string for basic queries, + or a Query object for more complex queries. Refer to RediSearch's + documentation for details on the query format. + query_params: Additional parameters for the query. These parameters are used + to replace placeholders in the query string. This is useful + for safely including user input in a search query. + decode_fields: If `True`, which is the default, decodes the fields in the + search results. If `False`, fields are returned in their raw + binary form. - ### Parameters - - - **query**: the search query. Either a text for simple queries with - default parameters, or a Query object for complex queries. - See RediSearch's documentation on query format + Returns: + A result set of documents matching the query. - For more information see `FT.SEARCH `_. - """ # noqa + For more information see https://redis.io/commands/ft.search + """ args, query = self._mk_query_args(query, query_params=query_params) st = time.time() res = await self.execute_command(SEARCH_CMD, *args) @@ -931,7 +954,11 @@ async def search( return res return self._parse_results( - SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0 + SEARCH_CMD, + res, + query=query, + duration=(time.time() - st) * 1000.0, + decode_fields=decode_fields, ) async def aggregate( diff --git a/redis/commands/search/result.py b/redis/commands/search/result.py index 5b19e6faa4..36bb5802a9 100644 --- a/redis/commands/search/result.py +++ b/redis/commands/search/result.py @@ -9,7 +9,13 @@ class Result: """ def __init__( - self, res, hascontent, duration=0, has_payload=False, with_scores=False + self, + res, + hascontent, + duration=0, + has_payload=False, + with_scores=False, + decode_fields=False, ): """ - **snippets**: An optional dictionary of the form @@ -32,24 +38,26 @@ def __init__( for i in range(1, len(res), step): id = to_string(res[i]) - payload = to_string(res[i + offset]) if has_payload else None + if has_payload: + payload_data = res[i + offset] + payload = to_string(payload_data) if decode_fields else payload_data + else: + payload = None # fields_offset = 2 if has_payload else 1 fields_offset = offset + 1 if has_payload else offset score = float(res[i + 1]) if with_scores else None fields = {} if hascontent and res[i + fields_offset] is not None: - fields = ( + keys = res[i + fields_offset][::2] + values = res[i + fields_offset][1::2] + fields = dict( dict( - dict( - zip( - map(to_string, res[i + fields_offset][::2]), - map(to_string, res[i + fields_offset][1::2]), - ) + zip( + map(to_string, keys), + map(to_string, values) if decode_fields else values, ) ) - if hascontent - else {} ) try: del fields["id"] diff --git a/tests/test_search.py b/tests/test_search.py index bfe204254c..f74975a914 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -4,6 +4,7 @@ import time from io import TextIOWrapper +import numpy as np import pytest import redis import redis.commands.search @@ -2284,3 +2285,37 @@ def test_geoshape(client: redis.Redis): assert result.docs[0]["id"] == "small" result = client.ft().search(q2, query_params=qp2) assert len(result.docs) == 2 + + +@pytest.mark.redismod +def test_vector_storage_and_retrieval(r: redis.Redis): + r.ft("vector_index").create_index( + ( + VectorField( + "my_vector", + "FLAT", + { + "TYPE": "FLOAT32", + "DIM": 4, + "DISTANCE_METRIC": "COSINE", + }, + ), + ), + definition=IndexDefinition(prefix=["doc:"], index_type=IndexType.HASH), + ) + + vector_data = [0.1, 0.2, 0.3, 0.4] + r.hset( + "doc:1", + mapping={"my_vector": np.array(vector_data, dtype=np.float32).tobytes()}, + ) + + query = Query("*").with_payloads().return_fields("my_vector").dialect(2) + res = r.ft("vector_index").search(query, decode_fields=False) + + assert res.total == 1 + assert res.docs[0].id == "doc:1" + retrieved_vector_data = np.frombuffer( + res.docs[0].__dict__["my_vector"], dtype=np.float32 + ) + assert np.allclose(retrieved_vector_data, vector_data)