From dc9a866b69caf6b5aedade7aaad94bdf680b5ca9 Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Tue, 8 Oct 2024 09:50:27 -0400 Subject: [PATCH 1/8] adds scorer to AggregateRequest --- redis/commands/search/aggregation.py | 18 ++++++++- tests/test_asyncio/test_search.py | 60 ++++++++++++++++++++++++++++ tests/test_search.py | 60 ++++++++++++++++++++++++++++ 3 files changed, 137 insertions(+), 1 deletion(-) diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 42c3547b0b..1629b7198d 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List, Union, Optional FIELDNAME = object() @@ -112,6 +112,7 @@ def __init__(self, query: str = "*") -> None: self._cursor = [] self._dialect = None self._add_scores = False + self._scorer = Optional[str] = None def load(self, *fields: List[str]) -> "AggregateRequest": """ @@ -300,6 +301,17 @@ def add_scores(self) -> "AggregateRequest": self._add_scores = True return self + def scorer(self, scorer: str) -> "AggregateRequest": + """ + Use a different scoring function to evaluate document relevance. + Default is `TFIDF`. + + :param scorer: The scoring function to use + (e.g. `TFIDF.DOCNORM` or `BM25`) + """ + self._scorer = scorer + return self + def verbatim(self) -> "AggregateRequest": self._verbatim = True return self @@ -323,6 +335,9 @@ def build_args(self) -> List[str]: if self._verbatim: ret.append("VERBATIM") + if self._scorer: + ret.extend(["SCORER", self._scorer]) + if self._add_scores: ret.append("ADDSCORES") @@ -332,6 +347,7 @@ def build_args(self) -> List[str]: if self._loadall: ret.append("LOAD") ret.append("*") + elif self._loadfields: ret.append("LOAD") ret.append(str(len(self._loadfields))) diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 0e6fe22131..0071760ffe 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1556,6 +1556,66 @@ async def test_aggregations_add_scores(decoded_r: redis.Redis): assert res.rows[1] == ["__score", "0.2"] +@pytest.mark.redismod +@skip_ifmodversion_lt("2.10.05", "search") +async def test_aggregations_hybrid_scoring(decoded_r: redis.Redis): + assert await decoded_r.ft().create_index( + ( + TextField("name", sortable=True, weight=5.0), + TextField("description", sortable=True, weight=5.0), + VectorField( + "vector", + "HNSW", + {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"}, + ), + ) + ) + + assert await decoded_r.hset( + "doc1", + mapping={ + "name": "cat book", + "description": "a book about cats", + "vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(), + }, + ) + assert await decoded_r.hset( + "doc2", + mapping={ + "name": "dog book", + "description": "a book about dogs", + "vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(), + }, + ) + + query_string = "(@description:cat)=>[KNN 3 @vector $vec_param AS dist]" + req = ( + aggregations.AggregateRequest(query_string) + .scorer("BM25") + .add_scores() + .apply(hybrid_score="@__score + @dist") + .load("*") + .dialect(4) + ) + + res = ( + await decoded_r.ft() + .aggregate( + req, + query_params={ + "vec_param": np.array([0.11, 0.21]).astype(np.float32).tobytes() + }, + ) + .rows[0] + ) + + assert len(res) == 6 + assert b"hybrid_score" in res + assert b"__score" in res + assert b"__dist" in res + assert float(res[1]) + float(res[3]) == float(res[5]) + + @pytest.mark.redismod @skip_if_redis_enterprise() async def test_search_commands_in_pipeline(decoded_r: redis.Redis): diff --git a/tests/test_search.py b/tests/test_search.py index dde59f0f87..dfe625a810 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1466,6 +1466,66 @@ def test_aggregations_add_scores(client): assert res.rows[1] == ["__score", "0.2"] +@pytest.mark.redismod +@skip_ifmodversion_lt("2.10.05", "search") +async def test_aggregations_hybrid_scoring(client): + client.ft().create_index( + ( + TextField("name", sortable=True, weight=5.0), + TextField("description", sortable=True, weight=5.0), + VectorField( + "vector", + "HNSW", + {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"}, + ), + ) + ) + + client.hset( + "doc1", + mapping={ + "name": "cat book", + "description": "a book about cats", + "vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(), + }, + ) + client.hset( + "doc2", + mapping={ + "name": "dog book", + "description": "a book about dogs", + "vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(), + }, + ) + + query_string = "(@description:cat)=>[KNN 3 @vector $vec_param AS dist]" + req = ( + aggregations.AggregateRequest(query_string) + .scorer("BM25") + .add_scores() + .apply(hybrid_score="@__score + @dist") + .load("*") + .dialect(4) + ) + + res = ( + client.ft() + .aggregate( + req, + query_params={ + "vec_param": np.array([0.11, 0.21]).astype(np.float32).tobytes() + }, + ) + .rows[0] + ) + + assert len(res) == 6 + assert b"hybrid_score" in res + assert b"__score" in res + assert b"__dist" in res + assert float(res[1]) + float(res[3]) == float(res[5]) + + @pytest.mark.redismod @skip_ifmodversion_lt("2.0.0", "search") def test_index_definition(client): From 2172c583e24cd0e3c40acaf2b0c709971a2aabcd Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Wed, 16 Oct 2024 08:12:25 -0400 Subject: [PATCH 2/8] fix linting --- redis/commands/search/aggregation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 1629b7198d..09849e66ea 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -1,4 +1,4 @@ -from typing import List, Union, Optional +from typing import List, Union FIELDNAME = object() @@ -112,7 +112,7 @@ def __init__(self, query: str = "*") -> None: self._cursor = [] self._dialect = None self._add_scores = False - self._scorer = Optional[str] = None + self._scorer = None def load(self, *fields: List[str]) -> "AggregateRequest": """ From 4b71d01290778c52f8dff9c8479338aad6b797bc Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Mon, 21 Oct 2024 10:50:08 -0400 Subject: [PATCH 3/8] update tests for BM25 --- tests/test_asyncio/test_search.py | 29 ++++++++++++----------------- tests/test_search.py | 29 ++++++++++++----------------- 2 files changed, 24 insertions(+), 34 deletions(-) diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 0071760ffe..fb813b0bc7 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1575,7 +1575,7 @@ async def test_aggregations_hybrid_scoring(decoded_r: redis.Redis): "doc1", mapping={ "name": "cat book", - "description": "a book about cats", + "description": "an animal book about cats", "vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(), }, ) @@ -1583,12 +1583,12 @@ async def test_aggregations_hybrid_scoring(decoded_r: redis.Redis): "doc2", mapping={ "name": "dog book", - "description": "a book about dogs", + "description": "an animal book about dogs", "vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(), }, ) - query_string = "(@description:cat)=>[KNN 3 @vector $vec_param AS dist]" + query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]" req = ( aggregations.AggregateRequest(query_string) .scorer("BM25") @@ -1598,22 +1598,17 @@ async def test_aggregations_hybrid_scoring(decoded_r: redis.Redis): .dialect(4) ) - res = ( - await decoded_r.ft() - .aggregate( - req, - query_params={ - "vec_param": np.array([0.11, 0.21]).astype(np.float32).tobytes() - }, - ) - .rows[0] + res = await decoded_r.ft().aggregate( + req, + query_params={"vec_param": np.array([0.11, 0.22]).astype(np.float32).tobytes()}, ) - assert len(res) == 6 - assert b"hybrid_score" in res - assert b"__score" in res - assert b"__dist" in res - assert float(res[1]) + float(res[3]) == float(res[5]) + if isinstance(res, dict): + assert len(res["results"]) == 2 + else: + assert len(res.rows) == 2 + for row in res.rows: + len(row) == 6 @pytest.mark.redismod diff --git a/tests/test_search.py b/tests/test_search.py index dfe625a810..0f0e7bb309 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1485,7 +1485,7 @@ async def test_aggregations_hybrid_scoring(client): "doc1", mapping={ "name": "cat book", - "description": "a book about cats", + "description": "an animal book about cats", "vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(), }, ) @@ -1493,12 +1493,12 @@ async def test_aggregations_hybrid_scoring(client): "doc2", mapping={ "name": "dog book", - "description": "a book about dogs", + "description": "an animal book about dogs", "vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(), }, ) - query_string = "(@description:cat)=>[KNN 3 @vector $vec_param AS dist]" + query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]" req = ( aggregations.AggregateRequest(query_string) .scorer("BM25") @@ -1508,22 +1508,17 @@ async def test_aggregations_hybrid_scoring(client): .dialect(4) ) - res = ( - client.ft() - .aggregate( - req, - query_params={ - "vec_param": np.array([0.11, 0.21]).astype(np.float32).tobytes() - }, - ) - .rows[0] + res = client.ft().aggregate( + req, + query_params={"vec_param": np.array([0.11, 0.21]).astype(np.float32).tobytes()}, ) - assert len(res) == 6 - assert b"hybrid_score" in res - assert b"__score" in res - assert b"__dist" in res - assert float(res[1]) + float(res[3]) == float(res[5]) + if isinstance(res, dict): + assert len(res["results"]) == 2 + else: + assert len(res.rows) == 2 + for row in res.rows: + len(row) == 6 @pytest.mark.redismod From b28656d967c0bbd701b6f8ccf3eeda82f7e0ab61 Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Tue, 22 Oct 2024 05:21:54 -0400 Subject: [PATCH 4/8] enum for aggregation scorer --- redis/commands/search/aggregation.py | 14 ++++++++++++-- tests/test_asyncio/test_search.py | 8 ++++++++ tests/test_search.py | 8 ++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 09849e66ea..0c2dcb4191 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -1,8 +1,18 @@ +from enum import Enum from typing import List, Union FIELDNAME = object() +class Scorers(Enum): + TFIDF = "TFIDF" + TFIDF_DOCNORM = "TFIDF.DOCNORM" + BM25 = "BM25" + DISMAX = "DISMAX" + DOCSCORE = "DOCSCORE" + HAMMING = "HAMMING" + + class Limit: def __init__(self, offset: int = 0, count: int = 0) -> None: self.offset = offset @@ -112,7 +122,7 @@ def __init__(self, query: str = "*") -> None: self._cursor = [] self._dialect = None self._add_scores = False - self._scorer = None + self._scorer = Scorers.TFIDF.value def load(self, *fields: List[str]) -> "AggregateRequest": """ @@ -309,7 +319,7 @@ def scorer(self, scorer: str) -> "AggregateRequest": :param scorer: The scoring function to use (e.g. `TFIDF.DOCNORM` or `BM25`) """ - self._scorer = scorer + self._scorer = Scorers(scorer).value return self def verbatim(self) -> "AggregateRequest": diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index fb813b0bc7..88efa0c0bf 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1611,6 +1611,14 @@ async def test_aggregations_hybrid_scoring(decoded_r: redis.Redis): len(row) == 6 +@pytest.mark.redismod +@skip_ifmodversion_lt("2.10.05", "search") +async def test_invalid_scorer(): + + with pytest.raises(ValueError): + aggregations.AggregateRequest("*").scorer("blah") + + @pytest.mark.redismod @skip_if_redis_enterprise() async def test_search_commands_in_pipeline(decoded_r: redis.Redis): diff --git a/tests/test_search.py b/tests/test_search.py index 0f0e7bb309..30c3a30d22 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1521,6 +1521,14 @@ async def test_aggregations_hybrid_scoring(client): len(row) == 6 +@pytest.mark.redismod +@skip_ifmodversion_lt("2.10.05", "search") +async def test_invalid_scorer(): + + with pytest.raises(ValueError): + aggregations.AggregateRequest("*").scorer("blah") + + @pytest.mark.redismod @skip_ifmodversion_lt("2.0.0", "search") def test_index_definition(client): From 5e1cadbf24c4e81130d46b3d0ed53f85399dd56a Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Tue, 22 Oct 2024 06:06:05 -0400 Subject: [PATCH 5/8] update signature --- redis/commands/search/aggregation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 0c2dcb4191..0dfe816b11 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -311,7 +311,7 @@ def add_scores(self) -> "AggregateRequest": self._add_scores = True return self - def scorer(self, scorer: str) -> "AggregateRequest": + def scorer(self, scorer: Scorers) -> "AggregateRequest": """ Use a different scoring function to evaluate document relevance. Default is `TFIDF`. From bb95f1ce0f1e0671c3bb1dcef6a81f99b0a9b279 Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Tue, 22 Oct 2024 09:36:36 -0400 Subject: [PATCH 6/8] revert back to string input --- redis/commands/search/aggregation.py | 16 +++------------- tests/test_asyncio/test_search.py | 8 -------- tests/test_search.py | 8 -------- 3 files changed, 3 insertions(+), 29 deletions(-) diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 0dfe816b11..5638f1d662 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -1,18 +1,8 @@ -from enum import Enum from typing import List, Union FIELDNAME = object() -class Scorers(Enum): - TFIDF = "TFIDF" - TFIDF_DOCNORM = "TFIDF.DOCNORM" - BM25 = "BM25" - DISMAX = "DISMAX" - DOCSCORE = "DOCSCORE" - HAMMING = "HAMMING" - - class Limit: def __init__(self, offset: int = 0, count: int = 0) -> None: self.offset = offset @@ -122,7 +112,7 @@ def __init__(self, query: str = "*") -> None: self._cursor = [] self._dialect = None self._add_scores = False - self._scorer = Scorers.TFIDF.value + self._scorer = "TFIDF" def load(self, *fields: List[str]) -> "AggregateRequest": """ @@ -311,7 +301,7 @@ def add_scores(self) -> "AggregateRequest": self._add_scores = True return self - def scorer(self, scorer: Scorers) -> "AggregateRequest": + def scorer(self, scorer: str) -> "AggregateRequest": """ Use a different scoring function to evaluate document relevance. Default is `TFIDF`. @@ -319,7 +309,7 @@ def scorer(self, scorer: Scorers) -> "AggregateRequest": :param scorer: The scoring function to use (e.g. `TFIDF.DOCNORM` or `BM25`) """ - self._scorer = Scorers(scorer).value + self._scorer = scorer return self def verbatim(self) -> "AggregateRequest": diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 88efa0c0bf..fb813b0bc7 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1611,14 +1611,6 @@ async def test_aggregations_hybrid_scoring(decoded_r: redis.Redis): len(row) == 6 -@pytest.mark.redismod -@skip_ifmodversion_lt("2.10.05", "search") -async def test_invalid_scorer(): - - with pytest.raises(ValueError): - aggregations.AggregateRequest("*").scorer("blah") - - @pytest.mark.redismod @skip_if_redis_enterprise() async def test_search_commands_in_pipeline(decoded_r: redis.Redis): diff --git a/tests/test_search.py b/tests/test_search.py index 30c3a30d22..0f0e7bb309 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1521,14 +1521,6 @@ async def test_aggregations_hybrid_scoring(client): len(row) == 6 -@pytest.mark.redismod -@skip_ifmodversion_lt("2.10.05", "search") -async def test_invalid_scorer(): - - with pytest.raises(ValueError): - aggregations.AggregateRequest("*").scorer("blah") - - @pytest.mark.redismod @skip_ifmodversion_lt("2.0.0", "search") def test_index_definition(client): From 63c691645dfdb043997d04e95e43537d64683a8b Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Mon, 31 Mar 2025 16:47:47 -0400 Subject: [PATCH 7/8] remove unneded file --- docs/examples/search_vector_similarity_examples.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/examples/search_vector_similarity_examples.ipynb b/docs/examples/search_vector_similarity_examples.ipynb index 809dbda4ea..dc8cf278cc 100644 --- a/docs/examples/search_vector_similarity_examples.ipynb +++ b/docs/examples/search_vector_similarity_examples.ipynb @@ -638,7 +638,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Find more example apps, tutorials, and projects using Redis Vector Similarity Search [in this GitHub organization](https://github.com/RedisVentures)." + "Find more example apps, tutorials, and projects using Redis Vector Similarity Search [Redis AI resources](https://github.com/redis-developer/redis-ai-resources/tree/main)." ] } ], From a3bf5ef0664b2860633a57151f11f77a4463cf95 Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Mon, 31 Mar 2025 16:49:24 -0400 Subject: [PATCH 8/8] update wording --- docs/examples/search_vector_similarity_examples.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/examples/search_vector_similarity_examples.ipynb b/docs/examples/search_vector_similarity_examples.ipynb index dc8cf278cc..af6d825129 100644 --- a/docs/examples/search_vector_similarity_examples.ipynb +++ b/docs/examples/search_vector_similarity_examples.ipynb @@ -638,7 +638,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Find more example apps, tutorials, and projects using Redis Vector Similarity Search [Redis AI resources](https://github.com/redis-developer/redis-ai-resources/tree/main)." + "Find more example apps, tutorials, and projects using Redis Vector Similarity Search check out the [Redis AI resources repo](https://github.com/redis-developer/redis-ai-resources/tree/main)." ] } ],