Skip to content

Commit 956a235

Browse files
authored
Model subclass instances (#492)
* Support model subclass instances * fix models to JSON * support dataclasses * no need for from_attributes on model subclass * add abc to model_instance tests * simplify input get_attr logic * tweak * tweak serialization
1 parent 12f596d commit 956a235

File tree

9 files changed

+356
-106
lines changed

9 files changed

+356
-106
lines changed

pydantic_core/core_schema.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ class CoreConfig(TypedDict, total=False):
3131
typed_dict_total: bool # default: True
3232
# used on typed-dicts and tagged union keys
3333
from_attributes: bool
34-
revalidate_models: bool
34+
# whether instances of models and dataclasses (including subclass instances) should re-validate, default False
35+
revalidate_instances: bool
3536
# whether to validate default values during validation, default False
3637
validate_default: bool
3738
# used on typed-dicts and arguments
@@ -2562,6 +2563,7 @@ class ModelSchema(TypedDict, total=False):
25622563
cls: Required[Type[Any]]
25632564
schema: Required[CoreSchema]
25642565
post_init: str
2566+
revalidate_instances: bool
25652567
strict: bool
25662568
frozen: bool
25672569
config: CoreConfig
@@ -2575,6 +2577,7 @@ def model_schema(
25752577
schema: CoreSchema,
25762578
*,
25772579
post_init: str | None = None,
2580+
revalidate_instances: bool | None = None,
25782581
strict: bool | None = None,
25792582
frozen: bool | None = None,
25802583
config: CoreConfig | None = None,
@@ -2612,6 +2615,8 @@ class MyModel:
26122615
cls: The class to use for the model
26132616
schema: The schema to use for the model
26142617
post_init: The call after init to use for the model
2618+
revalidate_instances: whether instances of models and dataclasses (including subclass instances)
2619+
should re-validate defaults to config.revalidate_instances, else False
26152620
strict: Whether the model is strict
26162621
frozen: Whether the model is frozen
26172622
config: The config to use for the model
@@ -2624,6 +2629,7 @@ class MyModel:
26242629
cls=cls,
26252630
schema=schema,
26262631
post_init=post_init,
2632+
revalidate_instances=revalidate_instances,
26272633
strict=strict,
26282634
frozen=frozen,
26292635
config=config,
@@ -2756,6 +2762,7 @@ class DataclassSchema(TypedDict, total=False):
27562762
cls: Required[Type[Any]]
27572763
schema: Required[CoreSchema]
27582764
post_init: bool # default: False
2765+
revalidate_instances: bool # default: False
27592766
strict: bool # default: False
27602767
ref: str
27612768
metadata: Any
@@ -2767,6 +2774,7 @@ def dataclass_schema(
27672774
schema: CoreSchema,
27682775
*,
27692776
post_init: bool | None = None,
2777+
revalidate_instances: bool | None = None,
27702778
strict: bool | None = None,
27712779
ref: str | None = None,
27722780
metadata: Any = None,
@@ -2780,6 +2788,8 @@ def dataclass_schema(
27802788
cls: The dataclass type, used to to perform subclass checks
27812789
schema: The schema to use for the dataclass fields
27822790
post_init: Whether to call `__post_init__` after validation
2791+
revalidate_instances: whether instances of models and dataclasses (including subclass instances)
2792+
should re-validate defaults to config.revalidate_instances, else False
27832793
strict: Whether to require an exact instance of `cls`
27842794
ref: See [TODO] for details
27852795
metadata: See [TODO] for details
@@ -2790,6 +2800,7 @@ def dataclass_schema(
27902800
cls=cls,
27912801
schema=schema,
27922802
post_init=post_init,
2803+
revalidate_instances=revalidate_instances,
27932804
strict=strict,
27942805
ref=ref,
27952806
metadata=metadata,

src/input/input_abstract.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,24 +40,23 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
4040
fn is_none(&self) -> bool;
4141

4242
#[cfg_attr(has_no_coverage, no_coverage)]
43-
fn get_attr(&self, _name: &PyString) -> Option<&PyAny> {
43+
fn input_get_attr(&self, _name: &PyString) -> Option<PyResult<&PyAny>> {
4444
None
4545
}
4646

4747
// input_ prefix to differentiate from the function on PyAny
4848
fn input_is_instance(&self, class: &PyAny, json_mask: u8) -> PyResult<bool>;
4949

50-
fn is_exact_instance(&self, _class: &PyType) -> PyResult<bool> {
51-
Ok(false)
50+
fn is_exact_instance(&self, _class: &PyType) -> bool {
51+
false
5252
}
5353

54-
fn input_is_subclass(&self, _class: &PyType) -> PyResult<bool> {
55-
Ok(false)
54+
fn is_python(&self) -> bool {
55+
false
5656
}
5757

58-
// if the input is a subclass of `_class`, return `input.__dict__`, used on dataclasses
59-
fn maybe_subclass_dict(&self, _class: &PyType) -> PyResult<&Self> {
60-
Ok(self)
58+
fn input_is_subclass(&self, _class: &PyType) -> PyResult<bool> {
59+
Ok(false)
6160
}
6261

6362
fn input_as_url(&self) -> Option<PyUrl> {

src/input/input_python.rs

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ impl<'a> Input<'a> for PyAny {
9898
self.is_none()
9999
}
100100

101-
fn get_attr(&self, name: &PyString) -> Option<&PyAny> {
102-
self.getattr(name).ok()
101+
fn input_get_attr(&self, name: &PyString) -> Option<PyResult<&PyAny>> {
102+
Some(self.getattr(name))
103103
}
104104

105105
fn input_is_instance(&self, class: &PyAny, _json_mask: u8) -> PyResult<bool> {
@@ -110,8 +110,12 @@ impl<'a> Input<'a> for PyAny {
110110
Ok(result == 1)
111111
}
112112

113-
fn is_exact_instance(&self, class: &PyType) -> PyResult<bool> {
114-
self.get_type().eq(class)
113+
fn is_exact_instance(&self, class: &PyType) -> bool {
114+
self.get_type().is(class)
115+
}
116+
117+
fn is_python(&self) -> bool {
118+
true
115119
}
116120

117121
fn input_is_subclass(&self, class: &PyType) -> PyResult<bool> {
@@ -121,14 +125,6 @@ impl<'a> Input<'a> for PyAny {
121125
}
122126
}
123127

124-
fn maybe_subclass_dict(&self, class: &PyType) -> PyResult<&Self> {
125-
if matches!(self.is_instance(class), Ok(true)) {
126-
self.getattr(intern!(self.py(), "__dict__"))
127-
} else {
128-
Ok(self)
129-
}
130-
}
131-
132128
fn input_as_url(&self) -> Option<PyUrl> {
133129
self.extract::<PyUrl>().ok()
134130
}

src/serializers/type_serializers/model.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ impl BuildSerializer for ModelSerializer {
4747
impl ModelSerializer {
4848
fn allow_value(&self, value: &PyAny, extra: &Extra) -> PyResult<bool> {
4949
match extra.check {
50-
SerCheck::Strict => value.get_type().eq(self.class.as_ref(value.py())),
50+
SerCheck::Strict => Ok(value.get_type().is(self.class.as_ref(value.py()))),
5151
SerCheck::Lax => value.is_instance(self.class.as_ref(value.py())),
5252
SerCheck::None => value.hasattr(intern!(value.py(), "__dict__")),
5353
}

src/validators/dataclass.rs

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
use pyo3::exceptions::{PyKeyError, PyTypeError};
1+
use pyo3::exceptions::PyKeyError;
22
use pyo3::intern;
33
use pyo3::prelude::*;
44
use pyo3::types::{PyDict, PyList, PyString, PyTuple, PyType};
55

66
use ahash::AHashSet;
77

8-
use crate::build_tools::{is_strict, py_err, safe_repr, schema_or_config_same, SchemaDict};
8+
use crate::build_tools::{is_strict, py_err, schema_or_config_same, SchemaDict};
99
use crate::errors::{ErrorType, ValError, ValLineError, ValResult};
1010
use crate::input::{GenericArguments, Input};
1111
use crate::lookup_key::LookupKey;
@@ -291,16 +291,6 @@ impl Validator for DataclassArgsValidator {
291291
}
292292
}
293293

294-
fn get_name(&self) -> &str {
295-
&self.validator_name
296-
}
297-
298-
fn complete(&mut self, build_context: &BuildContext<CombinedValidator>) -> PyResult<()> {
299-
self.fields
300-
.iter_mut()
301-
.try_for_each(|field| field.validator.complete(build_context))
302-
}
303-
304294
fn validate_assignment<'s, 'data: 's>(
305295
&'s self,
306296
py: Python<'data>,
@@ -354,6 +344,16 @@ impl Validator for DataclassArgsValidator {
354344
))
355345
}
356346
}
347+
348+
fn get_name(&self) -> &str {
349+
&self.validator_name
350+
}
351+
352+
fn complete(&mut self, build_context: &BuildContext<CombinedValidator>) -> PyResult<()> {
353+
self.fields
354+
.iter_mut()
355+
.try_for_each(|field| field.validator.complete(build_context))
356+
}
357357
}
358358

359359
#[derive(Debug, Clone)]
@@ -362,6 +362,7 @@ pub struct DataclassValidator {
362362
validator: Box<CombinedValidator>,
363363
class: Py<PyType>,
364364
post_init: Option<Py<PyString>>,
365+
revalidate: bool,
365366
name: String,
366367
}
367368

@@ -390,6 +391,7 @@ impl BuildValidator for DataclassValidator {
390391
validator: Box::new(validator),
391392
class: class.into(),
392393
post_init,
394+
revalidate: schema_or_config_same(schema, config, intern!(py, "revalidate_instances"))?.unwrap_or(false),
393395
// as with model, get the class's `__name__`, not using `class.name()` since it uses `__qualname__`
394396
// which is not what we want here
395397
name: class.getattr(intern!(py, "__name__"))?.extract()?,
@@ -411,33 +413,43 @@ impl Validator for DataclassValidator {
411413
// in the case that self_instance is Some, we're calling validation from within `BaseModel.__init__`
412414
return self.validate_init(py, self_instance, input, extra, slots, recursion_guard);
413415
}
414-
let class = self.class.as_ref(py);
415416

416-
// we only do the is_exact_instance in strict mode
417-
// we run validation even if input is an exact class to cover the case where a vanilla dataclass has been
418-
// created with invalid types
419-
// in theory we could have a flag to skip validation for an exact type in some scenarios, but I'm not sure
420-
// that's a good idea
421-
if extra.strict.unwrap_or(self.strict) && !input.is_exact_instance(class)? {
417+
// same logic as on models
418+
let class = self.class.as_ref(py);
419+
if input.input_is_instance(class, 0)? {
420+
if input.is_exact_instance(class) || !extra.strict.unwrap_or(self.strict) {
421+
if self.revalidate {
422+
let input = input.input_get_attr(intern!(py, "__dict__")).unwrap()?;
423+
let val_output = self.validator.validate(py, input, extra, slots, recursion_guard)?;
424+
let dc = create_class(self.class.as_ref(py))?;
425+
self.set_dict_call(py, dc.as_ref(py), val_output, input)?;
426+
Ok(dc)
427+
} else {
428+
Ok(input.to_object(py))
429+
}
430+
} else {
431+
Err(ValError::new(
432+
ErrorType::ModelClassType {
433+
class_name: self.get_name().to_string(),
434+
},
435+
input,
436+
))
437+
}
438+
} else if extra.strict.unwrap_or(self.strict) && input.is_python() {
422439
Err(ValError::new(
423440
ErrorType::ModelClassType {
424441
class_name: self.get_name().to_string(),
425442
},
426443
input,
427444
))
428445
} else {
429-
let input = input.maybe_subclass_dict(class)?;
430446
let val_output = self.validator.validate(py, input, extra, slots, recursion_guard)?;
431447
let dc = create_class(self.class.as_ref(py))?;
432448
self.set_dict_call(py, dc.as_ref(py), val_output, input)?;
433449
Ok(dc)
434450
}
435451
}
436452

437-
fn get_name(&self) -> &str {
438-
&self.name
439-
}
440-
441453
fn validate_assignment<'s, 'data: 's>(
442454
&'s self,
443455
py: Python<'data>,
@@ -448,11 +460,8 @@ impl Validator for DataclassValidator {
448460
slots: &'data [CombinedValidator],
449461
recursion_guard: &'s mut RecursionGuard,
450462
) -> ValResult<'data, PyObject> {
451-
let dict_attr = intern!(py, "__dict__");
452-
let dict: &PyDict = match obj.get_attr(dict_attr) {
453-
Some(v) => v.downcast()?,
454-
None => return Err(PyTypeError::new_err(format!("{} is not a model instance", safe_repr(obj))).into()),
455-
};
463+
let dict_py_str = intern!(py, "__dict__");
464+
let dict: &PyDict = obj.getattr(dict_py_str)?.downcast()?;
456465

457466
let new_dict = dict.copy()?;
458467
new_dict.set_item(field_name, field_value)?;
@@ -461,10 +470,14 @@ impl Validator for DataclassValidator {
461470
self.validator
462471
.validate_assignment(py, new_dict, field_name, field_value, extra, slots, recursion_guard)?;
463472

464-
force_setattr(py, obj, dict_attr, dc_dict)?;
473+
force_setattr(py, obj, dict_py_str, dc_dict)?;
465474

466475
Ok(obj.to_object(py))
467476
}
477+
478+
fn get_name(&self) -> &str {
479+
&self.name
480+
}
468481
}
469482

470483
impl DataclassValidator {

0 commit comments

Comments
 (0)