Skip to content

Commit e82dc8e

Browse files
committed
Added support for ADDSCORES modifier
1 parent fd0b0d3 commit e82dc8e

File tree

3 files changed

+45
-0
lines changed

3 files changed

+45
-0
lines changed

redis/commands/search/aggregation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(self, query: str = "*") -> None:
111111
self._verbatim = False
112112
self._cursor = []
113113
self._dialect = None
114+
self._add_scores = False
114115

115116
def load(self, *fields: List[str]) -> "AggregateRequest":
116117
"""
@@ -292,6 +293,13 @@ def with_schema(self) -> "AggregateRequest":
292293
self._with_schema = True
293294
return self
294295

296+
def add_scores(self) -> "AggregateRequest":
297+
"""
298+
If set, includes the score as an ordinary field of the row.
299+
"""
300+
self._add_scores = True
301+
return self
302+
295303
def verbatim(self) -> "AggregateRequest":
296304
self._verbatim = True
297305
return self
@@ -315,6 +323,9 @@ def build_args(self) -> List[str]:
315323
if self._verbatim:
316324
ret.append("VERBATIM")
317325

326+
if self._add_scores:
327+
ret.append("ADDSCORES")
328+
318329
if self._cursor:
319330
ret += self._cursor
320331

tests/test_asyncio/test_search.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,6 +1530,23 @@ async def test_withsuffixtrie(decoded_r: redis.Redis):
15301530
assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"]
15311531

15321532

1533+
@pytest.mark.redismod
1534+
@skip_ifmodversion_lt("2.10.05", "search")
1535+
async def test_aggregations_add_scores(decoded_r: redis.Redis):
1536+
assert await decoded_r.ft().create_index(
1537+
(TextField("name", sortable=True, weight=5.0), NumericField("age", sortable=True))
1538+
)
1539+
1540+
assert await decoded_r.ft().client.hset("doc1", mapping={"name": "bar", "age": "25"})
1541+
assert await decoded_r.ft().client.hset("doc2", mapping={"name": "foo", "age": "19"})
1542+
1543+
req = (aggregations.AggregateRequest("*").add_scores())
1544+
res = await decoded_r.ft().aggregate(req)
1545+
assert len(res.rows) == 2
1546+
assert res.rows[0] == ["__score", "0.2"]
1547+
assert res.rows[1] == ["__score", "0.2"]
1548+
1549+
15331550
@pytest.mark.redismod
15341551
@skip_if_redis_enterprise()
15351552
async def test_search_commands_in_pipeline(decoded_r: redis.Redis):

tests/test_search.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,6 +1440,23 @@ def test_aggregations_filter(client):
14401440
assert res["results"][1]["extra_attributes"] == {"age": "25"}
14411441

14421442

1443+
@pytest.mark.redismod
1444+
@skip_ifmodversion_lt("2.10.05", "search")
1445+
def test_aggregations_add_scores(client):
1446+
client.ft().create_index(
1447+
(TextField("name", sortable=True, weight=5.0), NumericField("age", sortable=True))
1448+
)
1449+
1450+
client.ft().client.hset("doc1", mapping={"name": "bar", "age": "25"})
1451+
client.ft().client.hset("doc2", mapping={"name": "foo", "age": "19"})
1452+
1453+
req = (aggregations.AggregateRequest("*").add_scores())
1454+
res = client.ft().aggregate(req)
1455+
assert len(res.rows) == 2
1456+
assert res.rows[0] == ["__score", "0.2"]
1457+
assert res.rows[1] == ["__score", "0.2"]
1458+
1459+
14431460
@pytest.mark.redismod
14441461
@skip_ifmodversion_lt("2.0.0", "search")
14451462
def test_index_definition(client):

0 commit comments

Comments
 (0)