diff --git a/elasticsearch_dsl/search.py b/elasticsearch_dsl/search.py index b8323c180..761f6611e 100644 --- a/elasticsearch_dsl/search.py +++ b/elasticsearch_dsl/search.py @@ -31,7 +31,7 @@ from .exceptions import IllegalOperation from .query import Bool, Q from .response import Hit, Response -from .utils import AttrDict, DslBase +from .utils import AttrDict, DslBase, recursive_to_dict class QueryProxy(object): @@ -668,7 +668,7 @@ def to_dict(self, count=False, **kwargs): if self._sort: d["sort"] = self._sort - d.update(self._extra) + d.update(recursive_to_dict(self._extra)) if self._source not in (None, {}): d["_source"] = self._source @@ -683,7 +683,7 @@ def to_dict(self, count=False, **kwargs): if self._script_fields: d["script_fields"] = self._script_fields - d.update(kwargs) + d.update(recursive_to_dict(kwargs)) return d def count(self): diff --git a/elasticsearch_dsl/update_by_query.py b/elasticsearch_dsl/update_by_query.py index 1d257b92f..b46b482b1 100644 --- a/elasticsearch_dsl/update_by_query.py +++ b/elasticsearch_dsl/update_by_query.py @@ -19,6 +19,7 @@ from .query import Bool, Q from .response import UpdateByQueryResponse from .search import ProxyDescriptor, QueryProxy, Request +from .utils import recursive_to_dict class UpdateByQuery(Request): @@ -141,9 +142,8 @@ def to_dict(self, **kwargs): if self._script: d["script"] = self._script - d.update(self._extra) - - d.update(kwargs) + d.update(recursive_to_dict(self._extra)) + d.update(recursive_to_dict(kwargs)) return d def execute(self): diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index 50849773a..a081670e0 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -566,3 +566,19 @@ def merge(data, new_data, raise_on_conflict=False): raise ValueError("Incompatible data for key %r, cannot be merged." % key) else: data[key] = value + + +def recursive_to_dict(data): + """Recursively transform objects that potentially have .to_dict() + into dictionary literals by traversing AttrList, AttrDict, list, + tuple, and Mapping types. + """ + if isinstance(data, AttrList): + data = list(data._l_) + elif hasattr(data, "to_dict"): + data = data.to_dict() + if isinstance(data, (list, tuple)): + return type(data)(recursive_to_dict(inner) for inner in data) + elif isinstance(data, collections_abc.Mapping): + return {key: recursive_to_dict(val) for key, val in data.items()} + return data diff --git a/tests/test_search.py b/tests/test_search.py index ec2370ece..35395c53e 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -515,3 +515,65 @@ def test_update_from_dict(): "indices_boost": [{"important-documents": 2}], "_source": ["id", "name"], } == s.to_dict() + + +def test_rescore_query_to_dict(): + s = search.Search(index="index-name") + + positive_query = Q( + "function_score", + query=Q("term", tags="a"), + script_score={"script": "_score * 1"}, + ) + + negative_query = Q( + "function_score", + query=Q("term", tags="b"), + script_score={"script": "_score * -100"}, + ) + + s = s.query(positive_query) + s = s.extra( + rescore={"window_size": 100, "query": {"rescore_query": negative_query}} + ) + assert s.to_dict() == { + "query": { + "function_score": { + "query": {"term": {"tags": "a"}}, + "functions": [{"script_score": {"script": "_score * 1"}}], + } + }, + "rescore": { + "window_size": 100, + "query": { + "rescore_query": { + "function_score": { + "query": {"term": {"tags": "b"}}, + "functions": [{"script_score": {"script": "_score * -100"}}], + } + } + }, + }, + } + + assert s.to_dict( + rescore={"window_size": 10, "query": {"rescore_query": positive_query}} + ) == { + "query": { + "function_score": { + "query": {"term": {"tags": "a"}}, + "functions": [{"script_score": {"script": "_score * 1"}}], + } + }, + "rescore": { + "window_size": 10, + "query": { + "rescore_query": { + "function_score": { + "query": {"term": {"tags": "a"}}, + "functions": [{"script_score": {"script": "_score * 1"}}], + } + } + }, + }, + } diff --git a/tests/test_update_by_query.py b/tests/test_update_by_query.py index 4f2393e1c..9e4db5c70 100644 --- a/tests/test_update_by_query.py +++ b/tests/test_update_by_query.py @@ -38,6 +38,9 @@ def test_ubq_to_dict(): ubq = UpdateByQuery(extra={"size": 5}) assert {"size": 5} == ubq.to_dict() + ubq = UpdateByQuery(extra={"extra_q": Q("term", category="conference")}) + assert {"extra_q": {"term": {"category": "conference"}}} == ubq.to_dict() + def test_complex_example(): ubq = UpdateByQuery() diff --git a/tests/test_utils.py b/tests/test_utils.py index 82f88c60e..38caad45a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -19,7 +19,7 @@ from pytest import raises -from elasticsearch_dsl import serializer, utils +from elasticsearch_dsl import Q, serializer, utils def test_attrdict_pickle(): @@ -94,3 +94,9 @@ def to_dict(self): return 42 assert serializer.serializer.dumps(MyClass()) == "42" + + +def test_recursive_to_dict(): + assert utils.recursive_to_dict({"k": [1, (1.0, {"v": Q("match", key="val")})]}) == { + "k": [1, (1.0, {"v": {"match": {"key": "val"}}})] + }