Skip to content

Commit 25dfba9

Browse files
committed
Added support for ADDSCORES modifier
1 parent fd0b0d3 commit 25dfba9

File tree

3 files changed

+172
-125
lines changed

3 files changed

+172
-125
lines changed

redis/commands/search/aggregation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(self, query: str = "*") -> None:
111111
self._verbatim = False
112112
self._cursor = []
113113
self._dialect = None
114+
self._add_scores = False
114115

115116
def load(self, *fields: List[str]) -> "AggregateRequest":
116117
"""
@@ -292,6 +293,13 @@ def with_schema(self) -> "AggregateRequest":
292293
self._with_schema = True
293294
return self
294295

296+
def add_scores(self) -> "AggregateRequest":
297+
"""
298+
If set, includes the score as an ordinary field of the row.
299+
"""
300+
self._add_scores = True
301+
return self
302+
295303
def verbatim(self) -> "AggregateRequest":
296304
self._verbatim = True
297305
return self
@@ -315,6 +323,9 @@ def build_args(self) -> List[str]:
315323
if self._verbatim:
316324
ret.append("VERBATIM")
317325

326+
if self._add_scores:
327+
ret.append("ADDSCORES")
328+
318329
if self._cursor:
319330
ret += self._cursor
320331

tests/test_asyncio/test_search.py

Lines changed: 75 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -202,16 +202,16 @@ async def test_client(decoded_r: redis.Redis):
202202
# test slop and in order
203203
assert 193 == (await decoded_r.ft().search(Query("henry king"))).total
204204
assert (
205-
3
206-
== (
207-
await decoded_r.ft().search(Query("henry king").slop(0).in_order())
208-
).total
205+
3
206+
== (
207+
await decoded_r.ft().search(Query("henry king").slop(0).in_order())
208+
).total
209209
)
210210
assert (
211-
52
212-
== (
213-
await decoded_r.ft().search(Query("king henry").slop(0).in_order())
214-
).total
211+
52
212+
== (
213+
await decoded_r.ft().search(Query("king henry").slop(0).in_order())
214+
).total
215215
)
216216
assert 53 == (await decoded_r.ft().search(Query("henry king").slop(0))).total
217217
assert 167 == (await decoded_r.ft().search(Query("henry king").slop(100))).total
@@ -294,31 +294,31 @@ async def test_client(decoded_r: redis.Redis):
294294

295295
# test slop and in order
296296
assert (
297-
193 == (await decoded_r.ft().search(Query("henry king")))["total_results"]
297+
193 == (await decoded_r.ft().search(Query("henry king")))["total_results"]
298298
)
299299
assert (
300-
3
301-
== (await decoded_r.ft().search(Query("henry king").slop(0).in_order()))[
302-
"total_results"
303-
]
300+
3
301+
== (await decoded_r.ft().search(Query("henry king").slop(0).in_order()))[
302+
"total_results"
303+
]
304304
)
305305
assert (
306-
52
307-
== (await decoded_r.ft().search(Query("king henry").slop(0).in_order()))[
308-
"total_results"
309-
]
306+
52
307+
== (await decoded_r.ft().search(Query("king henry").slop(0).in_order()))[
308+
"total_results"
309+
]
310310
)
311311
assert (
312-
53
313-
== (await decoded_r.ft().search(Query("henry king").slop(0)))[
314-
"total_results"
315-
]
312+
53
313+
== (await decoded_r.ft().search(Query("henry king").slop(0)))[
314+
"total_results"
315+
]
316316
)
317317
assert (
318-
167
319-
== (await decoded_r.ft().search(Query("henry king").slop(100)))[
320-
"total_results"
321-
]
318+
167
319+
== (await decoded_r.ft().search(Query("henry king").slop(100)))[
320+
"total_results"
321+
]
322322
)
323323

324324
# test delete document
@@ -669,33 +669,33 @@ async def test_summarize(decoded_r: redis.Redis):
669669
doc = sorted((await decoded_r.ft().search(q)).docs)[0]
670670
assert "<b>Henry</b> IV" == doc.play
671671
assert (
672-
"ACT I SCENE I. London. The palace. Enter <b>KING</b> <b>HENRY</b>, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa
673-
== doc.txt
672+
"ACT I SCENE I. London. The palace. Enter <b>KING</b> <b>HENRY</b>, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa
673+
== doc.txt
674674
)
675675

676676
q = Query("king henry").paging(0, 1).summarize().highlight()
677677

678678
doc = sorted((await decoded_r.ft().search(q)).docs)[0]
679679
assert "<b>Henry</b> ... " == doc.play
680680
assert (
681-
"ACT I SCENE I. London. The palace. Enter <b>KING</b> <b>HENRY</b>, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa
682-
== doc.txt
681+
"ACT I SCENE I. London. The palace. Enter <b>KING</b> <b>HENRY</b>, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa
682+
== doc.txt
683683
)
684684
else:
685685
doc = sorted((await decoded_r.ft().search(q))["results"])[0]
686686
assert "<b>Henry</b> IV" == doc["extra_attributes"]["play"]
687687
assert (
688-
"ACT I SCENE I. London. The palace. Enter <b>KING</b> <b>HENRY</b>, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa
689-
== doc["extra_attributes"]["txt"]
688+
"ACT I SCENE I. London. The palace. Enter <b>KING</b> <b>HENRY</b>, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa
689+
== doc["extra_attributes"]["txt"]
690690
)
691691

692692
q = Query("king henry").paging(0, 1).summarize().highlight()
693693

694694
doc = sorted((await decoded_r.ft().search(q))["results"])[0]
695695
assert "<b>Henry</b> ... " == doc["extra_attributes"]["play"]
696696
assert (
697-
"ACT I SCENE I. London. The palace. Enter <b>KING</b> <b>HENRY</b>, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa
698-
== doc["extra_attributes"]["txt"]
697+
"ACT I SCENE I. London. The palace. Enter <b>KING</b> <b>HENRY</b>, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa
698+
== doc["extra_attributes"]["txt"]
699699
)
700700

701701

@@ -932,10 +932,10 @@ async def test_spell_check(decoded_r: redis.Redis):
932932
res = await decoded_r.ft().spellcheck("lorm", include="dict")
933933
assert len(res["lorm"]) == 3
934934
assert (
935-
res["lorm"][0]["suggestion"],
936-
res["lorm"][1]["suggestion"],
937-
res["lorm"][2]["suggestion"],
938-
) == ("lorem", "lore", "lorm")
935+
res["lorm"][0]["suggestion"],
936+
res["lorm"][1]["suggestion"],
937+
res["lorm"][2]["suggestion"],
938+
) == ("lorem", "lore", "lorm")
939939
assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0")
940940

941941
# test spellcheck exclude
@@ -963,9 +963,9 @@ async def test_spell_check(decoded_r: redis.Redis):
963963
assert "lore" in res["results"]["lorm"][1].keys()
964964
assert "lorm" in res["results"]["lorm"][2].keys()
965965
assert (
966-
res["results"]["lorm"][0]["lorem"],
967-
res["results"]["lorm"][1]["lore"],
968-
) == (0.5, 0)
966+
res["results"]["lorm"][0]["lorem"],
967+
res["results"]["lorm"][1]["lore"],
968+
) == (0.5, 0)
969969

970970
# test spellcheck exclude
971971
res = await decoded_r.ft().spellcheck("lorm", exclude="dict")
@@ -1034,7 +1034,8 @@ async def test_scorer(decoded_r: redis.Redis):
10341034
await decoded_r.hset(
10351035
"doc2",
10361036
mapping={
1037-
"description": "Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do." # noqa
1037+
"description": "Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do."
1038+
# noqa
10381039
},
10391040
)
10401041

@@ -1098,12 +1099,12 @@ async def test_get(decoded_r: redis.Redis):
10981099
)
10991100

11001101
assert [
1101-
["f1", "some valid content dd2", "f2", "this is sample text f2"]
1102-
] == await decoded_r.ft().get("doc2")
1102+
["f1", "some valid content dd2", "f2", "this is sample text f2"]
1103+
] == await decoded_r.ft().get("doc2")
11031104
assert [
1104-
["f1", "some valid content dd1", "f2", "this is sample text f1"],
1105-
["f1", "some valid content dd2", "f2", "this is sample text f2"],
1106-
] == await decoded_r.ft().get("doc1", "doc2")
1105+
["f1", "some valid content dd1", "f2", "this is sample text f1"],
1106+
["f1", "some valid content dd2", "f2", "this is sample text f2"],
1107+
] == await decoded_r.ft().get("doc1", "doc2")
11071108

11081109

11091110
@pytest.mark.redismod
@@ -1305,7 +1306,7 @@ async def test_aggregations_groupby(decoded_r: redis.Redis):
13051306
res = (await decoded_r.ft().aggregate(req))["results"][0]
13061307
assert res["extra_attributes"]["parent"] == "redis"
13071308
assert (
1308-
res["extra_attributes"]["__generated_aliascount_distincttitle"] == "3"
1309+
res["extra_attributes"]["__generated_aliascount_distincttitle"] == "3"
13091310
)
13101311

13111312
req = (
@@ -1317,8 +1318,8 @@ async def test_aggregations_groupby(decoded_r: redis.Redis):
13171318
res = (await decoded_r.ft().aggregate(req))["results"][0]
13181319
assert res["extra_attributes"]["parent"] == "redis"
13191320
assert (
1320-
res["extra_attributes"]["__generated_aliascount_distinctishtitle"]
1321-
== "3"
1321+
res["extra_attributes"]["__generated_aliascount_distinctishtitle"]
1322+
== "3"
13221323
)
13231324

13241325
req = (
@@ -1370,8 +1371,8 @@ async def test_aggregations_groupby(decoded_r: redis.Redis):
13701371
res = (await decoded_r.ft().aggregate(req))["results"][0]
13711372
assert res["extra_attributes"]["parent"] == "redis"
13721373
assert (
1373-
res["extra_attributes"]["__generated_aliasstddevrandom_num"]
1374-
== "3.60555127546"
1374+
res["extra_attributes"]["__generated_aliasstddevrandom_num"]
1375+
== "3.60555127546"
13751376
)
13761377

13771378
req = (
@@ -1383,8 +1384,8 @@ async def test_aggregations_groupby(decoded_r: redis.Redis):
13831384
res = (await decoded_r.ft().aggregate(req))["results"][0]
13841385
assert res["extra_attributes"]["parent"] == "redis"
13851386
assert (
1386-
res["extra_attributes"]["__generated_aliasquantilerandom_num,0.5"]
1387-
== "8"
1387+
res["extra_attributes"]["__generated_aliasquantilerandom_num,0.5"]
1388+
== "8"
13881389
)
13891390

13901391
req = (
@@ -1530,6 +1531,23 @@ async def test_withsuffixtrie(decoded_r: redis.Redis):
15301531
assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"]
15311532

15321533

1534+
@pytest.mark.redismod
1535+
@skip_ifmodversion_lt("2.10.05", "search")
1536+
async def test_aggregations_add_scores(decoded_r: redis.Redis):
1537+
assert await decoded_r.ft().create_index(
1538+
(TextField("name", sortable=True, weight=5.0), NumericField("age", sortable=True))
1539+
)
1540+
1541+
assert await decoded_r.ft().client.hset("doc1", mapping={"name": "bar", "age": "25"})
1542+
assert await decoded_r.ft().client.hset("doc2", mapping={"name": "foo", "age": "19"})
1543+
1544+
req = (aggregations.AggregateRequest("*").add_scores())
1545+
res = await decoded_r.ft().aggregate(req)
1546+
assert len(res.rows) == 2
1547+
assert res.rows[0] == ["__score", "0.2"]
1548+
assert res.rows[1] == ["__score", "0.2"]
1549+
1550+
15331551
@pytest.mark.redismod
15341552
@skip_if_redis_enterprise()
15351553
async def test_search_commands_in_pipeline(decoded_r: redis.Redis):
@@ -1554,9 +1572,9 @@ async def test_search_commands_in_pipeline(decoded_r: redis.Redis):
15541572
assert "doc2" == res[3]["results"][1]["id"]
15551573
assert res[3]["results"][0]["payload"] is None
15561574
assert (
1557-
res[3]["results"][0]["extra_attributes"]
1558-
== res[3]["results"][1]["extra_attributes"]
1559-
== {"txt": "foo bar"}
1575+
res[3]["results"][0]["extra_attributes"]
1576+
== res[3]["results"][1]["extra_attributes"]
1577+
== {"txt": "foo bar"}
15601578
)
15611579

15621580

@@ -1615,5 +1633,5 @@ async def test_binary_and_text_fields(decoded_r: redis.Redis):
16151633
), "The vectors are not equal"
16161634

16171635
assert (
1618-
docs[0]["first_name"] == mixed_data["first_name"]
1636+
docs[0]["first_name"] == mixed_data["first_name"]
16191637
), "The text field is not decoded correctly"

0 commit comments

Comments
 (0)