Skip to content

Commit fed7453

Browse files
committed
>> simplified ModelMetaclass.__new__()
>> fix larray-project#866 : added Parameters class
1 parent 5b0079e commit fed7453

File tree

1 file changed

+94
-28
lines changed

1 file changed

+94
-28
lines changed

larray/core/constrained.py

Lines changed: 94 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@
44

55
import numpy as np
66

7+
from types import FunctionType
78
from typing import TYPE_CHECKING, Type, Optional, Any, Union, Dict, Set, List, Tuple, Callable, no_type_check
89

910
from pydantic.fields import ModelField, Undefined
1011
from pydantic.class_validators import extract_validators, Validator
1112
from pydantic.typing import is_classvar, resolve_annotations
12-
from pydantic.types import PyObject
13-
from pydantic.utils import validate_field_name, lenient_issubclass
14-
from pydantic.main import UNTOUCHED_TYPES, inherit_config, is_valid_field
1513

1614
from larray.core.metadata import Metadata
1715
from larray.core.axis import AxisCollection
@@ -168,12 +166,11 @@ def inherit_validators(base_validators: Dict[str, List[Validator]],
168166
for base in reversed(bases):
169167
if issubclass(base, ConstrainedSession) and base != ConstrainedSession:
170168
fields.update(deepcopy(base.__fields__))
171-
config = inherit_config(base.__config__, config)
172169
validators = inherit_validators(base.__validators__, validators)
173170

174-
config = inherit_config(namespace.get('Config'), config)
175171
validators = inherit_validators(extract_validators(namespace), validators)
176172

173+
# update fields inherited from base classes
177174
for field in fields.values():
178175
field.set_config(config)
179176
extra_validators = validators.get(field.name, [])
@@ -182,22 +179,34 @@ def inherit_validators(base_validators: Dict[str, List[Validator]],
182179
# re-run prepare to add extra validators
183180
field.populate_validators()
184181

182+
untouched_types = (FunctionType, property, type, classmethod, staticmethod)
183+
184+
def validate_field_name(bases: List[Any], field_name: str) -> None:
185+
"""
186+
Ensure that the field's name does not shadow an existing attribute of the model.
187+
"""
188+
for base in bases:
189+
if getattr(base, field_name, None):
190+
raise NameError(f"Variable name '{field_name}' shadows a '{base.__name__}' attribute.")
191+
192+
# extract and build fields
185193
class_vars = set()
186194
if (namespace.get('__module__'), namespace.get('__qualname__')) != \
187195
('larray.core.constrained', 'ConstrainedSession'):
188-
annotations = resolve_annotations(namespace.get('__annotations__', {}), namespace.get('__module__', None))
189-
untouched_types = UNTOUCHED_TYPES
196+
190197
# annotation only fields need to come first in fields
198+
annotations = resolve_annotations(namespace.get('__annotations__', {}), namespace.get('__module__', None))
191199
for ann_name, ann_type in annotations.items():
192200
if is_classvar(ann_type):
193201
class_vars.add(ann_name)
194-
elif is_valid_field(ann_name):
202+
elif not ann_name.startswith('_'):
195203
validate_field_name(bases, ann_name)
196204
value = namespace.get(ann_name, Undefined)
205+
cls = getattr(ann_type, '__origin__', None)
197206
if (
198-
isinstance(value, untouched_types)
199-
and ann_type != PyObject
200-
and not lenient_issubclass(getattr(ann_type, '__origin__', None), Type)
207+
isinstance(value, untouched_types)
208+
# and ann_type != PyObject (from pydantic.types import PyObject)
209+
and not (isinstance(cls, type) and issubclass(cls, Type))
201210
):
202211
continue
203212
fields[ann_name] = ModelField.infer(
@@ -210,10 +219,12 @@ def inherit_validators(base_validators: Dict[str, List[Validator]],
210219

211220
for var_name, value in namespace.items():
212221
if (
213-
var_name not in annotations
214-
and is_valid_field(var_name)
215-
and not isinstance(value, untouched_types)
216-
and var_name not in class_vars
222+
# namespace.items() contains annotated fields with default values
223+
var_name not in annotations
224+
and not var_name.startswith('_')
225+
and not isinstance(value, untouched_types)
226+
# avoid to update a field if it was redeclared (by mistake)
227+
and var_name not in class_vars
217228
):
218229
validate_field_name(bases, var_name)
219230
inferred = ModelField.infer(
@@ -269,10 +280,6 @@ class ConstrainedSession(Session, metaclass=ModelMetaclass):
269280
* meta : list of pairs or dict or OrderedDict or Metadata, optional
270281
Metadata (title, description, author, creation_date, ...) associated with the array.
271282
Keys must be strings. Values must be of type string, int, float, date, time or datetime.
272-
* allow_extra: bool, optional
273-
Whether to allow or forbid extra variables during session initialization (and after). Defaults to True.
274-
* allow_mutation: bool, optional
275-
Whether or not variables can be modified after initialization. Defaults to True.
276283
277284
Warnings
278285
--------
@@ -283,7 +290,7 @@ class ConstrainedSession(Session, metaclass=ModelMetaclass):
283290
284291
See Also
285292
--------
286-
Session
293+
Session, Parameters
287294
288295
Examples
289296
--------
@@ -308,7 +315,7 @@ class ConstrainedSession(Session, metaclass=ModelMetaclass):
308315
... # --- declare variables with a default value ---
309316
... # The default value will be used to set the variable if no value is passed at instantiation (see below).
310317
... # Such variable will be constrained by the type deduced from its default value.
311-
... # target_age: Group = AGE[18:] >> 'adults'
318+
... target_age: Group = AGE[:2] >> '0-2'
312319
... population = zeros((AGE, GENDER, TIME), dtype=int)
313320
... # --- declare constrained arrays ---
314321
... # the constrained arrays have axes assumed to be "frozen", meaning they are
@@ -397,6 +404,7 @@ class ConstrainedSession(Session, metaclass=ModelMetaclass):
397404
dumping variant_name ... done
398405
dumping birth_rate ... done
399406
dumping births ... done
407+
dumping target_age ... done
400408
dumping mortality_rate ... done
401409
dumping deaths ... done
402410
dumping population ... done
@@ -409,18 +417,16 @@ class ConstrainedSession(Session, metaclass=ModelMetaclass):
409417
__validators__: Dict[str, List[Validator]] = {}
410418
__config__: Type[BaseConfig] = BaseConfig
411419

420+
__allow_extra__ = True
421+
__allow_mutation__ = True
422+
412423
# Warning: order of fields is not preserved.
413424
# As of v1.0 of pydantic all fields with annotations (whether annotation-only or with a default value)
414425
# will precede all fields without an annotation. Within their respective groups, fields remain in the
415426
# order they were defined.
416427
# See https://pydantic-docs.helpmanual.io/usage/models/#field-ordering
417-
# Furthermore, among fields with annotations those with default values are put after
418-
# Uses something other than `self` the first arg to allow "self" as a settable attribute
419428
def __init__(self, *args, **kwargs):
420429
meta = kwargs.pop('meta', Metadata())
421-
self.__config__.allow_extra = kwargs.pop('allow_extra', True)
422-
self.__config__.allow_mutation = kwargs.pop('allow_mutation', True)
423-
424430
Session.__init__(self, meta=meta)
425431

426432
# create an intermediate Session object to not call the __setattr__
@@ -449,11 +455,12 @@ def __init__(self, *args, **kwargs):
449455

450456
# code of the method below has been partly borrowed from pydantic.BaseModel.__setattr__()
451457
def _check_key_value(self, name: str, value: Any, skip_allow_mutation: bool, skip_validation: bool) -> Any:
452-
if not self.__config__.allow_extra and name not in self.__fields__:
458+
cls = self.__class__
459+
if not cls.__allow_extra__ and name not in self.__fields__:
453460
raise ValueError(f"Variable '{name}' is not declared in '{self.__class__.__name__}'. "
454461
f"Adding undeclared variables is forbidden. "
455462
f"List of declared variables is: {list(self.__fields__.keys())}.")
456-
if not skip_allow_mutation and not self.__config__.allow_mutation:
463+
if not skip_allow_mutation and not cls.__allow_mutation__:
457464
raise TypeError(f"Cannot change the value of the variable '{name}' since '{self.__class__.__name__}' "
458465
f"is immutable and does not support item assignment")
459466
known_field = self.__fields__.get(name, None)
@@ -492,3 +499,62 @@ def dict(self, exclude: Set[str] = None):
492499
if name in d:
493500
del d[name]
494501
return d
502+
503+
504+
class Parameters(ConstrainedSession):
505+
"""
506+
Same as py:class:`ConstrainedSession` but:
507+
508+
* declared variables cannot be modified after initialization
509+
* adding undeclared variables after initialization is forbidden.
510+
511+
Parameters
512+
----------
513+
*args : str or dict of {str: object} or iterable of tuples (str, object)
514+
Path to the file containing the session to load or
515+
list/tuple/dictionary containing couples (name, object).
516+
**kwargs : dict of {str: object}
517+
518+
* Objects to add written as name=object
519+
* meta : list of pairs or dict or OrderedDict or Metadata, optional
520+
Metadata (title, description, author, creation_date, ...) associated with the array.
521+
Keys must be strings. Values must be of type string, int, float, date, time or datetime.
522+
523+
See Also
524+
--------
525+
ConstrainedSession
526+
527+
Examples
528+
--------
529+
530+
Content of file 'parameters.py'
531+
532+
>>> from larray import *
533+
>>> class ModelParameters(Parameters):
534+
... # --- declare variables with fixed values ---
535+
... # The given values can never be changed
536+
... FIRST_YEAR = 2020
537+
... LAST_YEAR = 2030
538+
... AGE = Axis('age=0..10')
539+
... GENDER = Axis('gender=male,female')
540+
... TIME = Axis(f'time={FIRST_YEAR}..{LAST_YEAR}')
541+
... # --- declare variables with defined types ---
542+
... # Their values must be defined at initialized and will be frozen after.
543+
... variant_name: str
544+
545+
Content of file 'model.py'
546+
547+
>>> # instantiation --> create an instance of the ModelVariables class
548+
>>> # all variables declared without value must be set
549+
>>> P = ModelParameters(variant_name='variant_1')
550+
>>> # once an instance is create, its variables can be accessed but not modified
551+
>>> P.variant_name
552+
'variant_1'
553+
>>> P.variant_name = 'new_variant' # doctest: +NORMALIZE_WHITESPACE
554+
Traceback (most recent call last):
555+
...
556+
TypeError: Cannot change the value of the variable 'variant_name' since 'ModelParameters'
557+
is immutable and does not support item assignment
558+
"""
559+
__allow_extra__ = False
560+
__allow_mutation__ = False

0 commit comments

Comments
 (0)