Skip to content

Commit 99b787c

Browse files
authored
Recursively call .to_dict() on objects in Search.extras/**kwargs
1 parent 5487df0 commit 99b787c

File tree

6 files changed

+94
-7
lines changed

6 files changed

+94
-7
lines changed

elasticsearch_dsl/search.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from .exceptions import IllegalOperation
3232
from .query import Bool, Q
3333
from .response import Hit, Response
34-
from .utils import AttrDict, DslBase
34+
from .utils import AttrDict, DslBase, recursive_to_dict
3535

3636

3737
class QueryProxy(object):
@@ -668,7 +668,7 @@ def to_dict(self, count=False, **kwargs):
668668
if self._sort:
669669
d["sort"] = self._sort
670670

671-
d.update(self._extra)
671+
d.update(recursive_to_dict(self._extra))
672672

673673
if self._source not in (None, {}):
674674
d["_source"] = self._source
@@ -683,7 +683,7 @@ def to_dict(self, count=False, **kwargs):
683683
if self._script_fields:
684684
d["script_fields"] = self._script_fields
685685

686-
d.update(kwargs)
686+
d.update(recursive_to_dict(kwargs))
687687
return d
688688

689689
def count(self):

elasticsearch_dsl/update_by_query.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .query import Bool, Q
2020
from .response import UpdateByQueryResponse
2121
from .search import ProxyDescriptor, QueryProxy, Request
22+
from .utils import recursive_to_dict
2223

2324

2425
class UpdateByQuery(Request):
@@ -141,9 +142,8 @@ def to_dict(self, **kwargs):
141142
if self._script:
142143
d["script"] = self._script
143144

144-
d.update(self._extra)
145-
146-
d.update(kwargs)
145+
d.update(recursive_to_dict(self._extra))
146+
d.update(recursive_to_dict(kwargs))
147147
return d
148148

149149
def execute(self):

elasticsearch_dsl/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,3 +566,19 @@ def merge(data, new_data, raise_on_conflict=False):
566566
raise ValueError("Incompatible data for key %r, cannot be merged." % key)
567567
else:
568568
data[key] = value
569+
570+
571+
def recursive_to_dict(data):
572+
"""Recursively transform objects that potentially have .to_dict()
573+
into dictionary literals by traversing AttrList, AttrDict, list,
574+
tuple, and Mapping types.
575+
"""
576+
if isinstance(data, AttrList):
577+
data = list(data._l_)
578+
elif hasattr(data, "to_dict"):
579+
data = data.to_dict()
580+
if isinstance(data, (list, tuple)):
581+
return type(data)(recursive_to_dict(inner) for inner in data)
582+
elif isinstance(data, collections_abc.Mapping):
583+
return {key: recursive_to_dict(val) for key, val in data.items()}
584+
return data

tests/test_search.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,3 +515,65 @@ def test_update_from_dict():
515515
"indices_boost": [{"important-documents": 2}],
516516
"_source": ["id", "name"],
517517
} == s.to_dict()
518+
519+
520+
def test_rescore_query_to_dict():
521+
s = search.Search(index="index-name")
522+
523+
positive_query = Q(
524+
"function_score",
525+
query=Q("term", tags="a"),
526+
script_score={"script": "_score * 1"},
527+
)
528+
529+
negative_query = Q(
530+
"function_score",
531+
query=Q("term", tags="b"),
532+
script_score={"script": "_score * -100"},
533+
)
534+
535+
s = s.query(positive_query)
536+
s = s.extra(
537+
rescore={"window_size": 100, "query": {"rescore_query": negative_query}}
538+
)
539+
assert s.to_dict() == {
540+
"query": {
541+
"function_score": {
542+
"query": {"term": {"tags": "a"}},
543+
"functions": [{"script_score": {"script": "_score * 1"}}],
544+
}
545+
},
546+
"rescore": {
547+
"window_size": 100,
548+
"query": {
549+
"rescore_query": {
550+
"function_score": {
551+
"query": {"term": {"tags": "b"}},
552+
"functions": [{"script_score": {"script": "_score * -100"}}],
553+
}
554+
}
555+
},
556+
},
557+
}
558+
559+
assert s.to_dict(
560+
rescore={"window_size": 10, "query": {"rescore_query": positive_query}}
561+
) == {
562+
"query": {
563+
"function_score": {
564+
"query": {"term": {"tags": "a"}},
565+
"functions": [{"script_score": {"script": "_score * 1"}}],
566+
}
567+
},
568+
"rescore": {
569+
"window_size": 10,
570+
"query": {
571+
"rescore_query": {
572+
"function_score": {
573+
"query": {"term": {"tags": "a"}},
574+
"functions": [{"script_score": {"script": "_score * 1"}}],
575+
}
576+
}
577+
},
578+
},
579+
}

tests/test_update_by_query.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def test_ubq_to_dict():
3838
ubq = UpdateByQuery(extra={"size": 5})
3939
assert {"size": 5} == ubq.to_dict()
4040

41+
ubq = UpdateByQuery(extra={"extra_q": Q("term", category="conference")})
42+
assert {"extra_q": {"term": {"category": "conference"}}} == ubq.to_dict()
43+
4144

4245
def test_complex_example():
4346
ubq = UpdateByQuery()

tests/test_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from pytest import raises
2121

22-
from elasticsearch_dsl import serializer, utils
22+
from elasticsearch_dsl import Q, serializer, utils
2323

2424

2525
def test_attrdict_pickle():
@@ -94,3 +94,9 @@ def to_dict(self):
9494
return 42
9595

9696
assert serializer.serializer.dumps(MyClass()) == "42"
97+
98+
99+
def test_recursive_to_dict():
100+
assert utils.recursive_to_dict({"k": [1, (1.0, {"v": Q("match", key="val")})]}) == {
101+
"k": [1, (1.0, {"v": {"match": {"key": "val"}}})]
102+
}

0 commit comments

Comments
 (0)