diff --git a/pydantic_core/core_schema.py b/pydantic_core/core_schema.py index 9b890da53..25ab6f5aa 100644 --- a/pydantic_core/core_schema.py +++ b/pydantic_core/core_schema.py @@ -31,7 +31,8 @@ class CoreConfig(TypedDict, total=False): typed_dict_total: bool # default: True # used on typed-dicts and tagged union keys from_attributes: bool - revalidate_models: bool + # whether instances of models and dataclasses (including subclass instances) should re-validate, default False + revalidate_instances: bool # whether to validate default values during validation, default False validate_default: bool # used on typed-dicts and arguments @@ -2562,6 +2563,7 @@ class ModelSchema(TypedDict, total=False): cls: Required[Type[Any]] schema: Required[CoreSchema] post_init: str + revalidate_instances: bool strict: bool frozen: bool config: CoreConfig @@ -2575,6 +2577,7 @@ def model_schema( schema: CoreSchema, *, post_init: str | None = None, + revalidate_instances: bool | None = None, strict: bool | None = None, frozen: bool | None = None, config: CoreConfig | None = None, @@ -2612,6 +2615,8 @@ class MyModel: cls: The class to use for the model schema: The schema to use for the model post_init: The call after init to use for the model + revalidate_instances: whether instances of models and dataclasses (including subclass instances) + should re-validate defaults to config.revalidate_instances, else False strict: Whether the model is strict frozen: Whether the model is frozen config: The config to use for the model @@ -2624,6 +2629,7 @@ class MyModel: cls=cls, schema=schema, post_init=post_init, + revalidate_instances=revalidate_instances, strict=strict, frozen=frozen, config=config, @@ -2756,6 +2762,7 @@ class DataclassSchema(TypedDict, total=False): cls: Required[Type[Any]] schema: Required[CoreSchema] post_init: bool # default: False + revalidate_instances: bool # default: False strict: bool # default: False ref: str metadata: Any @@ -2767,6 +2774,7 @@ def dataclass_schema( schema: CoreSchema, *, post_init: bool | None = None, + revalidate_instances: bool | None = None, strict: bool | None = None, ref: str | None = None, metadata: Any = None, @@ -2780,6 +2788,8 @@ def dataclass_schema( cls: The dataclass type, used to to perform subclass checks schema: The schema to use for the dataclass fields post_init: Whether to call `__post_init__` after validation + revalidate_instances: whether instances of models and dataclasses (including subclass instances) + should re-validate defaults to config.revalidate_instances, else False strict: Whether to require an exact instance of `cls` ref: See [TODO] for details metadata: See [TODO] for details @@ -2790,6 +2800,7 @@ def dataclass_schema( cls=cls, schema=schema, post_init=post_init, + revalidate_instances=revalidate_instances, strict=strict, ref=ref, metadata=metadata, diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 518a50ead..ddbdff2d2 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -40,24 +40,23 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { fn is_none(&self) -> bool; #[cfg_attr(has_no_coverage, no_coverage)] - fn get_attr(&self, _name: &PyString) -> Option<&PyAny> { + fn input_get_attr(&self, _name: &PyString) -> Option> { None } // input_ prefix to differentiate from the function on PyAny fn input_is_instance(&self, class: &PyAny, json_mask: u8) -> PyResult; - fn is_exact_instance(&self, _class: &PyType) -> PyResult { - Ok(false) + fn is_exact_instance(&self, _class: &PyType) -> bool { + false } - fn input_is_subclass(&self, _class: &PyType) -> PyResult { - Ok(false) + fn is_python(&self) -> bool { + false } - // if the input is a subclass of `_class`, return `input.__dict__`, used on dataclasses - fn maybe_subclass_dict(&self, _class: &PyType) -> PyResult<&Self> { - Ok(self) + fn input_is_subclass(&self, _class: &PyType) -> PyResult { + Ok(false) } fn input_as_url(&self) -> Option { diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 60e4e7140..f1fa2ff85 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -98,8 +98,8 @@ impl<'a> Input<'a> for PyAny { self.is_none() } - fn get_attr(&self, name: &PyString) -> Option<&PyAny> { - self.getattr(name).ok() + fn input_get_attr(&self, name: &PyString) -> Option> { + Some(self.getattr(name)) } fn input_is_instance(&self, class: &PyAny, _json_mask: u8) -> PyResult { @@ -110,8 +110,12 @@ impl<'a> Input<'a> for PyAny { Ok(result == 1) } - fn is_exact_instance(&self, class: &PyType) -> PyResult { - self.get_type().eq(class) + fn is_exact_instance(&self, class: &PyType) -> bool { + self.get_type().is(class) + } + + fn is_python(&self) -> bool { + true } fn input_is_subclass(&self, class: &PyType) -> PyResult { @@ -121,14 +125,6 @@ impl<'a> Input<'a> for PyAny { } } - fn maybe_subclass_dict(&self, class: &PyType) -> PyResult<&Self> { - if matches!(self.is_instance(class), Ok(true)) { - self.getattr(intern!(self.py(), "__dict__")) - } else { - Ok(self) - } - } - fn input_as_url(&self) -> Option { self.extract::().ok() } diff --git a/src/serializers/type_serializers/model.rs b/src/serializers/type_serializers/model.rs index 0a03d0c39..ad6247e52 100644 --- a/src/serializers/type_serializers/model.rs +++ b/src/serializers/type_serializers/model.rs @@ -47,7 +47,7 @@ impl BuildSerializer for ModelSerializer { impl ModelSerializer { fn allow_value(&self, value: &PyAny, extra: &Extra) -> PyResult { match extra.check { - SerCheck::Strict => value.get_type().eq(self.class.as_ref(value.py())), + SerCheck::Strict => Ok(value.get_type().is(self.class.as_ref(value.py()))), SerCheck::Lax => value.is_instance(self.class.as_ref(value.py())), SerCheck::None => value.hasattr(intern!(value.py(), "__dict__")), } diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index b5bb17b72..9109e7f81 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -1,11 +1,11 @@ -use pyo3::exceptions::{PyKeyError, PyTypeError}; +use pyo3::exceptions::PyKeyError; use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyString, PyTuple, PyType}; use ahash::AHashSet; -use crate::build_tools::{is_strict, py_err, safe_repr, schema_or_config_same, SchemaDict}; +use crate::build_tools::{is_strict, py_err, schema_or_config_same, SchemaDict}; use crate::errors::{ErrorType, ValError, ValLineError, ValResult}; use crate::input::{GenericArguments, Input}; use crate::lookup_key::LookupKey; @@ -291,16 +291,6 @@ impl Validator for DataclassArgsValidator { } } - fn get_name(&self) -> &str { - &self.validator_name - } - - fn complete(&mut self, build_context: &BuildContext) -> PyResult<()> { - self.fields - .iter_mut() - .try_for_each(|field| field.validator.complete(build_context)) - } - fn validate_assignment<'s, 'data: 's>( &'s self, py: Python<'data>, @@ -354,6 +344,16 @@ impl Validator for DataclassArgsValidator { )) } } + + fn get_name(&self) -> &str { + &self.validator_name + } + + fn complete(&mut self, build_context: &BuildContext) -> PyResult<()> { + self.fields + .iter_mut() + .try_for_each(|field| field.validator.complete(build_context)) + } } #[derive(Debug, Clone)] @@ -362,6 +362,7 @@ pub struct DataclassValidator { validator: Box, class: Py, post_init: Option>, + revalidate: bool, name: String, } @@ -390,6 +391,7 @@ impl BuildValidator for DataclassValidator { validator: Box::new(validator), class: class.into(), post_init, + revalidate: schema_or_config_same(schema, config, intern!(py, "revalidate_instances"))?.unwrap_or(false), // as with model, get the class's `__name__`, not using `class.name()` since it uses `__qualname__` // which is not what we want here name: class.getattr(intern!(py, "__name__"))?.extract()?, @@ -411,14 +413,29 @@ impl Validator for DataclassValidator { // in the case that self_instance is Some, we're calling validation from within `BaseModel.__init__` return self.validate_init(py, self_instance, input, extra, slots, recursion_guard); } - let class = self.class.as_ref(py); - // we only do the is_exact_instance in strict mode - // we run validation even if input is an exact class to cover the case where a vanilla dataclass has been - // created with invalid types - // in theory we could have a flag to skip validation for an exact type in some scenarios, but I'm not sure - // that's a good idea - if extra.strict.unwrap_or(self.strict) && !input.is_exact_instance(class)? { + // same logic as on models + let class = self.class.as_ref(py); + if input.input_is_instance(class, 0)? { + if input.is_exact_instance(class) || !extra.strict.unwrap_or(self.strict) { + if self.revalidate { + let input = input.input_get_attr(intern!(py, "__dict__")).unwrap()?; + let val_output = self.validator.validate(py, input, extra, slots, recursion_guard)?; + let dc = create_class(self.class.as_ref(py))?; + self.set_dict_call(py, dc.as_ref(py), val_output, input)?; + Ok(dc) + } else { + Ok(input.to_object(py)) + } + } else { + Err(ValError::new( + ErrorType::ModelClassType { + class_name: self.get_name().to_string(), + }, + input, + )) + } + } else if extra.strict.unwrap_or(self.strict) && input.is_python() { Err(ValError::new( ErrorType::ModelClassType { class_name: self.get_name().to_string(), @@ -426,7 +443,6 @@ impl Validator for DataclassValidator { input, )) } else { - let input = input.maybe_subclass_dict(class)?; let val_output = self.validator.validate(py, input, extra, slots, recursion_guard)?; let dc = create_class(self.class.as_ref(py))?; self.set_dict_call(py, dc.as_ref(py), val_output, input)?; @@ -434,10 +450,6 @@ impl Validator for DataclassValidator { } } - fn get_name(&self) -> &str { - &self.name - } - fn validate_assignment<'s, 'data: 's>( &'s self, py: Python<'data>, @@ -448,11 +460,8 @@ impl Validator for DataclassValidator { slots: &'data [CombinedValidator], recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - let dict_attr = intern!(py, "__dict__"); - let dict: &PyDict = match obj.get_attr(dict_attr) { - Some(v) => v.downcast()?, - None => return Err(PyTypeError::new_err(format!("{} is not a model instance", safe_repr(obj))).into()), - }; + let dict_py_str = intern!(py, "__dict__"); + let dict: &PyDict = obj.getattr(dict_py_str)?.downcast()?; let new_dict = dict.copy()?; new_dict.set_item(field_name, field_value)?; @@ -461,10 +470,14 @@ impl Validator for DataclassValidator { self.validator .validate_assignment(py, new_dict, field_name, field_value, extra, slots, recursion_guard)?; - force_setattr(py, obj, dict_attr, dc_dict)?; + force_setattr(py, obj, dict_py_str, dc_dict)?; Ok(obj.to_object(py)) } + + fn get_name(&self) -> &str { + &self.name + } } impl DataclassValidator { diff --git a/src/validators/model.rs b/src/validators/model.rs index c40fb197e..454250e18 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -7,7 +7,7 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PySet, PyString, PyTuple, PyType}; use pyo3::{ffi, intern}; -use crate::build_tools::{py_err, safe_repr, SchemaDict}; +use crate::build_tools::{py_err, schema_or_config_same, SchemaDict}; use crate::errors::{ErrorType, ValError, ValResult}; use crate::input::{py_error_on_minusone, Input}; use crate::questions::Question; @@ -50,7 +50,7 @@ impl BuildValidator for ModelValidator { // we don't use is_strict here since we don't want validation to be strict in this case if // `config.strict` is set, only if this specific field is strict strict: schema.get_as(intern!(py, "strict"))?.unwrap_or(false), - revalidate: config.get_as(intern!(py, "revalidate_models"))?.unwrap_or(false), + revalidate: schema_or_config_same(schema, config, intern!(py, "revalidate_instances"))?.unwrap_or(false), validator: Box::new(validator), class: class.into(), post_init: schema @@ -86,44 +86,62 @@ impl Validator for ModelValidator { return self.validate_init(py, self_instance, input, extra, slots, recursion_guard); } + // if we're in strict mode, we require an exact instance of the class (from python, with JSON an object is ok) + // if we're not in strict mode, instances subclasses are okay, as well as dicts, mappings, from attributes etc. + // if the input is an instance of the class, we "revalidate" it - e.g. we extract and reuse `__fields_set__` + // but use from attributes to create a new instance of the model field type let class = self.class.as_ref(py); - let instance = if input.is_exact_instance(class)? { - if self.revalidate { - let fields_set = input.get_attr(intern!(py, "__fields_set__")); - let output = self.validator.validate(py, input, extra, slots, recursion_guard)?; - if self.expect_fields_set { - let (model_dict, validation_fields_set): (&PyAny, &PyAny) = output.extract(py)?; - let fields_set = fields_set.unwrap_or(validation_fields_set); - self.create_class(model_dict, Some(fields_set))? + // mask 0 so JSON is input is never true here + if input.input_is_instance(class, 0)? { + // if the input is an exact instance OR we're not in strict mode, then progress + // which means raise ane error in the case of an instance of a subclass in strict mode + if input.is_exact_instance(class) || !extra.strict.unwrap_or(self.strict) { + if self.revalidate { + let fields_set = match input.input_get_attr(intern!(py, "__fields_set__")) { + Some(fields_set) => fields_set.ok(), + None => None, + }; + // get dict here so from_attributes logic doesn't apply + let dict = input.input_get_attr(intern!(py, "__dict__")).unwrap()?; + let output = self.validator.validate(py, dict, extra, slots, recursion_guard)?; + let instance = if self.expect_fields_set { + let (model_dict, validation_fields_set): (&PyAny, &PyAny) = output.extract(py)?; + let fields_set = fields_set.unwrap_or(validation_fields_set); + self.create_class(model_dict, Some(fields_set))? + } else { + self.create_class(output.as_ref(py), fields_set)? + }; + self.call_post_init(py, instance, input, extra) } else { - self.create_class(output.as_ref(py), fields_set)? + Ok(input.to_object(py)) } } else { - return Ok(input.to_object(py)); + Err(ValError::new( + ErrorType::ModelClassType { + class_name: self.get_name().to_string(), + }, + input, + )) } - } else if extra.strict.unwrap_or(self.strict) { - return Err(ValError::new( + } else if extra.strict.unwrap_or(self.strict) && input.is_python() { + Err(ValError::new( ErrorType::ModelClassType { class_name: self.get_name().to_string(), }, input, - )); + )) } else { let output = self.validator.validate(py, input, extra, slots, recursion_guard)?; - if self.expect_fields_set { + let instance = if self.expect_fields_set { let (model_dict, fields_set): (&PyAny, &PyAny) = output.extract(py)?; self.create_class(model_dict, Some(fields_set))? } else { self.create_class(output.as_ref(py), None)? - } - }; - if let Some(ref post_init) = self.post_init { - instance - .call_method1(py, post_init.as_ref(py), (extra.context,)) - .map_err(|e| convert_err(py, e, input))?; + }; + self.call_post_init(py, instance, input, extra) } - Ok(instance) } + fn get_name(&self) -> &str { &self.name } @@ -145,10 +163,8 @@ impl Validator for ModelValidator { if self.frozen { return Err(ValError::new(ErrorType::FrozenInstance, field_value)); } - let dict: &PyDict = match model.get_attr(intern!(py, "__dict__")) { - Some(v) => v.downcast()?, - None => return Err(PyTypeError::new_err(format!("{} is not a model instance", safe_repr(model))).into()), - }; + let dict_py_str = intern!(py, "__dict__"); + let dict: &PyDict = model.getattr(dict_py_str)?.downcast()?; let new_dict = dict.copy()?; new_dict.set_item(field_name, field_value)?; @@ -158,7 +174,7 @@ impl Validator for ModelValidator { .validate_assignment(py, new_dict, field_name, field_value, extra, slots, recursion_guard)?; let output = if self.expect_fields_set { let (output, updated_fields_set): (&PyDict, &PySet) = output.extract(py)?; - if let Some(fields_set) = model.get_attr(intern!(py, "__fields_set__")) { + if let Ok(fields_set) = model.input_get_attr(intern!(py, "__fields_set__")).unwrap() { let fields_set: &PySet = fields_set.downcast()?; for field_name in updated_fields_set { fields_set.add(field_name)?; @@ -168,7 +184,7 @@ impl Validator for ModelValidator { } else { output }; - force_setattr(py, model, intern!(py, "__dict__"), output)?; + force_setattr(py, model, dict_py_str, output)?; Ok(model.into_py(py)) } } @@ -197,12 +213,22 @@ impl ModelValidator { } else { set_model_attrs(self_instance, output.as_ref(py), None)?; }; + self.call_post_init(py, self_instance.into_py(py), input, extra) + } + + fn call_post_init<'s, 'data>( + &'s self, + py: Python<'data>, + instance: PyObject, + input: &'data impl Input<'data>, + extra: &Extra, + ) -> ValResult<'data, PyObject> { if let Some(ref post_init) = self.post_init { - self_instance - .call_method1(post_init.as_ref(py), (extra.context,)) + instance + .call_method1(py, post_init.as_ref(py), (extra.context,)) .map_err(|e| convert_err(py, e, input))?; } - Ok(self_instance.into_py(py)) + Ok(instance) } fn create_class(&self, model_dict: &PyAny, fields_set: Option<&PyAny>) -> PyResult { diff --git a/tests/benchmarks/test_micro_benchmarks.py b/tests/benchmarks/test_micro_benchmarks.py index ec66111d4..b556c57c4 100644 --- a/tests/benchmarks/test_micro_benchmarks.py +++ b/tests/benchmarks/test_micro_benchmarks.py @@ -1217,3 +1217,75 @@ def test_definition_out_of_tree(benchmark): ) values = [1, 2, 3.0, '4', '5', '6'] * 1000 benchmark(validator.validate_python, values) + + +@pytest.mark.benchmark(group='model_instance') +def test_model_instance(benchmark): + class MyModel: + __slots__ = '__dict__', '__fields_set__' + + def __init__(self, **d): + self.__dict__ = d + self.__fields_set__ = set(d) + + validator = SchemaValidator( + core_schema.model_schema( + MyModel, + core_schema.typed_dict_schema( + { + 'foo': core_schema.typed_dict_field(core_schema.int_schema()), + 'bar': core_schema.typed_dict_field(core_schema.int_schema()), + } + ), + revalidate_instances=True, + ) + ) + m1 = MyModel(foo=1, bar='2') + m2 = validator.validate_python(m1) + assert m1 is not m2 + assert m2.foo == 1 + assert m2.bar == 2 + + benchmark(validator.validate_python, m1) + + +@pytest.mark.benchmark(group='model_instance') +def test_model_instance_abc(benchmark): + import abc + + class MyMeta(abc.ABCMeta): + def __instancecheck__(self, instance) -> bool: + return hasattr(instance, '__pydantic_validator__') and super().__instancecheck__(instance) + + class BaseModel(metaclass=MyMeta): + __slots__ = '__dict__', '__fields_set__' + __pydantic_validator__ = True + + def __init__(self, **d): + self.__dict__ = d + self.__fields_set__ = set(d) + + class MyModel(BaseModel): + pass + + validator = SchemaValidator( + core_schema.model_schema( + MyModel, + core_schema.typed_dict_schema( + { + 'foo': core_schema.typed_dict_field(core_schema.int_schema()), + 'bar': core_schema.typed_dict_field(core_schema.int_schema()), + } + ), + revalidate_instances=True, + ) + ) + m1 = MyModel(foo=1, bar='2') + m2 = validator.validate_python(m1) + assert m1 is not m2 + assert m2.foo == 1 + assert m2.bar == 2 + + assert validator.isinstance_python(m1) + + benchmark(validator.validate_python, m1) diff --git a/tests/validators/test_dataclasses.py b/tests/validators/test_dataclasses.py index f8987638b..88415b0ad 100644 --- a/tests/validators/test_dataclasses.py +++ b/tests/validators/test_dataclasses.py @@ -267,16 +267,23 @@ class DuplicateDifferent: @pytest.mark.parametrize( - 'input_value,expected', + 'revalidate_instances,input_value,expected', [ - ({'a': 'hello', 'b': True}, {'a': 'hello', 'b': True}), - (FooDataclass(a='hello', b=True), {'a': 'hello', 'b': True}), - (FooDataclassSame(a='hello', b=True), {'a': 'hello', 'b': True}), - (FooDataclassMore(a='hello', b=True, c='more'), Err(r'c\s+Unexpected keyword argument')), - (DuplicateDifferent(a='hello', b=True), Err('Input should be a dictionary or an instance of FooDataclass')), + (True, {'a': 'hello', 'b': True}, {'a': 'hello', 'b': True}), + (True, FooDataclass(a='hello', b=True), {'a': 'hello', 'b': True}), + (True, FooDataclassSame(a='hello', b=True), {'a': 'hello', 'b': True}), + (True, FooDataclassMore(a='hello', b=True, c='more'), Err(r'c\s+Unexpected keyword argument')), + (True, DuplicateDifferent(a='hello', b=True), Err('should be a dictionary or an instance of FooDataclass')), + # revalidate_instances=False + (False, {'a': 'hello', 'b': True}, {'a': 'hello', 'b': True}), + (False, FooDataclass(a='hello', b=True), {'a': 'hello', 'b': True}), + (False, FooDataclassSame(a='hello', b=True), {'a': 'hello', 'b': True}), + (False, FooDataclassMore(a='hello', b=True, c='more'), {'a': 'hello', 'b': True, 'c': 'more'}), + (False, FooDataclassMore(a='hello', b='wrong', c='more'), {'a': 'hello', 'b': 'wrong', 'c': 'more'}), + (False, DuplicateDifferent(a='hello', b=True), Err('should be a dictionary or an instance of FooDataclass')), ], ) -def test_dataclass_subclass(input_value, expected): +def test_dataclass_subclass(revalidate_instances, input_value, expected): schema = core_schema.dataclass_schema( FooDataclass, core_schema.dataclass_args_schema( @@ -286,12 +293,13 @@ def test_dataclass_subclass(input_value, expected): core_schema.dataclass_field(name='b', schema=core_schema.bool_schema()), ], ), + revalidate_instances=revalidate_instances, ) v = SchemaValidator(schema) if isinstance(expected, Err): with pytest.raises(ValidationError, match=expected.message) as exc_info: - v.validate_python(input_value) + print(v.validate_python(input_value)) # debug(exc_info.value.errors()) if expected.errors is not None: @@ -398,14 +406,17 @@ def __post_init__(self, *args): @pytest.mark.parametrize( - 'input_value,expected', + 'revalidate_instances,input_value,expected', [ - ({'a': b'hello', 'b': 'true'}, {'a': 'hello', 'b': True}), - (FooDataclass(a='hello', b=True), {'a': 'hello', 'b': True}), - (FooDataclass(a=b'hello', b='true'), {'a': 'hello', 'b': True}), + (True, {'a': b'hello', 'b': 'true'}, {'a': 'hello', 'b': True}), + (True, FooDataclass(a='hello', b=True), {'a': 'hello', 'b': True}), + (True, FooDataclass(a=b'hello', b='true'), {'a': 'hello', 'b': True}), + (False, {'a': b'hello', 'b': 'true'}, {'a': 'hello', 'b': True}), + (False, FooDataclass(a='hello', b=True), {'a': 'hello', 'b': True}), + (False, FooDataclass(a=b'hello', b='true'), {'a': b'hello', 'b': 'true'}), ], ) -def test_dataclass_exact_validation(input_value, expected): +def test_dataclass_exact_validation(revalidate_instances, input_value, expected): schema = core_schema.dataclass_schema( FooDataclass, core_schema.dataclass_args_schema( @@ -415,6 +426,7 @@ def test_dataclass_exact_validation(input_value, expected): core_schema.dataclass_field(name='b', schema=core_schema.bool_schema()), ], ), + revalidate_instances=revalidate_instances, ) v = SchemaValidator(schema) @@ -714,7 +726,7 @@ def test_dataclass_validate_assignment(): ] # wrong arguments - with pytest.raises(TypeError, match="'field_a' is not a model instance"): + with pytest.raises(AttributeError, match="'str' object has no attribute '__dict__'"): v.validate_assignment('field_a', 'c', 123) diff --git a/tests/validators/test_model.py b/tests/validators/test_model.py index d1a627c07..9377a7c88 100644 --- a/tests/validators/test_model.py +++ b/tests/validators/test_model.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Set, Tuple import pytest +from dirty_equals import HasRepr, IsStr from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema @@ -442,6 +443,8 @@ def __init__(self): def test_model_class_instance_subclass(): + post_init_calls = [] + class MyModel: __slots__ = '__dict__', '__fields_set__' field_a: str @@ -449,6 +452,9 @@ class MyModel: def __init__(self): self.field_a = 'init_a' + def model_post_init(self, context): + post_init_calls.append(context) + class MySubModel(MyModel): field_b: str @@ -465,16 +471,75 @@ def __init__(self): 'return_fields_set': True, 'fields': {'field_a': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}}, }, - 'config': {'from_attributes': True}, + 'post_init': 'model_post_init', } ) m2 = MySubModel() assert m2.field_a - m3 = v.validate_python(m2) - assert m2 != m3 + m3 = v.validate_python(m2, context='call1') + assert m2 is m3 + assert m3.field_a == 'init_a' + assert m3.field_b == 'init_b' + assert post_init_calls == [] + + m4 = v.validate_python({'field_a': b'hello'}, context='call2') + assert isinstance(m4, MyModel) + assert m4.field_a == 'hello' + assert m4.__fields_set__ == {'field_a'} + assert post_init_calls == ['call2'] + + +def test_model_class_instance_subclass_revalidate(): + post_init_calls = [] + + class MyModel: + __slots__ = '__dict__', '__fields_set__' + field_a: str + + def __init__(self): + self.field_a = 'init_a' + + def model_post_init(self, context): + post_init_calls.append(context) + + class MySubModel(MyModel): + field_b: str + + def __init__(self): + super().__init__() + self.field_b = 'init_b' + + v = SchemaValidator( + { + 'type': 'model', + 'cls': MyModel, + 'schema': { + 'type': 'typed-dict', + 'return_fields_set': True, + 'fields': {'field_a': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}}, + }, + 'post_init': 'model_post_init', + 'revalidate_instances': True, + } + ) + + m2 = MySubModel() + assert m2.field_a + m3 = v.validate_python(m2, context='call1') + assert m2 is not m3 assert m3.field_a == 'init_a' assert not hasattr(m3, 'field_b') + assert post_init_calls == ['call1'] + + m4 = MySubModel() + m4.__fields_set__ = {'fruit_loop'} + m5 = v.validate_python(m4, context='call2') + assert m4 is not m5 + assert m5.__fields_set__ == {'fruit_loop'} + assert m5.field_a == 'init_a' + assert not hasattr(m5, 'field_b') + assert post_init_calls == ['call1', 'call2'] def test_model_class_strict(): @@ -506,7 +571,7 @@ def __init__(self): assert m.field_a == 'init_a' # note that since dict validation was not run here, there has been no check this is an int assert m.field_b == 'init_b' - with pytest.raises(ValidationError) as exc_info: + with pytest.raises(ValidationError, match='^1 validation error for MyModel\n') as exc_info: v.validate_python({'field_a': 'test', 'field_b': 12}) assert exc_info.value.errors() == [ { @@ -519,6 +584,62 @@ def __init__(self): ] assert str(exc_info.value).startswith('1 validation error for MyModel\n') + class MySubModel(MyModel): + field_c: str + + def __init__(self): + super().__init__() + self.field_c = 'init_c' + + # instances of subclasses are not supported in strict mode + m3 = MySubModel() + with pytest.raises(ValidationError, match='^1 validation error for MyModel\n') as exc_info: + v.validate_python(m3) + # insert_assert(exc_info.value.errors()) + assert exc_info.value.errors() == [ + { + 'type': 'model_class_type', + 'loc': (), + 'msg': 'Input should be an instance of MyModel', + 'input': HasRepr(IsStr(regex='.+MySubModel object at.+')), + 'ctx': {'class_name': 'MyModel'}, + } + ] + + +def test_model_class_strict_json(): + class MyModel: + __slots__ = '__dict__', '__fields_set__' + field_a: str + field_b: int + field_c: int + + v = SchemaValidator( + { + 'type': 'model', + 'strict': True, + 'cls': MyModel, + 'schema': { + 'type': 'typed-dict', + 'return_fields_set': True, + 'fields': { + 'field_a': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}, + 'field_b': {'type': 'typed-dict-field', 'schema': {'type': 'int'}}, + 'field_c': { + 'type': 'typed-dict-field', + 'schema': {'type': 'default', 'default': 42, 'schema': {'type': 'int'}}, + }, + }, + }, + } + ) + m = v.validate_json('{"field_a": "foobar", "field_b": "123"}') + assert isinstance(m, MyModel) + assert m.field_a == 'foobar' + assert m.field_b == 123 + assert m.field_c == 42 + assert m.__fields_set__ == {'field_a', 'field_b'} + def test_internal_error(): v = SchemaValidator( @@ -550,6 +671,7 @@ def __init__(self, a, b, fields_set): { 'type': 'model', 'cls': MyModel, + 'revalidate_instances': True, 'schema': { 'type': 'typed-dict', 'return_fields_set': True, @@ -559,7 +681,6 @@ def __init__(self, a, b, fields_set): 'field_b': {'type': 'typed-dict-field', 'schema': {'type': 'int'}}, }, }, - 'config': {'revalidate_models': True}, } ) assert re.search(r'revalidate: \w+', repr(v)).group(0) == 'revalidate: true' @@ -617,7 +738,7 @@ def __init__(self, **kwargs): 'field_b': {'type': 'typed-dict-field', 'schema': {'type': 'int'}}, }, }, - 'config': {'revalidate_models': True}, + 'config': {'revalidate_instances': True}, } ) @@ -697,7 +818,7 @@ def call_me_maybe(self, context): 'field_b': {'type': 'typed-dict-field', 'schema': {'type': 'int'}}, }, }, - 'config': {'revalidate_models': True}, + 'config': {'revalidate_instances': True}, } ) assert re.search(r'revalidate: \w+', repr(v)).group(0) == 'revalidate: true' @@ -927,7 +1048,7 @@ class MyModel: assert not hasattr(m, '__fields_set__') # wrong arguments - with pytest.raises(TypeError, match="'field_a' is not a model instance"): + with pytest.raises(AttributeError, match="'str' object has no attribute '__dict__'"): v.validate_assignment('field_a', 'field_a', b'different')