Skip to content

Model subclass instances #492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class CoreConfig(TypedDict, total=False):
typed_dict_total: bool # default: True
# used on typed-dicts and tagged union keys
from_attributes: bool
revalidate_models: bool
# whether instances of models and dataclasses (including subclass instances) should re-validate, default False
revalidate_instances: bool
# whether to validate default values during validation, default False
validate_default: bool
# used on typed-dicts and arguments
Expand Down Expand Up @@ -2562,6 +2563,7 @@ class ModelSchema(TypedDict, total=False):
cls: Required[Type[Any]]
schema: Required[CoreSchema]
post_init: str
revalidate_instances: bool
strict: bool
frozen: bool
config: CoreConfig
Expand All @@ -2575,6 +2577,7 @@ def model_schema(
schema: CoreSchema,
*,
post_init: str | None = None,
revalidate_instances: bool | None = None,
strict: bool | None = None,
frozen: bool | None = None,
config: CoreConfig | None = None,
Expand Down Expand Up @@ -2612,6 +2615,8 @@ class MyModel:
cls: The class to use for the model
schema: The schema to use for the model
post_init: The call after init to use for the model
revalidate_instances: whether instances of models and dataclasses (including subclass instances)
should re-validate defaults to config.revalidate_instances, else False
strict: Whether the model is strict
frozen: Whether the model is frozen
config: The config to use for the model
Expand All @@ -2624,6 +2629,7 @@ class MyModel:
cls=cls,
schema=schema,
post_init=post_init,
revalidate_instances=revalidate_instances,
strict=strict,
frozen=frozen,
config=config,
Expand Down Expand Up @@ -2756,6 +2762,7 @@ class DataclassSchema(TypedDict, total=False):
cls: Required[Type[Any]]
schema: Required[CoreSchema]
post_init: bool # default: False
revalidate_instances: bool # default: False
strict: bool # default: False
ref: str
metadata: Any
Expand All @@ -2767,6 +2774,7 @@ def dataclass_schema(
schema: CoreSchema,
*,
post_init: bool | None = None,
revalidate_instances: bool | None = None,
strict: bool | None = None,
ref: str | None = None,
metadata: Any = None,
Expand All @@ -2780,6 +2788,8 @@ def dataclass_schema(
cls: The dataclass type, used to to perform subclass checks
schema: The schema to use for the dataclass fields
post_init: Whether to call `__post_init__` after validation
revalidate_instances: whether instances of models and dataclasses (including subclass instances)
should re-validate defaults to config.revalidate_instances, else False
strict: Whether to require an exact instance of `cls`
ref: See [TODO] for details
metadata: See [TODO] for details
Expand All @@ -2790,6 +2800,7 @@ def dataclass_schema(
cls=cls,
schema=schema,
post_init=post_init,
revalidate_instances=revalidate_instances,
strict=strict,
ref=ref,
metadata=metadata,
Expand Down
15 changes: 7 additions & 8 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,23 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
fn is_none(&self) -> bool;

#[cfg_attr(has_no_coverage, no_coverage)]
fn get_attr(&self, _name: &PyString) -> Option<&PyAny> {
fn input_get_attr(&self, _name: &PyString) -> Option<PyResult<&PyAny>> {
None
}

// input_ prefix to differentiate from the function on PyAny
fn input_is_instance(&self, class: &PyAny, json_mask: u8) -> PyResult<bool>;

fn is_exact_instance(&self, _class: &PyType) -> PyResult<bool> {
Ok(false)
fn is_exact_instance(&self, _class: &PyType) -> bool {
false
}

fn input_is_subclass(&self, _class: &PyType) -> PyResult<bool> {
Ok(false)
fn is_python(&self) -> bool {
false
}

// if the input is a subclass of `_class`, return `input.__dict__`, used on dataclasses
fn maybe_subclass_dict(&self, _class: &PyType) -> PyResult<&Self> {
Ok(self)
fn input_is_subclass(&self, _class: &PyType) -> PyResult<bool> {
Ok(false)
}

fn input_as_url(&self) -> Option<PyUrl> {
Expand Down
20 changes: 8 additions & 12 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ impl<'a> Input<'a> for PyAny {
self.is_none()
}

fn get_attr(&self, name: &PyString) -> Option<&PyAny> {
self.getattr(name).ok()
fn input_get_attr(&self, name: &PyString) -> Option<PyResult<&PyAny>> {
Some(self.getattr(name))
}

fn input_is_instance(&self, class: &PyAny, _json_mask: u8) -> PyResult<bool> {
Expand All @@ -110,8 +110,12 @@ impl<'a> Input<'a> for PyAny {
Ok(result == 1)
}

fn is_exact_instance(&self, class: &PyType) -> PyResult<bool> {
self.get_type().eq(class)
fn is_exact_instance(&self, class: &PyType) -> bool {
self.get_type().is(class)
}

fn is_python(&self) -> bool {
true
}

fn input_is_subclass(&self, class: &PyType) -> PyResult<bool> {
Expand All @@ -121,14 +125,6 @@ impl<'a> Input<'a> for PyAny {
}
}

fn maybe_subclass_dict(&self, class: &PyType) -> PyResult<&Self> {
if matches!(self.is_instance(class), Ok(true)) {
self.getattr(intern!(self.py(), "__dict__"))
} else {
Ok(self)
}
}

fn input_as_url(&self) -> Option<PyUrl> {
self.extract::<PyUrl>().ok()
}
Expand Down
2 changes: 1 addition & 1 deletion src/serializers/type_serializers/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl BuildSerializer for ModelSerializer {
impl ModelSerializer {
fn allow_value(&self, value: &PyAny, extra: &Extra) -> PyResult<bool> {
match extra.check {
SerCheck::Strict => value.get_type().eq(self.class.as_ref(value.py())),
SerCheck::Strict => Ok(value.get_type().is(self.class.as_ref(value.py()))),
SerCheck::Lax => value.is_instance(self.class.as_ref(value.py())),
SerCheck::None => value.hasattr(intern!(value.py(), "__dict__")),
}
Expand Down
73 changes: 43 additions & 30 deletions src/validators/dataclass.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use pyo3::exceptions::{PyKeyError, PyTypeError};
use pyo3::exceptions::PyKeyError;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyString, PyTuple, PyType};

use ahash::AHashSet;

use crate::build_tools::{is_strict, py_err, safe_repr, schema_or_config_same, SchemaDict};
use crate::build_tools::{is_strict, py_err, schema_or_config_same, SchemaDict};
use crate::errors::{ErrorType, ValError, ValLineError, ValResult};
use crate::input::{GenericArguments, Input};
use crate::lookup_key::LookupKey;
Expand Down Expand Up @@ -291,16 +291,6 @@ impl Validator for DataclassArgsValidator {
}
}

fn get_name(&self) -> &str {
&self.validator_name
}

fn complete(&mut self, build_context: &BuildContext<CombinedValidator>) -> PyResult<()> {
self.fields
.iter_mut()
.try_for_each(|field| field.validator.complete(build_context))
}

fn validate_assignment<'s, 'data: 's>(
&'s self,
py: Python<'data>,
Expand Down Expand Up @@ -354,6 +344,16 @@ impl Validator for DataclassArgsValidator {
))
}
}

fn get_name(&self) -> &str {
&self.validator_name
}

fn complete(&mut self, build_context: &BuildContext<CombinedValidator>) -> PyResult<()> {
self.fields
.iter_mut()
.try_for_each(|field| field.validator.complete(build_context))
}
}

#[derive(Debug, Clone)]
Expand All @@ -362,6 +362,7 @@ pub struct DataclassValidator {
validator: Box<CombinedValidator>,
class: Py<PyType>,
post_init: Option<Py<PyString>>,
revalidate: bool,
name: String,
}

Expand Down Expand Up @@ -390,6 +391,7 @@ impl BuildValidator for DataclassValidator {
validator: Box::new(validator),
class: class.into(),
post_init,
revalidate: schema_or_config_same(schema, config, intern!(py, "revalidate_instances"))?.unwrap_or(false),
// as with model, get the class's `__name__`, not using `class.name()` since it uses `__qualname__`
// which is not what we want here
name: class.getattr(intern!(py, "__name__"))?.extract()?,
Expand All @@ -411,33 +413,43 @@ impl Validator for DataclassValidator {
// in the case that self_instance is Some, we're calling validation from within `BaseModel.__init__`
return self.validate_init(py, self_instance, input, extra, slots, recursion_guard);
}
let class = self.class.as_ref(py);

// we only do the is_exact_instance in strict mode
// we run validation even if input is an exact class to cover the case where a vanilla dataclass has been
// created with invalid types
// in theory we could have a flag to skip validation for an exact type in some scenarios, but I'm not sure
// that's a good idea
if extra.strict.unwrap_or(self.strict) && !input.is_exact_instance(class)? {
// same logic as on models
let class = self.class.as_ref(py);
if input.input_is_instance(class, 0)? {
if input.is_exact_instance(class) || !extra.strict.unwrap_or(self.strict) {
if self.revalidate {
let input = input.input_get_attr(intern!(py, "__dict__")).unwrap()?;
let val_output = self.validator.validate(py, input, extra, slots, recursion_guard)?;
let dc = create_class(self.class.as_ref(py))?;
self.set_dict_call(py, dc.as_ref(py), val_output, input)?;
Ok(dc)
} else {
Ok(input.to_object(py))
}
} else {
Err(ValError::new(
ErrorType::ModelClassType {
class_name: self.get_name().to_string(),
},
input,
))
}
} else if extra.strict.unwrap_or(self.strict) && input.is_python() {
Err(ValError::new(
ErrorType::ModelClassType {
class_name: self.get_name().to_string(),
},
input,
))
} else {
let input = input.maybe_subclass_dict(class)?;
let val_output = self.validator.validate(py, input, extra, slots, recursion_guard)?;
let dc = create_class(self.class.as_ref(py))?;
self.set_dict_call(py, dc.as_ref(py), val_output, input)?;
Ok(dc)
}
}

fn get_name(&self) -> &str {
&self.name
}

fn validate_assignment<'s, 'data: 's>(
&'s self,
py: Python<'data>,
Expand All @@ -448,11 +460,8 @@ impl Validator for DataclassValidator {
slots: &'data [CombinedValidator],
recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
let dict_attr = intern!(py, "__dict__");
let dict: &PyDict = match obj.get_attr(dict_attr) {
Some(v) => v.downcast()?,
None => return Err(PyTypeError::new_err(format!("{} is not a model instance", safe_repr(obj))).into()),
};
let dict_py_str = intern!(py, "__dict__");
let dict: &PyDict = obj.getattr(dict_py_str)?.downcast()?;

let new_dict = dict.copy()?;
new_dict.set_item(field_name, field_value)?;
Expand All @@ -461,10 +470,14 @@ impl Validator for DataclassValidator {
self.validator
.validate_assignment(py, new_dict, field_name, field_value, extra, slots, recursion_guard)?;

force_setattr(py, obj, dict_attr, dc_dict)?;
force_setattr(py, obj, dict_py_str, dc_dict)?;

Ok(obj.to_object(py))
}

fn get_name(&self) -> &str {
&self.name
}
}

impl DataclassValidator {
Expand Down
Loading