Skip to content

Commit dd81975

Browse files
committed
PYCBC-1680: Add support for FTS Vector Search pre-filtering
Changes ======= * Added prefilter option to `VectorQuery` * Updated unit tests to confirm search JSON is encoded correctly Results ======= All tests pass Change-Id: I5107074fea06b4486fce150dd868de0079db5804 Reviewed-on: https://review.couchbase.org/c/couchbase-python-client/+/226340 Reviewed-by: Dimitris Christodoulou <[email protected]> Tested-by: Build Bot <[email protected]>
1 parent 56bde35 commit dd81975

File tree

3 files changed

+70
-3
lines changed

3 files changed

+70
-3
lines changed

couchbase/logic/search.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,8 @@ def encode_vector_search(self) -> Optional[List[Dict[str, Any]]]:
969969
encoded_query['vector_base64'] = query.vector_base64
970970
if query.boost is not None:
971971
encoded_query['boost'] = query.boost
972+
if query.prefilter is not None:
973+
encoded_query['filter'] = query.prefilter.encodable
972974
encoded_queries.append(encoded_query)
973975

974976
return encoded_queries

couchbase/logic/vector_search.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
from __future__ import annotations
22

33
from enum import Enum
4-
from typing import (List,
4+
from typing import (TYPE_CHECKING,
5+
List,
56
Optional,
67
Union)
78

89
from couchbase._utils import is_null_or_empty
910
from couchbase.exceptions import InvalidArgumentException
1011
from couchbase.options import VectorSearchOptions
1112

13+
if TYPE_CHECKING:
14+
from couchbase.logic.search_queries import SearchQuery
15+
1216

1317
class VectorQueryCombination(Enum):
1418
""" Specifies how multiple vector searches are combined.
@@ -31,6 +35,7 @@ class VectorQuery:
3135
vector (Union[List[float], str]): The vector to use in the query.
3236
num_candidates (int, optional): Specifies the number of results returned. If provided, must be greater or equal to 1.
3337
boost (float, optional): Add boost to query.
38+
prefilter (`~couchbase.search.SearchQuery`, optional): Specifies a pre-filter to use for the vector query.
3439
3540
Raises:
3641
:class:`~couchbase.exceptions.InvalidArgumentException`: If the vector is not provided.
@@ -46,18 +51,21 @@ def __init__(self,
4651
vector, # type: Union[List[float], str]
4752
num_candidates=None, # type: Optional[int]
4853
boost=None, # type: Optional[float]
54+
prefilter=None, # type: Optional[SearchQuery]
4955
):
5056
if is_null_or_empty(field_name):
5157
raise InvalidArgumentException('Must provide a field name.')
5258
self._field_name = field_name
5359
self._vector = None
5460
self._vector_base64 = None
5561
self._validate_and_set_vector(vector)
56-
self._num_candidates = self._boost = None
62+
self._num_candidates = self._boost = self._prefilter = None
5763
if num_candidates is not None:
5864
self.num_candidates = num_candidates
5965
if boost is not None:
6066
self.boost = boost
67+
if prefilter is not None:
68+
self.prefilter = prefilter
6169

6270
@property
6371
def boost(self) -> Optional[float]:
@@ -98,6 +106,23 @@ def num_candidates(self,
98106
raise InvalidArgumentException('num_candidates must be >= 1.')
99107
self._num_candidates = value
100108

109+
@property
110+
def prefilter(self) -> Optional[SearchQuery]:
111+
"""
112+
Optional[SearchQuery]: Returns vector query's prefilter query, if it exists.
113+
"""
114+
return self._prefilter
115+
116+
@prefilter.setter
117+
def prefilter(self,
118+
value # type: SearchQuery
119+
):
120+
# avoid circular import
121+
from couchbase.logic.search_queries import SearchQuery
122+
if not isinstance(value, SearchQuery):
123+
raise InvalidArgumentException('prefilter must be a SearchQuery.')
124+
self._prefilter = value
125+
101126
@property
102127
def vector(self) -> Optional[List[float]]:
103128
"""
@@ -138,6 +163,7 @@ def create(cls,
138163
vector, # type: Union[List[float], str]
139164
num_candidates=None, # type: Optional[int]
140165
boost=None, # type: Optional[float]
166+
prefilter=None, # type: Optional[SearchQuery]
141167
) -> VectorQuery:
142168
""" Creates a :class:`~couchbase.vector_search.VectorQuery`.
143169
@@ -146,6 +172,7 @@ def create(cls,
146172
vector (Union[List[float], str]): The vector to use in the query.
147173
num_candidates (int, optional): Specifies the number of results returned. If provided, must be greater or equal to 1.
148174
boost (float, optional): Add boost to query.
175+
prefilter (`~couchbase.search.SearchQuery`, optional): Specifies a pre-filter to use for the vector query.
149176
150177
Raises:
151178
:class:`~couchbase.exceptions.InvalidArgumentException`: If the vector is not provided.
@@ -155,7 +182,7 @@ def create(cls,
155182
Returns:
156183
:class:`~couchbase.vector_search.VectorQuery`: The created vector query.
157184
""" # noqa: E501
158-
return cls(field_name, vector, num_candidates=num_candidates, boost=boost)
185+
return cls(field_name, vector, num_candidates=num_candidates, boost=boost, prefilter=prefilter)
159186

160187

161188
class VectorSearch:

couchbase/tests/search_params_t.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -955,6 +955,7 @@ class VectorSearchParamTestSuite:
955955
'test_vector_query_invalid_num_candidates',
956956
'test_vector_query_invalid_vector',
957957
'test_vector_search',
958+
'test_vector_search_with_prefilter',
958959
'test_vector_search_base64',
959960
'test_vector_search_invalid',
960961
'test_vector_search_multiple_queries'
@@ -1047,6 +1048,43 @@ def test_vector_search(self, cb_env):
10471048
encoded_q = cb_env.get_encoded_query(search_query)
10481049
assert exp_json == encoded_q
10491050

1051+
def test_vector_search_with_prefilter(self, cb_env):
1052+
exp_json = {
1053+
'query': {'match_none': None},
1054+
'index_name': cb_env.TEST_INDEX_NAME,
1055+
'metrics': True,
1056+
'show_request': False,
1057+
'vector_search': [
1058+
{
1059+
'field': 'vector_field',
1060+
'vector': self.TEST_VECTOR,
1061+
'k': 3,
1062+
'filter': {
1063+
'match': 'salty beers',
1064+
'analyzer': 'analyzer',
1065+
'boost': 1.5,
1066+
'field': 'field',
1067+
'fuzziness': 1234,
1068+
'prefix_length': 4,
1069+
'operator': 'or'
1070+
}
1071+
}
1072+
]
1073+
}
1074+
1075+
q = search.MatchQuery('salty beers', boost=1.5, analyzer='analyzer',
1076+
field='field', fuzziness=1234, prefix_length=4, match_operator=MatchOperator.OR)
1077+
vector_search = VectorSearch.from_vector_query(VectorQuery('vector_field',
1078+
self.TEST_VECTOR,
1079+
prefilter=q))
1080+
req = SearchRequest.create(vector_search)
1081+
search_query = search.SearchQueryBuilder.create_search_query_from_request(
1082+
cb_env.TEST_INDEX_NAME,
1083+
req
1084+
)
1085+
encoded_q = cb_env.get_encoded_query(search_query)
1086+
assert exp_json == encoded_q
1087+
10501088
def test_vector_search_base64(self, cb_env):
10511089
exp_json = {
10521090
'query': {'match_none': None},

0 commit comments

Comments
 (0)