1
- use pyo3:: exceptions:: { PyKeyError , PyTypeError } ;
1
+ use pyo3:: exceptions:: PyKeyError ;
2
2
use pyo3:: intern;
3
3
use pyo3:: prelude:: * ;
4
4
use pyo3:: types:: { PyDict , PyList , PyString , PyTuple , PyType } ;
5
5
6
6
use ahash:: AHashSet ;
7
7
8
- use crate :: build_tools:: { is_strict, py_err, safe_repr , schema_or_config_same, SchemaDict } ;
8
+ use crate :: build_tools:: { is_strict, py_err, schema_or_config_same, SchemaDict } ;
9
9
use crate :: errors:: { ErrorType , ValError , ValLineError , ValResult } ;
10
10
use crate :: input:: { GenericArguments , Input } ;
11
11
use crate :: lookup_key:: LookupKey ;
@@ -291,16 +291,6 @@ impl Validator for DataclassArgsValidator {
291
291
}
292
292
}
293
293
294
- fn get_name ( & self ) -> & str {
295
- & self . validator_name
296
- }
297
-
298
- fn complete ( & mut self , build_context : & BuildContext < CombinedValidator > ) -> PyResult < ( ) > {
299
- self . fields
300
- . iter_mut ( )
301
- . try_for_each ( |field| field. validator . complete ( build_context) )
302
- }
303
-
304
294
fn validate_assignment < ' s , ' data : ' s > (
305
295
& ' s self ,
306
296
py : Python < ' data > ,
@@ -354,6 +344,16 @@ impl Validator for DataclassArgsValidator {
354
344
) )
355
345
}
356
346
}
347
+
348
+ fn get_name ( & self ) -> & str {
349
+ & self . validator_name
350
+ }
351
+
352
+ fn complete ( & mut self , build_context : & BuildContext < CombinedValidator > ) -> PyResult < ( ) > {
353
+ self . fields
354
+ . iter_mut ( )
355
+ . try_for_each ( |field| field. validator . complete ( build_context) )
356
+ }
357
357
}
358
358
359
359
#[ derive( Debug , Clone ) ]
@@ -362,6 +362,7 @@ pub struct DataclassValidator {
362
362
validator : Box < CombinedValidator > ,
363
363
class : Py < PyType > ,
364
364
post_init : Option < Py < PyString > > ,
365
+ revalidate : bool ,
365
366
name : String ,
366
367
}
367
368
@@ -390,6 +391,7 @@ impl BuildValidator for DataclassValidator {
390
391
validator : Box :: new ( validator) ,
391
392
class : class. into ( ) ,
392
393
post_init,
394
+ revalidate : schema_or_config_same ( schema, config, intern ! ( py, "revalidate_instances" ) ) ?. unwrap_or ( false ) ,
393
395
// as with model, get the class's `__name__`, not using `class.name()` since it uses `__qualname__`
394
396
// which is not what we want here
395
397
name : class. getattr ( intern ! ( py, "__name__" ) ) ?. extract ( ) ?,
@@ -411,33 +413,43 @@ impl Validator for DataclassValidator {
411
413
// in the case that self_instance is Some, we're calling validation from within `BaseModel.__init__`
412
414
return self . validate_init ( py, self_instance, input, extra, slots, recursion_guard) ;
413
415
}
414
- let class = self . class . as_ref ( py) ;
415
416
416
- // we only do the is_exact_instance in strict mode
417
- // we run validation even if input is an exact class to cover the case where a vanilla dataclass has been
418
- // created with invalid types
419
- // in theory we could have a flag to skip validation for an exact type in some scenarios, but I'm not sure
420
- // that's a good idea
421
- if extra. strict . unwrap_or ( self . strict ) && !input. is_exact_instance ( class) ? {
417
+ // same logic as on models
418
+ let class = self . class . as_ref ( py) ;
419
+ if input. input_is_instance ( class, 0 ) ? {
420
+ if input. is_exact_instance ( class) || !extra. strict . unwrap_or ( self . strict ) {
421
+ if self . revalidate {
422
+ let input = input. input_get_attr ( intern ! ( py, "__dict__" ) ) . unwrap ( ) ?;
423
+ let val_output = self . validator . validate ( py, input, extra, slots, recursion_guard) ?;
424
+ let dc = create_class ( self . class . as_ref ( py) ) ?;
425
+ self . set_dict_call ( py, dc. as_ref ( py) , val_output, input) ?;
426
+ Ok ( dc)
427
+ } else {
428
+ Ok ( input. to_object ( py) )
429
+ }
430
+ } else {
431
+ Err ( ValError :: new (
432
+ ErrorType :: ModelClassType {
433
+ class_name : self . get_name ( ) . to_string ( ) ,
434
+ } ,
435
+ input,
436
+ ) )
437
+ }
438
+ } else if extra. strict . unwrap_or ( self . strict ) && input. is_python ( ) {
422
439
Err ( ValError :: new (
423
440
ErrorType :: ModelClassType {
424
441
class_name : self . get_name ( ) . to_string ( ) ,
425
442
} ,
426
443
input,
427
444
) )
428
445
} else {
429
- let input = input. maybe_subclass_dict ( class) ?;
430
446
let val_output = self . validator . validate ( py, input, extra, slots, recursion_guard) ?;
431
447
let dc = create_class ( self . class . as_ref ( py) ) ?;
432
448
self . set_dict_call ( py, dc. as_ref ( py) , val_output, input) ?;
433
449
Ok ( dc)
434
450
}
435
451
}
436
452
437
- fn get_name ( & self ) -> & str {
438
- & self . name
439
- }
440
-
441
453
fn validate_assignment < ' s , ' data : ' s > (
442
454
& ' s self ,
443
455
py : Python < ' data > ,
@@ -448,11 +460,8 @@ impl Validator for DataclassValidator {
448
460
slots : & ' data [ CombinedValidator ] ,
449
461
recursion_guard : & ' s mut RecursionGuard ,
450
462
) -> ValResult < ' data , PyObject > {
451
- let dict_attr = intern ! ( py, "__dict__" ) ;
452
- let dict: & PyDict = match obj. get_attr ( dict_attr) {
453
- Some ( v) => v. downcast ( ) ?,
454
- None => return Err ( PyTypeError :: new_err ( format ! ( "{} is not a model instance" , safe_repr( obj) ) ) . into ( ) ) ,
455
- } ;
463
+ let dict_py_str = intern ! ( py, "__dict__" ) ;
464
+ let dict: & PyDict = obj. getattr ( dict_py_str) ?. downcast ( ) ?;
456
465
457
466
let new_dict = dict. copy ( ) ?;
458
467
new_dict. set_item ( field_name, field_value) ?;
@@ -461,10 +470,14 @@ impl Validator for DataclassValidator {
461
470
self . validator
462
471
. validate_assignment ( py, new_dict, field_name, field_value, extra, slots, recursion_guard) ?;
463
472
464
- force_setattr ( py, obj, dict_attr , dc_dict) ?;
473
+ force_setattr ( py, obj, dict_py_str , dc_dict) ?;
465
474
466
475
Ok ( obj. to_object ( py) )
467
476
}
477
+
478
+ fn get_name ( & self ) -> & str {
479
+ & self . name
480
+ }
468
481
}
469
482
470
483
impl DataclassValidator {
0 commit comments