diff --git a/elasticsearch_dsl/search.py b/elasticsearch_dsl/search.py index 724c06a52..1fa5b5906 100644 --- a/elasticsearch_dsl/search.py +++ b/elasticsearch_dsl/search.py @@ -120,6 +120,7 @@ def __init__(self, using="default", index=None, doc_type=None, extra=None): self._doc_type = [] self._doc_type_map = {} + self._collapse = {} if isinstance(doc_type, (tuple, list)): self._doc_type.extend(doc_type) elif isinstance(doc_type, collections.abc.Mapping): @@ -293,6 +294,7 @@ def _clone(self): s = self.__class__( using=self._using, index=self._index, doc_type=self._doc_type ) + s._collapse = self._collapse.copy() s._doc_type_map = self._doc_type_map.copy() s._extra = self._extra.copy() s._params = self._params.copy() @@ -318,6 +320,7 @@ def __init__(self, **kwargs): self.aggs = AggsProxy(self) self._sort = [] + self._collapse = {} self._source = None self._highlight = {} self._highlight_opts = {} @@ -568,6 +571,27 @@ def sort(self, *keys): s._sort.append(k) return s + def collapse(self, field=None, inner_hits=None, max_concurrent_group_searches=None): + """ + Add collapsing information to the search request. + If called without providing ``field``, it will remove all collapse + requirements, otherwise it will replace them with the provided + arguments. + The API returns a copy of the Search object and can thus be chained. + """ + s = self._clone() + s._collapse = {} + + if field is None: + return s + + s._collapse["field"] = field + if inner_hits: + s._collapse["inner_hits"] = inner_hits + if max_concurrent_group_searches: + s._collapse["max_concurrent_group_searches"] = max_concurrent_group_searches + return s + def highlight_options(self, **kwargs): """ Update the global highlighting options used for this request. For @@ -663,6 +687,9 @@ def to_dict(self, count=False, **kwargs): if self._sort: d["sort"] = self._sort + if self._collapse: + d["collapse"] = self._collapse + d.update(recursive_to_dict(self._extra)) if self._source not in (None, {}): diff --git a/tests/test_search.py b/tests/test_search.py index 4da824182..ff1eed430 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -256,6 +256,38 @@ def test_sort_by_score(): s.sort("-_score") +def test_collapse(): + s = search.Search() + + inner_hits = {"name": "most_recent", "size": 5, "sort": [{"@timestamp": "desc"}]} + s = s.collapse("user.id", inner_hits=inner_hits, max_concurrent_group_searches=4) + + assert { + "field": "user.id", + "inner_hits": { + "name": "most_recent", + "size": 5, + "sort": [{"@timestamp": "desc"}], + }, + "max_concurrent_group_searches": 4, + } == s._collapse + assert { + "collapse": { + "field": "user.id", + "inner_hits": { + "name": "most_recent", + "size": 5, + "sort": [{"@timestamp": "desc"}], + }, + "max_concurrent_group_searches": 4, + } + } == s.to_dict() + + s = s.collapse() + assert {} == s._collapse + assert search.Search().to_dict() == s.to_dict() + + def test_slice(): s = search.Search() assert {"from": 3, "size": 7} == s[3:10].to_dict() @@ -305,6 +337,7 @@ def test_complex_example(): s.query("match", title="python") .query(~Q("match", title="ruby")) .filter(Q("term", category="meetup") | Q("term", category="conference")) + .collapse("user_id") .post_filter("terms", tags=["prague", "czech"]) .script_fields(more_attendees="doc['attendees'].value + 42") ) @@ -342,6 +375,7 @@ def test_complex_example(): "aggs": {"avg_attendees": {"avg": {"field": "attendees"}}}, } }, + "collapse": {"field": "user_id"}, "highlight": { "order": "score", "fields": {"title": {"fragment_size": 50}, "body": {"fragment_size": 50}},