Skip to content

Recursively call .to_dict() on objects in Search.extras/**kwargs #1458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions elasticsearch_dsl/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .exceptions import IllegalOperation
from .query import Bool, Q
from .response import Hit, Response
from .utils import AttrDict, DslBase
from .utils import AttrDict, DslBase, recursive_to_dict


class QueryProxy(object):
Expand Down Expand Up @@ -668,7 +668,7 @@ def to_dict(self, count=False, **kwargs):
if self._sort:
d["sort"] = self._sort

d.update(self._extra)
d.update(recursive_to_dict(self._extra))

if self._source not in (None, {}):
d["_source"] = self._source
Expand All @@ -683,7 +683,7 @@ def to_dict(self, count=False, **kwargs):
if self._script_fields:
d["script_fields"] = self._script_fields

d.update(kwargs)
d.update(recursive_to_dict(kwargs))
return d

def count(self):
Expand Down
6 changes: 3 additions & 3 deletions elasticsearch_dsl/update_by_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .query import Bool, Q
from .response import UpdateByQueryResponse
from .search import ProxyDescriptor, QueryProxy, Request
from .utils import recursive_to_dict


class UpdateByQuery(Request):
Expand Down Expand Up @@ -141,9 +142,8 @@ def to_dict(self, **kwargs):
if self._script:
d["script"] = self._script

d.update(self._extra)

d.update(kwargs)
d.update(recursive_to_dict(self._extra))
d.update(recursive_to_dict(kwargs))
return d

def execute(self):
Expand Down
16 changes: 16 additions & 0 deletions elasticsearch_dsl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,3 +566,19 @@ def merge(data, new_data, raise_on_conflict=False):
raise ValueError("Incompatible data for key %r, cannot be merged." % key)
else:
data[key] = value


def recursive_to_dict(data):
"""Recursively transform objects that potentially have .to_dict()
into dictionary literals by traversing AttrList, AttrDict, list,
tuple, and Mapping types.
"""
if isinstance(data, AttrList):
data = list(data._l_)
elif hasattr(data, "to_dict"):
data = data.to_dict()
if isinstance(data, (list, tuple)):
return type(data)(recursive_to_dict(inner) for inner in data)
elif isinstance(data, collections_abc.Mapping):
return {key: recursive_to_dict(val) for key, val in data.items()}
return data
62 changes: 62 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,3 +515,65 @@ def test_update_from_dict():
"indices_boost": [{"important-documents": 2}],
"_source": ["id", "name"],
} == s.to_dict()


def test_rescore_query_to_dict():
s = search.Search(index="index-name")

positive_query = Q(
"function_score",
query=Q("term", tags="a"),
script_score={"script": "_score * 1"},
)

negative_query = Q(
"function_score",
query=Q("term", tags="b"),
script_score={"script": "_score * -100"},
)

s = s.query(positive_query)
s = s.extra(
rescore={"window_size": 100, "query": {"rescore_query": negative_query}}
)
assert s.to_dict() == {
"query": {
"function_score": {
"query": {"term": {"tags": "a"}},
"functions": [{"script_score": {"script": "_score * 1"}}],
}
},
"rescore": {
"window_size": 100,
"query": {
"rescore_query": {
"function_score": {
"query": {"term": {"tags": "b"}},
"functions": [{"script_score": {"script": "_score * -100"}}],
}
}
},
},
}

assert s.to_dict(
rescore={"window_size": 10, "query": {"rescore_query": positive_query}}
) == {
"query": {
"function_score": {
"query": {"term": {"tags": "a"}},
"functions": [{"script_score": {"script": "_score * 1"}}],
}
},
"rescore": {
"window_size": 10,
"query": {
"rescore_query": {
"function_score": {
"query": {"term": {"tags": "a"}},
"functions": [{"script_score": {"script": "_score * 1"}}],
}
}
},
},
}
3 changes: 3 additions & 0 deletions tests/test_update_by_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def test_ubq_to_dict():
ubq = UpdateByQuery(extra={"size": 5})
assert {"size": 5} == ubq.to_dict()

ubq = UpdateByQuery(extra={"extra_q": Q("term", category="conference")})
assert {"extra_q": {"term": {"category": "conference"}}} == ubq.to_dict()


def test_complex_example():
ubq = UpdateByQuery()
Expand Down
8 changes: 7 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from pytest import raises

from elasticsearch_dsl import serializer, utils
from elasticsearch_dsl import Q, serializer, utils


def test_attrdict_pickle():
Expand Down Expand Up @@ -94,3 +94,9 @@ def to_dict(self):
return 42

assert serializer.serializer.dumps(MyClass()) == "42"


def test_recursive_to_dict():
assert utils.recursive_to_dict({"k": [1, (1.0, {"v": Q("match", key="val")})]}) == {
"k": [1, (1.0, {"v": {"match": {"key": "val"}}})]
}