diff --git a/elasticsearch_dsl/search.py b/elasticsearch_dsl/search.py index 13b93bd8..4ffff188 100644 --- a/elasticsearch_dsl/search.py +++ b/elasticsearch_dsl/search.py @@ -318,8 +318,9 @@ def __init__(self, **kwargs): self.aggs = AggsProxy(self) self._sort = [] - self._collapse = {} self._knn = [] + self._rank = {} + self._collapse = {} self._source = None self._highlight = {} self._highlight_opts = {} @@ -408,6 +409,7 @@ def _clone(self): s._response_class = self._response_class s._knn = [knn.copy() for knn in self._knn] + s._rank = self._rank.copy() s._collapse = self._collapse.copy() s._sort = self._sort[:] s._source = copy.copy(self._source) if self._source is not None else None @@ -451,6 +453,8 @@ def update_from_dict(self, d): self._knn = d.pop("knn") if isinstance(self._knn, dict): self._knn = [self._knn] + if "rank" in d: + self._rank = d.pop("rank") if "collapse" in d: self._collapse = d.pop("collapse") if "sort" in d: @@ -558,6 +562,27 @@ def knn( s._knn[-1]["similarity"] = similarity return s + def rank(self, rrf=None): + """ + Defines a method for combining and ranking results sets from a combination + of searches. Requires a minimum of 2 results sets. + + :arg rrf: Set to ``True`` or an options dictionary to set the rank method to reciprocal rank fusion (RRF). + + Example:: + s = Search() + s = s.query('match', content='search text') + s = s.knn(field='embedding', k=5, num_candidates=10, query_vector=vector) + s = s.rank(rrf=True) + + Note: This option is in technical preview and may change in the future. The syntax will likely change before GA. + """ + s = self._clone() + s._rank = {} + if rrf is not None and rrf is not False: + s._rank["rrf"] = {} if rrf is True else rrf + return s + def source(self, fields=None, **kwargs): """ Selectively control how the _source field is returned. @@ -747,6 +772,9 @@ def to_dict(self, count=False, **kwargs): else: d["knn"] = self._knn + if self._rank: + d["rank"] = self._rank + # count request doesn't care for sorting and other things if not count: if self.post_filter: diff --git a/tests/test_search.py b/tests/test_search.py index 3b47b821..841caa7c 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -288,6 +288,18 @@ def test_knn(): } == s.to_dict() +def test_rank(): + s = search.Search() + s.rank(rrf=False) + assert {} == s.to_dict() + + s = s.rank(rrf=True) + assert {"rank": {"rrf": {}}} == s.to_dict() + + s = s.rank(rrf={"window_size": 50, "rank_constant": 20}) + assert {"rank": {"rrf": {"window_size": 50, "rank_constant": 20}}} == s.to_dict() + + def test_sort(): s = search.Search() s = s.sort("fielda", "-fieldb")