diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index e46b8ffe3..b63c12aa0 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -64,9 +64,16 @@ fn validate_iter_to_vec<'a, 's>( match validator.validate(py, item, extra, slots, recursion_guard) { Ok(item) => output.push(item), Err(ValError::LineErrors(line_errors)) => { + if !extra.exhaustive { + return Err(ValError::Omit); + } errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index.into()))); } - Err(ValError::Omit) => (), + Err(ValError::Omit) => { + if !extra.exhaustive { + return Err(ValError::Omit); + } + } Err(err) => return Err(err), } } diff --git a/src/validators/generator.rs b/src/validators/generator.rs index f79ac1703..b7df4ba1f 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -239,6 +239,7 @@ impl InternalValidator { field: self.field.as_deref(), strict: self.strict, context: self.context.as_ref().map(|data| data.as_ref(py)), + exhaustive: true, }; self.validator .validate(py, input, &extra, &self.slots, &mut self.recursion_guard) diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 40c45afc1..40dc72c8c 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -196,6 +196,7 @@ impl SchemaValidator { field: Some(field.as_str()), strict, context, + exhaustive: true, }; let r = self .validator @@ -450,6 +451,8 @@ pub struct Extra<'a> { pub strict: Option, /// context used in validator functions pub context: Option<&'a PyAny>, + // if we should do exhaustive validation + pub exhaustive: bool, } impl<'a> Extra<'a> { @@ -457,6 +460,7 @@ impl<'a> Extra<'a> { Extra { strict, context, + exhaustive: true, ..Default::default() } } @@ -469,6 +473,16 @@ impl<'a> Extra<'a> { field: self.field, strict: Some(true), context: self.context, + exhaustive: self.exhaustive, + } + } + pub fn with_exhaustiveness(&self, exhaustive: bool) -> Self { + Self { + data: self.data, + field: self.field, + strict: Some(true), + context: self.context, + exhaustive, } } } diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index 189e1ff98..d239fb571 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -188,6 +188,7 @@ impl Validator for TypedDictValidator { field: None, strict: extra.strict, context: extra.context, + exhaustive: true, }; macro_rules! process { diff --git a/src/validators/union.rs b/src/validators/union.rs index 324b7e3d8..c6a4cee40 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -85,6 +85,17 @@ impl Validator for UnionValidator { recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { if extra.strict.unwrap_or(self.strict) { + // 1st pass: non exhaustive + let non_exhaustive_extra = extra.with_exhaustiveness(false); + if let Some(res) = self + .choices + .iter() + .map(|validator| validator.validate(py, input, &non_exhaustive_extra, slots, recursion_guard)) + .find(ValResult::is_ok) + { + return res; + } + let mut errors: Option> = match self.custom_error { None => Some(Vec::with_capacity(self.choices.len())), _ => None, @@ -110,7 +121,7 @@ impl Validator for UnionValidator { } else { // 1st pass: check if the value is an exact instance of one of the Union types, // e.g. use validate in strict mode - let strict_extra = extra.as_strict(); + let strict_extra = extra.as_strict().with_exhaustiveness(false); if let Some(res) = self .choices .iter() diff --git a/tests/benchmarks/test_micro_benchmarks.py b/tests/benchmarks/test_micro_benchmarks.py index 815a3498e..a2afcdb6f 100644 --- a/tests/benchmarks/test_micro_benchmarks.py +++ b/tests/benchmarks/test_micro_benchmarks.py @@ -837,6 +837,20 @@ def test_smart_union_core(self, benchmark): benchmark(v.validate_python, 1) + @pytest.mark.benchmark(group='smart-union') + def test_smart_union_deep(self, benchmark): + v = SchemaValidator( + { + 'type': 'union', + 'choices': [ + {'type': 'list', 'items_schema': {'type': 'str'}}, + {'type': 'list', 'items_schema': {'type': 'int'}}, + ], + } + ) + data = [1] * 1_000 + benchmark(v.validate_python, data) + @skip_pydantic @pytest.mark.benchmark(group='smart-union') def test_smart_union_pyd(self, benchmark):