4
4
5
5
import numpy as np
6
6
7
+ from types import FunctionType
7
8
from typing import TYPE_CHECKING , Type , Optional , Any , Union , Dict , Set , List , Tuple , Callable , no_type_check
8
9
9
10
from pydantic .fields import ModelField , Undefined
10
11
from pydantic .class_validators import extract_validators , Validator
11
12
from pydantic .typing import is_classvar , resolve_annotations
12
- from pydantic .types import PyObject
13
- from pydantic .utils import validate_field_name , lenient_issubclass
14
- from pydantic .main import UNTOUCHED_TYPES , inherit_config , is_valid_field
15
13
16
14
from larray .core .metadata import Metadata
17
15
from larray .core .axis import AxisCollection
@@ -168,12 +166,11 @@ def inherit_validators(base_validators: Dict[str, List[Validator]],
168
166
for base in reversed (bases ):
169
167
if issubclass (base , ConstrainedSession ) and base != ConstrainedSession :
170
168
fields .update (deepcopy (base .__fields__ ))
171
- config = inherit_config (base .__config__ , config )
172
169
validators = inherit_validators (base .__validators__ , validators )
173
170
174
- config = inherit_config (namespace .get ('Config' ), config )
175
171
validators = inherit_validators (extract_validators (namespace ), validators )
176
172
173
+ # update fields inherited from base classes
177
174
for field in fields .values ():
178
175
field .set_config (config )
179
176
extra_validators = validators .get (field .name , [])
@@ -182,22 +179,34 @@ def inherit_validators(base_validators: Dict[str, List[Validator]],
182
179
# re-run prepare to add extra validators
183
180
field .populate_validators ()
184
181
182
+ untouched_types = (FunctionType , property , type , classmethod , staticmethod )
183
+
184
+ def validate_field_name (bases : List [Any ], field_name : str ) -> None :
185
+ """
186
+ Ensure that the field's name does not shadow an existing attribute of the model.
187
+ """
188
+ for base in bases :
189
+ if getattr (base , field_name , None ):
190
+ raise NameError (f"Variable name '{ field_name } ' shadows a '{ base .__name__ } ' attribute." )
191
+
192
+ # extract and build fields
185
193
class_vars = set ()
186
194
if (namespace .get ('__module__' ), namespace .get ('__qualname__' )) != \
187
195
('larray.core.constrained' , 'ConstrainedSession' ):
188
- annotations = resolve_annotations (namespace .get ('__annotations__' , {}), namespace .get ('__module__' , None ))
189
- untouched_types = UNTOUCHED_TYPES
196
+
190
197
# annotation only fields need to come first in fields
198
+ annotations = resolve_annotations (namespace .get ('__annotations__' , {}), namespace .get ('__module__' , None ))
191
199
for ann_name , ann_type in annotations .items ():
192
200
if is_classvar (ann_type ):
193
201
class_vars .add (ann_name )
194
- elif is_valid_field ( ann_name ):
202
+ elif not ann_name . startswith ( '_' ):
195
203
validate_field_name (bases , ann_name )
196
204
value = namespace .get (ann_name , Undefined )
205
+ cls = getattr (ann_type , '__origin__' , None )
197
206
if (
198
- isinstance (value , untouched_types )
199
- and ann_type != PyObject
200
- and not lenient_issubclass ( getattr ( ann_type , '__origin__' , None ) , Type )
207
+ isinstance (value , untouched_types )
208
+ # and ann_type != PyObject (from pydantic.types import PyObject)
209
+ and not ( isinstance ( cls , type ) and issubclass ( cls , Type ) )
201
210
):
202
211
continue
203
212
fields [ann_name ] = ModelField .infer (
@@ -210,10 +219,12 @@ def inherit_validators(base_validators: Dict[str, List[Validator]],
210
219
211
220
for var_name , value in namespace .items ():
212
221
if (
213
- var_name not in annotations
214
- and is_valid_field (var_name )
215
- and not isinstance (value , untouched_types )
216
- and var_name not in class_vars
222
+ # namespace.items() contains annotated fields with default values
223
+ var_name not in annotations
224
+ and not var_name .startswith ('_' )
225
+ and not isinstance (value , untouched_types )
226
+ # avoid to update a field if it was redeclared (by mistake)
227
+ and var_name not in class_vars
217
228
):
218
229
validate_field_name (bases , var_name )
219
230
inferred = ModelField .infer (
@@ -269,10 +280,6 @@ class ConstrainedSession(Session, metaclass=ModelMetaclass):
269
280
* meta : list of pairs or dict or OrderedDict or Metadata, optional
270
281
Metadata (title, description, author, creation_date, ...) associated with the array.
271
282
Keys must be strings. Values must be of type string, int, float, date, time or datetime.
272
- * allow_extra: bool, optional
273
- Whether to allow or forbid extra variables during session initialization (and after). Defaults to True.
274
- * allow_mutation: bool, optional
275
- Whether or not variables can be modified after initialization. Defaults to True.
276
283
277
284
Warnings
278
285
--------
@@ -283,7 +290,7 @@ class ConstrainedSession(Session, metaclass=ModelMetaclass):
283
290
284
291
See Also
285
292
--------
286
- Session
293
+ Session, Parameters
287
294
288
295
Examples
289
296
--------
@@ -308,7 +315,7 @@ class ConstrainedSession(Session, metaclass=ModelMetaclass):
308
315
... # --- declare variables with a default value ---
309
316
... # The default value will be used to set the variable if no value is passed at instantiation (see below).
310
317
... # Such variable will be constrained by the type deduced from its default value.
311
- ... # target_age: Group = AGE[18: ] >> 'adults '
318
+ ... target_age: Group = AGE[:2 ] >> '0-2 '
312
319
... population = zeros((AGE, GENDER, TIME), dtype=int)
313
320
... # --- declare constrained arrays ---
314
321
... # the constrained arrays have axes assumed to be "frozen", meaning they are
@@ -397,6 +404,7 @@ class ConstrainedSession(Session, metaclass=ModelMetaclass):
397
404
dumping variant_name ... done
398
405
dumping birth_rate ... done
399
406
dumping births ... done
407
+ dumping target_age ... done
400
408
dumping mortality_rate ... done
401
409
dumping deaths ... done
402
410
dumping population ... done
@@ -409,18 +417,16 @@ class ConstrainedSession(Session, metaclass=ModelMetaclass):
409
417
__validators__ : Dict [str , List [Validator ]] = {}
410
418
__config__ : Type [BaseConfig ] = BaseConfig
411
419
420
+ __allow_extra__ = True
421
+ __allow_mutation__ = True
422
+
412
423
# Warning: order of fields is not preserved.
413
424
# As of v1.0 of pydantic all fields with annotations (whether annotation-only or with a default value)
414
425
# will precede all fields without an annotation. Within their respective groups, fields remain in the
415
426
# order they were defined.
416
427
# See https://pydantic-docs.helpmanual.io/usage/models/#field-ordering
417
- # Furthermore, among fields with annotations those with default values are put after
418
- # Uses something other than `self` the first arg to allow "self" as a settable attribute
419
428
def __init__ (self , * args , ** kwargs ):
420
429
meta = kwargs .pop ('meta' , Metadata ())
421
- self .__config__ .allow_extra = kwargs .pop ('allow_extra' , True )
422
- self .__config__ .allow_mutation = kwargs .pop ('allow_mutation' , True )
423
-
424
430
Session .__init__ (self , meta = meta )
425
431
426
432
# create an intermediate Session object to not call the __setattr__
@@ -449,11 +455,12 @@ def __init__(self, *args, **kwargs):
449
455
450
456
# code of the method below has been partly borrowed from pydantic.BaseModel.__setattr__()
451
457
def _check_key_value (self , name : str , value : Any , skip_allow_mutation : bool , skip_validation : bool ) -> Any :
452
- if not self .__config__ .allow_extra and name not in self .__fields__ :
458
+ cls = self .__class__
459
+ if not cls .__allow_extra__ and name not in self .__fields__ :
453
460
raise ValueError (f"Variable '{ name } ' is not declared in '{ self .__class__ .__name__ } '. "
454
461
f"Adding undeclared variables is forbidden. "
455
462
f"List of declared variables is: { list (self .__fields__ .keys ())} ." )
456
- if not skip_allow_mutation and not self . __config__ . allow_mutation :
463
+ if not skip_allow_mutation and not cls . __allow_mutation__ :
457
464
raise TypeError (f"Cannot change the value of the variable '{ name } ' since '{ self .__class__ .__name__ } ' "
458
465
f"is immutable and does not support item assignment" )
459
466
known_field = self .__fields__ .get (name , None )
@@ -492,3 +499,62 @@ def dict(self, exclude: Set[str] = None):
492
499
if name in d :
493
500
del d [name ]
494
501
return d
502
+
503
+
504
+ class Parameters (ConstrainedSession ):
505
+ """
506
+ Same as py:class:`ConstrainedSession` but:
507
+
508
+ * declared variables cannot be modified after initialization
509
+ * adding undeclared variables after initialization is forbidden.
510
+
511
+ Parameters
512
+ ----------
513
+ *args : str or dict of {str: object} or iterable of tuples (str, object)
514
+ Path to the file containing the session to load or
515
+ list/tuple/dictionary containing couples (name, object).
516
+ **kwargs : dict of {str: object}
517
+
518
+ * Objects to add written as name=object
519
+ * meta : list of pairs or dict or OrderedDict or Metadata, optional
520
+ Metadata (title, description, author, creation_date, ...) associated with the array.
521
+ Keys must be strings. Values must be of type string, int, float, date, time or datetime.
522
+
523
+ See Also
524
+ --------
525
+ ConstrainedSession
526
+
527
+ Examples
528
+ --------
529
+
530
+ Content of file 'parameters.py'
531
+
532
+ >>> from larray import *
533
+ >>> class ModelParameters(Parameters):
534
+ ... # --- declare variables with fixed values ---
535
+ ... # The given values can never be changed
536
+ ... FIRST_YEAR = 2020
537
+ ... LAST_YEAR = 2030
538
+ ... AGE = Axis('age=0..10')
539
+ ... GENDER = Axis('gender=male,female')
540
+ ... TIME = Axis(f'time={FIRST_YEAR}..{LAST_YEAR}')
541
+ ... # --- declare variables with defined types ---
542
+ ... # Their values must be defined at initialized and will be frozen after.
543
+ ... variant_name: str
544
+
545
+ Content of file 'model.py'
546
+
547
+ >>> # instantiation --> create an instance of the ModelVariables class
548
+ >>> # all variables declared without value must be set
549
+ >>> P = ModelParameters(variant_name='variant_1')
550
+ >>> # once an instance is create, its variables can be accessed but not modified
551
+ >>> P.variant_name
552
+ 'variant_1'
553
+ >>> P.variant_name = 'new_variant' # doctest: +NORMALIZE_WHITESPACE
554
+ Traceback (most recent call last):
555
+ ...
556
+ TypeError: Cannot change the value of the variable 'variant_name' since 'ModelParameters'
557
+ is immutable and does not support item assignment
558
+ """
559
+ __allow_extra__ = False
560
+ __allow_mutation__ = False
0 commit comments