Skip to content

Commit 601bea7

Browse files
Autogenerate aggregation responses and hits (#1932)
* typing of aggregation responses * buckets_as_dict property * typing of the `meta` attribute of hits * more minor adjustments to response types
1 parent b08dfdc commit 601bea7

14 files changed

+2084
-571
lines changed

elasticsearch_dsl/response/__init__.py

+92-6
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,18 @@
4040
from ..search_base import Request, SearchBase
4141
from ..update_by_query_base import UpdateByQueryBase
4242

43-
__all__ = ["Response", "AggResponse", "UpdateByQueryResponse", "Hit", "HitMeta"]
43+
__all__ = [
44+
"Response",
45+
"AggResponse",
46+
"UpdateByQueryResponse",
47+
"Hit",
48+
"HitMeta",
49+
"AggregateResponseType",
50+
]
4451

4552

4653
class Response(AttrDict[Any], Generic[_R]):
47-
"""An Elasticsearch response.
54+
"""An Elasticsearch search response.
4855
4956
:arg took: (required)
5057
:arg timed_out: (required)
@@ -195,21 +202,100 @@ def search_after(self) -> "SearchBase[_R]":
195202
return self._search.extra(search_after=self.hits[-1].meta.sort) # type: ignore
196203

197204

205+
AggregateResponseType = Union[
206+
"types.CardinalityAggregate",
207+
"types.HdrPercentilesAggregate",
208+
"types.HdrPercentileRanksAggregate",
209+
"types.TDigestPercentilesAggregate",
210+
"types.TDigestPercentileRanksAggregate",
211+
"types.PercentilesBucketAggregate",
212+
"types.MedianAbsoluteDeviationAggregate",
213+
"types.MinAggregate",
214+
"types.MaxAggregate",
215+
"types.SumAggregate",
216+
"types.AvgAggregate",
217+
"types.WeightedAvgAggregate",
218+
"types.ValueCountAggregate",
219+
"types.SimpleValueAggregate",
220+
"types.DerivativeAggregate",
221+
"types.BucketMetricValueAggregate",
222+
"types.StatsAggregate",
223+
"types.StatsBucketAggregate",
224+
"types.ExtendedStatsAggregate",
225+
"types.ExtendedStatsBucketAggregate",
226+
"types.GeoBoundsAggregate",
227+
"types.GeoCentroidAggregate",
228+
"types.HistogramAggregate",
229+
"types.DateHistogramAggregate",
230+
"types.AutoDateHistogramAggregate",
231+
"types.VariableWidthHistogramAggregate",
232+
"types.StringTermsAggregate",
233+
"types.LongTermsAggregate",
234+
"types.DoubleTermsAggregate",
235+
"types.UnmappedTermsAggregate",
236+
"types.LongRareTermsAggregate",
237+
"types.StringRareTermsAggregate",
238+
"types.UnmappedRareTermsAggregate",
239+
"types.MultiTermsAggregate",
240+
"types.MissingAggregate",
241+
"types.NestedAggregate",
242+
"types.ReverseNestedAggregate",
243+
"types.GlobalAggregate",
244+
"types.FilterAggregate",
245+
"types.ChildrenAggregate",
246+
"types.ParentAggregate",
247+
"types.SamplerAggregate",
248+
"types.UnmappedSamplerAggregate",
249+
"types.GeoHashGridAggregate",
250+
"types.GeoTileGridAggregate",
251+
"types.GeoHexGridAggregate",
252+
"types.RangeAggregate",
253+
"types.DateRangeAggregate",
254+
"types.GeoDistanceAggregate",
255+
"types.IpRangeAggregate",
256+
"types.IpPrefixAggregate",
257+
"types.FiltersAggregate",
258+
"types.AdjacencyMatrixAggregate",
259+
"types.SignificantLongTermsAggregate",
260+
"types.SignificantStringTermsAggregate",
261+
"types.UnmappedSignificantTermsAggregate",
262+
"types.CompositeAggregate",
263+
"types.FrequentItemSetsAggregate",
264+
"types.TimeSeriesAggregate",
265+
"types.ScriptedMetricAggregate",
266+
"types.TopHitsAggregate",
267+
"types.InferenceAggregate",
268+
"types.StringStatsAggregate",
269+
"types.BoxPlotAggregate",
270+
"types.TopMetricsAggregate",
271+
"types.TTestAggregate",
272+
"types.RateAggregate",
273+
"types.CumulativeCardinalityAggregate",
274+
"types.MatrixStatsAggregate",
275+
"types.GeoLineAggregate",
276+
]
277+
278+
198279
class AggResponse(AttrDict[Any], Generic[_R]):
280+
"""An Elasticsearch aggregation response."""
281+
199282
_meta: Dict[str, Any]
200283

201284
def __init__(self, aggs: "Agg[_R]", search: "Request[_R]", data: Dict[str, Any]):
202285
super(AttrDict, self).__setattr__("_meta", {"search": search, "aggs": aggs})
203286
super().__init__(data)
204287

205-
def __getitem__(self, attr_name: str) -> Any:
288+
def __getitem__(self, attr_name: str) -> AggregateResponseType:
206289
if attr_name in self._meta["aggs"]:
207290
# don't do self._meta['aggs'][attr_name] to avoid copying
208291
agg = self._meta["aggs"].aggs[attr_name]
209-
return agg.result(self._meta["search"], self._d_[attr_name])
210-
return super().__getitem__(attr_name)
292+
return cast(
293+
AggregateResponseType,
294+
agg.result(self._meta["search"], self._d_[attr_name]),
295+
)
296+
return super().__getitem__(attr_name) # type: ignore
211297

212-
def __iter__(self) -> Iterator["Agg"]: # type: ignore[override]
298+
def __iter__(self) -> Iterator[AggregateResponseType]: # type: ignore[override]
213299
for name in self._meta["aggs"]:
214300
yield self[name]
215301

0 commit comments

Comments
 (0)