Skip to content

Commit 9f2d54d

Browse files
authored
Merge pull request #1 from rbs333/feat/iss-3408/add-scorer-aggregate
Feat/iss 3408/add scorer aggregate
2 parents 700045c + bb95f1c commit 9f2d54d

File tree

6 files changed

+354
-1
lines changed

6 files changed

+354
-1
lines changed

dev_requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ packaging>=20.4
99
pytest
1010
pytest-asyncio>=0.23.0,<0.24.0
1111
pytest-cov
12-
pytest-profiling
12+
pytest-profiling==1.7.0
1313
pytest-timeout
1414
ujson>=4.2.0
1515
uvloop

doctests/query_agg.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# EXAMPLE: query_agg
2+
# HIDE_START
3+
import json
4+
import redis
5+
from redis.commands.json.path import Path
6+
from redis.commands.search import Search
7+
from redis.commands.search.aggregation import AggregateRequest
8+
from redis.commands.search.field import NumericField, TagField
9+
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
10+
import redis.commands.search.reducers as reducers
11+
12+
r = redis.Redis(decode_responses=True)
13+
14+
# create index
15+
schema = (
16+
TagField("$.condition", as_name="condition"),
17+
NumericField("$.price", as_name="price"),
18+
)
19+
20+
index = r.ft("idx:bicycle")
21+
index.create_index(
22+
schema,
23+
definition=IndexDefinition(prefix=["bicycle:"], index_type=IndexType.JSON),
24+
)
25+
26+
# load data
27+
with open("data/query_em.json") as f:
28+
bicycles = json.load(f)
29+
30+
pipeline = r.pipeline(transaction=False)
31+
for bid, bicycle in enumerate(bicycles):
32+
pipeline.json().set(f'bicycle:{bid}', Path.root_path(), bicycle)
33+
pipeline.execute()
34+
# HIDE_END
35+
36+
# STEP_START agg1
37+
search = Search(r, index_name="idx:bicycle")
38+
aggregate_request = AggregateRequest(query='@condition:{new}') \
39+
.load('__key', 'price') \
40+
.apply(discounted='@price - (@price * 0.1)')
41+
res = search.aggregate(aggregate_request)
42+
print(len(res.rows)) # >>> 5
43+
print(res.rows) # >>> [['__key', 'bicycle:0', ...
44+
#[['__key', 'bicycle:0', 'price', '270', 'discounted', '243'],
45+
# ['__key', 'bicycle:5', 'price', '810', 'discounted', '729'],
46+
# ['__key', 'bicycle:6', 'price', '2300', 'discounted', '2070'],
47+
# ['__key', 'bicycle:7', 'price', '430', 'discounted', '387'],
48+
# ['__key', 'bicycle:8', 'price', '1200', 'discounted', '1080']]
49+
# REMOVE_START
50+
assert len(res.rows) == 5
51+
# REMOVE_END
52+
# STEP_END
53+
54+
# STEP_START agg2
55+
search = Search(r, index_name="idx:bicycle")
56+
aggregate_request = AggregateRequest(query='*') \
57+
.load('price') \
58+
.apply(price_category='@price<1000') \
59+
.group_by('@condition', reducers.sum('@price_category').alias('num_affordable'))
60+
res = search.aggregate(aggregate_request)
61+
print(len(res.rows)) # >>> 3
62+
print(res.rows) # >>>
63+
#[['condition', 'refurbished', 'num_affordable', '1'],
64+
# ['condition', 'used', 'num_affordable', '1'],
65+
# ['condition', 'new', 'num_affordable', '3']]
66+
# REMOVE_START
67+
assert len(res.rows) == 3
68+
# REMOVE_END
69+
# STEP_END
70+
71+
# STEP_START agg3
72+
search = Search(r, index_name="idx:bicycle")
73+
aggregate_request = AggregateRequest(query='*') \
74+
.apply(type="'bicycle'") \
75+
.group_by('@type', reducers.count().alias('num_total'))
76+
res = search.aggregate(aggregate_request)
77+
print(len(res.rows)) # >>> 1
78+
print(res.rows) # >>> [['type', 'bicycle', 'num_total', '10']]
79+
# REMOVE_START
80+
assert len(res.rows) == 1
81+
# REMOVE_END
82+
# STEP_END
83+
84+
# STEP_START agg4
85+
search = Search(r, index_name="idx:bicycle")
86+
aggregate_request = AggregateRequest(query='*') \
87+
.load('__key') \
88+
.group_by('@condition', reducers.tolist('__key').alias('bicycles'))
89+
res = search.aggregate(aggregate_request)
90+
print(len(res.rows)) # >>> 3
91+
print(res.rows) # >>>
92+
#[['condition', 'refurbished', 'bicycles', ['bicycle:9']],
93+
# ['condition', 'used', 'bicycles', ['bicycle:1', 'bicycle:2', 'bicycle:3', 'bicycle:4']],
94+
# ['condition', 'new', 'bicycles', ['bicycle:5', 'bicycle:6', 'bicycle:7', 'bicycle:0', 'bicycle:8']]]
95+
# REMOVE_START
96+
assert len(res.rows) == 3
97+
# REMOVE_END
98+
# STEP_END
99+
100+
# REMOVE_START
101+
# destroy index and data
102+
r.ft("idx:bicycle").dropindex(delete_documents=True)
103+
# REMOVE_END

doctests/query_combined.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# EXAMPLE: query_combined
2+
# HIDE_START
3+
import json
4+
import numpy as np
5+
import redis
6+
import warnings
7+
from redis.commands.json.path import Path
8+
from redis.commands.search.field import NumericField, TagField, TextField, VectorField
9+
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
10+
from redis.commands.search.query import Query
11+
from sentence_transformers import SentenceTransformer
12+
13+
14+
def embed_text(model, text):
15+
return np.array(model.encode(text)).astype(np.float32).tobytes()
16+
17+
warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces.*")
18+
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
19+
query = "Bike for small kids"
20+
query_vector = embed_text(model, query)
21+
22+
r = redis.Redis(decode_responses=True)
23+
24+
# create index
25+
schema = (
26+
TextField("$.description", no_stem=True, as_name="model"),
27+
TagField("$.condition", as_name="condition"),
28+
NumericField("$.price", as_name="price"),
29+
VectorField(
30+
"$.description_embeddings",
31+
"FLAT",
32+
{
33+
"TYPE": "FLOAT32",
34+
"DIM": 384,
35+
"DISTANCE_METRIC": "COSINE",
36+
},
37+
as_name="vector",
38+
),
39+
)
40+
41+
index = r.ft("idx:bicycle")
42+
index.create_index(
43+
schema,
44+
definition=IndexDefinition(prefix=["bicycle:"], index_type=IndexType.JSON),
45+
)
46+
47+
# load data
48+
with open("data/query_vector.json") as f:
49+
bicycles = json.load(f)
50+
51+
pipeline = r.pipeline(transaction=False)
52+
for bid, bicycle in enumerate(bicycles):
53+
pipeline.json().set(f'bicycle:{bid}', Path.root_path(), bicycle)
54+
pipeline.execute()
55+
# HIDE_END
56+
57+
# STEP_START combined1
58+
q = Query("@price:[500 1000] @condition:{new}")
59+
res = index.search(q)
60+
print(res.total) # >>> 1
61+
# REMOVE_START
62+
assert res.total == 1
63+
# REMOVE_END
64+
# STEP_END
65+
66+
# STEP_START combined2
67+
q = Query("kids @price:[500 1000] @condition:{used}")
68+
res = index.search(q)
69+
print(res.total) # >>> 1
70+
# REMOVE_START
71+
assert res.total == 1
72+
# REMOVE_END
73+
# STEP_END
74+
75+
# STEP_START combined3
76+
q = Query("(kids | small) @condition:{used}")
77+
res = index.search(q)
78+
print(res.total) # >>> 2
79+
# REMOVE_START
80+
assert res.total == 2
81+
# REMOVE_END
82+
# STEP_END
83+
84+
# STEP_START combined4
85+
q = Query("@description:(kids | small) @condition:{used}")
86+
res = index.search(q)
87+
print(res.total) # >>> 0
88+
# REMOVE_START
89+
assert res.total == 0
90+
# REMOVE_END
91+
# STEP_END
92+
93+
# STEP_START combined5
94+
q = Query("@description:(kids | small) @condition:{new | used}")
95+
res = index.search(q)
96+
print(res.total) # >>> 0
97+
# REMOVE_START
98+
assert res.total == 0
99+
# REMOVE_END
100+
# STEP_END
101+
102+
# STEP_START combined6
103+
q = Query("@price:[500 1000] -@condition:{new}")
104+
res = index.search(q)
105+
print(res.total) # >>> 2
106+
# REMOVE_START
107+
assert res.total == 2
108+
# REMOVE_END
109+
# STEP_END
110+
111+
# STEP_START combined7
112+
q = Query("(@price:[500 1000] -@condition:{new})=>[KNN 3 @vector $query_vector]").dialect(2)
113+
# put query string here
114+
res = index.search(q,{ 'query_vector': query_vector })
115+
print(res.total) # >>> 2
116+
# REMOVE_START
117+
assert res.total == 2
118+
# REMOVE_END
119+
# STEP_END
120+
121+
# REMOVE_START
122+
# destroy index and data
123+
r.ft("idx:bicycle").dropindex(delete_documents=True)
124+
# REMOVE_END

redis/commands/search/aggregation.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def __init__(self, query: str = "*") -> None:
112112
self._cursor = []
113113
self._dialect = None
114114
self._add_scores = False
115+
self._scorer = "TFIDF"
115116

116117
def load(self, *fields: List[str]) -> "AggregateRequest":
117118
"""
@@ -300,6 +301,17 @@ def add_scores(self) -> "AggregateRequest":
300301
self._add_scores = True
301302
return self
302303

304+
def scorer(self, scorer: str) -> "AggregateRequest":
305+
"""
306+
Use a different scoring function to evaluate document relevance.
307+
Default is `TFIDF`.
308+
309+
:param scorer: The scoring function to use
310+
(e.g. `TFIDF.DOCNORM` or `BM25`)
311+
"""
312+
self._scorer = scorer
313+
return self
314+
303315
def verbatim(self) -> "AggregateRequest":
304316
self._verbatim = True
305317
return self
@@ -323,6 +335,9 @@ def build_args(self) -> List[str]:
323335
if self._verbatim:
324336
ret.append("VERBATIM")
325337

338+
if self._scorer:
339+
ret.extend(["SCORER", self._scorer])
340+
326341
if self._add_scores:
327342
ret.append("ADDSCORES")
328343

@@ -332,6 +347,7 @@ def build_args(self) -> List[str]:
332347
if self._loadall:
333348
ret.append("LOAD")
334349
ret.append("*")
350+
335351
elif self._loadfields:
336352
ret.append("LOAD")
337353
ret.append(str(len(self._loadfields)))

tests/test_asyncio/test_search.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1556,6 +1556,61 @@ async def test_aggregations_add_scores(decoded_r: redis.Redis):
15561556
assert res.rows[1] == ["__score", "0.2"]
15571557

15581558

1559+
@pytest.mark.redismod
1560+
@skip_ifmodversion_lt("2.10.05", "search")
1561+
async def test_aggregations_hybrid_scoring(decoded_r: redis.Redis):
1562+
assert await decoded_r.ft().create_index(
1563+
(
1564+
TextField("name", sortable=True, weight=5.0),
1565+
TextField("description", sortable=True, weight=5.0),
1566+
VectorField(
1567+
"vector",
1568+
"HNSW",
1569+
{"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"},
1570+
),
1571+
)
1572+
)
1573+
1574+
assert await decoded_r.hset(
1575+
"doc1",
1576+
mapping={
1577+
"name": "cat book",
1578+
"description": "an animal book about cats",
1579+
"vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(),
1580+
},
1581+
)
1582+
assert await decoded_r.hset(
1583+
"doc2",
1584+
mapping={
1585+
"name": "dog book",
1586+
"description": "an animal book about dogs",
1587+
"vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(),
1588+
},
1589+
)
1590+
1591+
query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]"
1592+
req = (
1593+
aggregations.AggregateRequest(query_string)
1594+
.scorer("BM25")
1595+
.add_scores()
1596+
.apply(hybrid_score="@__score + @dist")
1597+
.load("*")
1598+
.dialect(4)
1599+
)
1600+
1601+
res = await decoded_r.ft().aggregate(
1602+
req,
1603+
query_params={"vec_param": np.array([0.11, 0.22]).astype(np.float32).tobytes()},
1604+
)
1605+
1606+
if isinstance(res, dict):
1607+
assert len(res["results"]) == 2
1608+
else:
1609+
assert len(res.rows) == 2
1610+
for row in res.rows:
1611+
len(row) == 6
1612+
1613+
15591614
@pytest.mark.redismod
15601615
@skip_if_redis_enterprise()
15611616
async def test_search_commands_in_pipeline(decoded_r: redis.Redis):

0 commit comments

Comments
 (0)