Skip to content

Commit 8d772e1

Browse files
support for optional, list, and other type hints
1 parent 89bfdc8 commit 8d772e1

File tree

3 files changed

+76
-12
lines changed

3 files changed

+76
-12
lines changed

elasticsearch_dsl/document_base.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,21 @@
1717

1818
from datetime import date, datetime
1919
from fnmatch import fnmatch
20+
from typing import List, Optional
2021

2122
from .exceptions import ValidationException
22-
from .field import Field, Integer, Float, Boolean, Text, Binary, Date
23+
from .field import (
24+
Binary,
25+
Boolean,
26+
Date,
27+
Field,
28+
Float,
29+
InstrumentedField,
30+
Integer,
31+
Nested,
32+
Object,
33+
Text,
34+
)
2335
from .mapping import Mapping
2436
from .utils import DOC_META_FIELDS, ObjectBase
2537

@@ -35,6 +47,11 @@ def __new__(cls, name, bases, attrs):
3547
attrs["_doc_type"] = DocumentOptions(name, bases, attrs)
3648
return super().__new__(cls, name, bases, attrs)
3749

50+
def __getattr__(cls, attr):
51+
if attr in cls._doc_type.mapping:
52+
return InstrumentedField(attr, cls._doc_type.mapping[attr])
53+
return super().__getattribute__(attr)
54+
3855

3956
class DocumentOptions:
4057
type_annotation_map = {
@@ -53,18 +70,38 @@ def __init__(self, name, bases, attrs):
5370
# create the mapping instance
5471
self.mapping = getattr(meta, "mapping", Mapping())
5572

56-
for name, type_ in attrs.get('__annotations__', {}).items():
57-
if name not in attrs:
58-
if type_ in self.type_annotation_map:
73+
annotations = attrs.get("__annotations__", {})
74+
fields = set([n for n in attrs if isinstance(attrs[n], Field)])
75+
fields.update(annotations.keys())
76+
for name in fields:
77+
if name in attrs:
78+
value = attrs[name]
79+
else:
80+
type_ = annotations[name]
81+
required = True
82+
multi = False
83+
while hasattr(type_, "__origin__"):
84+
if type_.__origin__ == Optional:
85+
required = False
86+
type_ = type_.__args__[0]
87+
elif issubclass(type_.__origin__, List):
88+
multi = True
89+
type_ = type_.__args__[0]
90+
if issubclass(type_, InnerDoc):
91+
field = Nested if multi else Object
92+
field_args = {}
93+
elif type_ in self.type_annotation_map:
5994
field, field_args = self.type_annotation_map[type_]
60-
self.mapping.field(name, field(**field_args))
61-
elif issubclass(type_, Field):
62-
self.mapping.field(name, type_())
63-
64-
# register all declared fields into the mapping
65-
for name, value in list(attrs.items()):
66-
if isinstance(value, Field):
67-
self.mapping.field(name, value)
95+
elif not issubclass(type_, Field):
96+
raise TypeError(f"Cannot map type {type_}")
97+
else:
98+
field = type_
99+
field_args = {}
100+
field_args = {"multi": multi, "required": required, **field_args}
101+
value = field(**field_args)
102+
value._name = name
103+
self.mapping.field(name, value)
104+
if name in attrs:
68105
del attrs[name]
69106

70107
# add all the mappings for meta fields

elasticsearch_dsl/field.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def __init__(self, multi=False, required=False, *args, **kwargs):
7777
"""
7878
self._multi = multi
7979
self._required = required
80+
self._name = None
81+
self._parent = None
8082
super().__init__(*args, **kwargs)
8183

8284
def __getitem__(self, subfield):
@@ -123,6 +125,25 @@ def to_dict(self):
123125
return value
124126

125127

128+
class InstrumentedField:
129+
def __init__(self, name, field):
130+
self._name = name
131+
self._field = field
132+
133+
def __getattr__(self, attr):
134+
f = None
135+
try:
136+
f = self._field[attr]
137+
except KeyError:
138+
pass
139+
if isinstance(f, Field):
140+
return InstrumentedField(f"{self._name}.{attr}", f)
141+
return getattr(self._field, attr)
142+
143+
def __repr__(self):
144+
return self._name
145+
146+
126147
class CustomField(Field):
127148
name = "custom"
128149
_coerce = True

elasticsearch_dsl/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,12 @@ def __getattr__(self, name):
499499
return value
500500
raise
501501

502+
def __setattr__(self, name, value):
503+
if name in self.__class__._doc_type.mapping:
504+
self._d_[name] = value
505+
else:
506+
super().__setattr__(name, value)
507+
502508
def to_dict(self, skip_empty=True):
503509
out = {}
504510
for k, v in self._d_.items():

0 commit comments

Comments
 (0)