Skip to content

Commit ec3680b

Browse files
authored
Vector similarity search support (#1986)
1 parent 9680353 commit ec3680b

File tree

2 files changed

+88
-7
lines changed

2 files changed

+88
-7
lines changed

redis/commands/search/commands.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import itertools
22
import time
3+
from typing import Dict, Union
34

45
from ..helpers import parse_to_dict
56
from ._util import to_string
@@ -377,7 +378,17 @@ def info(self):
377378
it = map(to_string, res)
378379
return dict(zip(it, it))
379380

380-
def _mk_query_args(self, query):
381+
def get_params_args(self, query_params: Dict[str, Union[str, int, float]]):
382+
args = []
383+
if len(query_params) > 0:
384+
args.append("params")
385+
args.append(len(query_params) * 2)
386+
for key, value in query_params.items():
387+
args.append(key)
388+
args.append(value)
389+
return args
390+
391+
def _mk_query_args(self, query, query_params: Dict[str, Union[str, int, float]]):
381392
args = [self.index_name]
382393

383394
if isinstance(query, str):
@@ -387,9 +398,16 @@ def _mk_query_args(self, query):
387398
raise ValueError(f"Bad query type {type(query)}")
388399

389400
args += query.get_args()
401+
if query_params is not None:
402+
args += self.get_params_args(query_params)
403+
390404
return args, query
391405

392-
def search(self, query):
406+
def search(
407+
self,
408+
query: Union[str, Query],
409+
query_params: Dict[str, Union[str, int, float]] = None,
410+
):
393411
"""
394412
Search the index for a given query, and return a result of documents
395413
@@ -401,7 +419,7 @@ def search(self, query):
401419
402420
For more information: https://oss.redis.com/redisearch/Commands/#ftsearch
403421
""" # noqa
404-
args, query = self._mk_query_args(query)
422+
args, query = self._mk_query_args(query, query_params=query_params)
405423
st = time.time()
406424
res = self.execute_command(SEARCH_CMD, *args)
407425

@@ -413,18 +431,26 @@ def search(self, query):
413431
with_scores=query._with_scores,
414432
)
415433

416-
def explain(self, query):
434+
def explain(
435+
self,
436+
query=Union[str, Query],
437+
query_params: Dict[str, Union[str, int, float]] = None,
438+
):
417439
"""Returns the execution plan for a complex query.
418440
419441
For more information: https://oss.redis.com/redisearch/Commands/#ftexplain
420442
""" # noqa
421-
args, query_text = self._mk_query_args(query)
443+
args, query_text = self._mk_query_args(query, query_params=query_params)
422444
return self.execute_command(EXPLAIN_CMD, *args)
423445

424-
def explain_cli(self, query): # noqa
446+
def explain_cli(self, query: Union[str, Query]): # noqa
425447
raise NotImplementedError("EXPLAINCLI will not be implemented.")
426448

427-
def aggregate(self, query):
449+
def aggregate(
450+
self,
451+
query: Union[str, Query],
452+
query_params: Dict[str, Union[str, int, float]] = None,
453+
):
428454
"""
429455
Issue an aggregation query.
430456
@@ -445,6 +471,8 @@ def aggregate(self, query):
445471
cmd = [CURSOR_CMD, "READ", self.index_name] + query.build_args()
446472
else:
447473
raise ValueError("Bad query", query)
474+
if query_params is not None:
475+
cmd += self.get_params_args(query_params)
448476

449477
raw = self.execute_command(*cmd)
450478
return self._get_aggregate_result(raw, query, has_cursor)

tests/test_search.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,3 +1521,56 @@ def test_profile_limited(client):
15211521
)
15221522
assert det["Iterators profile"]["Type"] == "INTERSECT"
15231523
assert len(res.docs) == 3 # check also the search result
1524+
1525+
1526+
@pytest.mark.redismod
1527+
def test_text_params(modclient):
1528+
modclient.flushdb()
1529+
modclient.ft().create_index((TextField("name"),))
1530+
1531+
modclient.ft().add_document("doc1", name="Alice")
1532+
modclient.ft().add_document("doc2", name="Bob")
1533+
modclient.ft().add_document("doc3", name="Carol")
1534+
1535+
params_dict = {"name1": "Alice", "name2": "Bob"}
1536+
q = Query("@name:($name1 | $name2 )")
1537+
res = modclient.ft().search(q, query_params=params_dict)
1538+
assert 2 == res.total
1539+
assert "doc1" == res.docs[0].id
1540+
assert "doc2" == res.docs[1].id
1541+
1542+
1543+
@pytest.mark.redismod
1544+
def test_numeric_params(modclient):
1545+
modclient.flushdb()
1546+
modclient.ft().create_index((NumericField("numval"),))
1547+
1548+
modclient.ft().add_document("doc1", numval=101)
1549+
modclient.ft().add_document("doc2", numval=102)
1550+
modclient.ft().add_document("doc3", numval=103)
1551+
1552+
params_dict = {"min": 101, "max": 102}
1553+
q = Query("@numval:[$min $max]")
1554+
res = modclient.ft().search(q, query_params=params_dict)
1555+
1556+
assert 2 == res.total
1557+
assert "doc1" == res.docs[0].id
1558+
assert "doc2" == res.docs[1].id
1559+
1560+
1561+
@pytest.mark.redismod
1562+
def test_geo_params(modclient):
1563+
1564+
modclient.flushdb()
1565+
modclient.ft().create_index((GeoField("g")))
1566+
modclient.ft().add_document("doc1", g="29.69465, 34.95126")
1567+
modclient.ft().add_document("doc2", g="29.69350, 34.94737")
1568+
modclient.ft().add_document("doc3", g="29.68746, 34.94882")
1569+
1570+
params_dict = {"lat": "34.95126", "lon": "29.69465", "radius": 1000, "units": "km"}
1571+
q = Query("@g:[$lon $lat $radius $units]")
1572+
res = modclient.ft().search(q, query_params=params_dict)
1573+
assert 3 == res.total
1574+
assert "doc1" == res.docs[0].id
1575+
assert "doc2" == res.docs[1].id
1576+
assert "doc3" == res.docs[2].id

0 commit comments

Comments
 (0)