From d6d4088ef97aef05b69bbc72a273a360cdd3d495 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Thu, 7 Sep 2023 22:35:53 +0100 Subject: [PATCH] make error "duplicate" cheaper --- src/errors/line_error.rs | 31 +++++++++++------------ src/errors/validation_exception.rs | 16 ++++++------ src/input/input_json.rs | 7 +++--- src/input/parse_json.rs | 14 ++++++----- src/input/return_enums.rs | 9 +++---- src/lazy_index_map.rs | 40 ++++++++++++++---------------- src/validators/function.rs | 2 +- src/validators/generator.rs | 8 +++--- src/validators/json.rs | 2 +- 9 files changed, 62 insertions(+), 67 deletions(-) diff --git a/src/errors/line_error.rs b/src/errors/line_error.rs index c8230a27a..e5d3c7bac 100644 --- a/src/errors/line_error.rs +++ b/src/errors/line_error.rs @@ -62,10 +62,10 @@ impl<'a> ValError<'a> { } /// a bit like clone but change the lifetime to match py - pub fn duplicate<'py>(&self, py: Python<'py>) -> ValError<'py> { + pub fn into_owned(self, py: Python<'_>) -> ValError<'_> { match self { - ValError::LineErrors(errors) => errors.iter().map(|e| e.duplicate(py)).collect::>().into(), - ValError::InternalErr(err) => ValError::InternalErr(err.clone_ref(py)), + ValError::LineErrors(errors) => errors.into_iter().map(|e| e.into_owned(py)).collect::>().into(), + ValError::InternalErr(err) => ValError::InternalErr(err), ValError::Omit => ValError::Omit, ValError::UseDefault => ValError::UseDefault, } @@ -129,28 +129,26 @@ impl<'a> ValLineError<'a> { self } - /// a bit like clone but change the lifetime to match py, used by ValError.duplicate above - pub fn duplicate<'py>(&'a self, py: Python<'py>) -> ValLineError<'py> { + /// a bit like clone but change the lifetime to match py, used by ValError.into_owned above + pub fn into_owned(self, py: Python<'_>) -> ValLineError<'_> { ValLineError { - error_type: self.error_type.clone(), - input_value: InputValue::<'py>::from(self.input_value.to_object(py)), - location: self.location.clone(), + error_type: self.error_type, + input_value: match self.input_value { + InputValue::PyAny(input) => InputValue::PyAny(input.to_object(py).into_ref(py)), + InputValue::JsonInput(input) => InputValue::JsonInput(input), + InputValue::String(input) => InputValue::PyAny(input.to_object(py).into_ref(py)), + }, + location: self.location, } } } #[cfg_attr(debug_assertions, derive(Debug))] +#[derive(Clone)] pub enum InputValue<'a> { PyAny(&'a PyAny), - JsonInput(&'a JsonInput), + JsonInput(JsonInput), String(&'a str), - PyObject(PyObject), -} - -impl<'a> From for InputValue<'a> { - fn from(py_object: PyObject) -> Self { - Self::PyObject(py_object) - } } impl<'a> ToPyObject for InputValue<'a> { @@ -159,7 +157,6 @@ impl<'a> ToPyObject for InputValue<'a> { Self::PyAny(input) => input.into_py(py), Self::JsonInput(input) => input.to_object(py), Self::String(input) => input.into_py(py), - Self::PyObject(py_obj) => py_obj.into_py(py), } } } diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index 8963c1d81..9dde6551c 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -22,7 +22,7 @@ use super::line_error::ValLineError; use super::location::Location; use super::types::{ErrorMode, ErrorType}; use super::value_exception::PydanticCustomError; -use super::ValError; +use super::{InputValue, ValError}; #[pyclass(extends=PyValueError, module="pydantic_core._pydantic_core")] #[derive(Clone)] @@ -128,11 +128,11 @@ fn get_url_prefix(py: Python, include_url: bool) -> Option<&str> { } // used to convert a validation error back to ValError for wrap functions -impl<'a> IntoPy> for ValidationError { - fn into_py(self, py: Python) -> ValError<'a> { +impl ValidationError { + pub(crate) fn into_val_error(self, py: Python<'_>) -> ValError<'_> { self.line_errors .into_iter() - .map(|e| e.into_py(py)) + .map(|e| e.into_val_line_error(py)) .collect::>() .into() } @@ -322,13 +322,13 @@ impl<'a> IntoPy for ValLineError<'a> { } } -/// opposite of above, used to extract line errors from a validation error for wrap functions -impl<'a> IntoPy> for PyLineError { - fn into_py(self, _py: Python) -> ValLineError<'a> { +impl PyLineError { + /// Used to extract line errors from a validation error for wrap functions + fn into_val_line_error(self, py: Python<'_>) -> ValLineError<'_> { ValLineError { error_type: self.error_type, location: self.location, - input_value: self.input_value.into(), + input_value: InputValue::PyAny(self.input_value.into_ref(py)), } } } diff --git a/src/input/input_json.rs b/src/input/input_json.rs index 86079da09..d9cb81fe2 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -31,7 +31,8 @@ impl<'a> Input<'a> for JsonInput { } fn as_error_value(&'a self) -> InputValue<'a> { - InputValue::JsonInput(self) + // cloning JsonInput is cheap due to use of Arc + InputValue::JsonInput(self.clone()) } fn is_none(&self) -> bool { @@ -262,7 +263,7 @@ impl<'a> Input<'a> for JsonInput { JsonInput::String(s) => Ok(string_to_vec(s).into()), JsonInput::Object(object) => { // return keys iterator to match python's behavior - let keys: Vec = object.keys().map(|k| JsonInput::String(k.clone())).collect(); + let keys: JsonArray = JsonArray::new(object.keys().map(|k| JsonInput::String(k.clone())).collect()); Ok(keys.into()) } _ => Err(ValError::new(ErrorTypeDefaults::IterableType, self)), @@ -550,5 +551,5 @@ impl<'a> Input<'a> for String { } fn string_to_vec(s: &str) -> JsonArray { - s.chars().map(|c| JsonInput::String(c.to_string())).collect() + JsonArray::new(s.chars().map(|c| JsonInput::String(c.to_string())).collect()) } diff --git a/src/input/parse_json.rs b/src/input/parse_json.rs index 3bc2d0d46..7603eaf67 100644 --- a/src/input/parse_json.rs +++ b/src/input/parse_json.rs @@ -1,9 +1,11 @@ use std::fmt; +use std::sync::Arc; use num_bigint::BigInt; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList}; use serde::de::{Deserialize, DeserializeSeed, Error as SerdeError, MapAccess, SeqAccess, Visitor}; +use smallvec::SmallVec; use crate::lazy_index_map::LazyIndexMap; @@ -20,8 +22,8 @@ pub enum JsonInput { Array(JsonArray), Object(JsonObject), } -pub type JsonArray = Vec; -pub type JsonObject = LazyIndexMap; +pub type JsonArray = Arc>; +pub type JsonObject = Arc>; impl ToPyObject for JsonInput { fn to_object(&self, py: Python<'_>) -> PyObject { @@ -111,13 +113,13 @@ impl<'de> Deserialize<'de> for JsonInput { where V: SeqAccess<'de>, { - let mut vec = Vec::new(); + let mut vec = SmallVec::new(); while let Some(elem) = visitor.next_element()? { vec.push(elem); } - Ok(JsonInput::Array(vec)) + Ok(JsonInput::Array(JsonArray::new(vec))) } fn visit_map(self, mut visitor: V) -> Result @@ -171,9 +173,9 @@ impl<'de> Deserialize<'de> for JsonInput { while let Some((key, value)) = visitor.next_entry()? { values.insert(key, value); } - Ok(JsonInput::Object(values)) + Ok(JsonInput::Object(Arc::new(values))) } - None => Ok(JsonInput::Object(LazyIndexMap::new())), + None => Ok(JsonInput::Object(Arc::new(LazyIndexMap::new()))), } } } diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index f1cfb8543..e97fe8b81 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -625,8 +625,8 @@ impl GenericPyIterator { } } - pub fn input<'a>(&'a self, py: Python<'a>) -> &'a PyAny { - self.obj.as_ref(py) + pub fn input_as_error_value<'py>(&self, py: Python<'py>) -> InputValue<'py> { + InputValue::PyAny(self.obj.clone_ref(py).into_ref(py)) } pub fn index(&self) -> usize { @@ -654,9 +654,8 @@ impl GenericJsonIterator { } } - pub fn input<'a>(&'a self, py: Python<'a>) -> &'a PyAny { - let input = JsonInput::Array(self.array.clone()); - input.to_object(py).into_ref(py) + pub fn input_as_error_value<'py>(&self, _py: Python<'py>) -> InputValue<'py> { + InputValue::JsonInput(JsonInput::Array(self.array.clone())) } pub fn index(&self) -> usize { diff --git a/src/lazy_index_map.rs b/src/lazy_index_map.rs index 163421de3..c5621f877 100644 --- a/src/lazy_index_map.rs +++ b/src/lazy_index_map.rs @@ -1,32 +1,36 @@ use std::borrow::Borrow; -use std::cell::RefCell; use std::cmp::{Eq, PartialEq}; use std::fmt::Debug; use std::hash::Hash; use std::slice::Iter as SliceIter; +use std::sync::OnceLock; use ahash::AHashMap; +use smallvec::SmallVec; #[derive(Debug, Clone, Default)] pub struct LazyIndexMap { - vec: Vec<(K, V)>, - map: RefCell>>, + vec: SmallVec<[(K, V); 8]>, + map: OnceLock>, } /// Like [IndexMap](https://docs.rs/indexmap/latest/indexmap/) but only builds the lookup map when it's needed. impl LazyIndexMap where K: Clone + Debug + Eq + Hash, - V: Clone + Debug, + V: Debug, { pub fn new() -> Self { Self { - vec: Vec::new(), - map: RefCell::new(None), + vec: SmallVec::new(), + map: OnceLock::new(), } } pub fn insert(&mut self, key: K, value: V) { + if let Some(map) = self.map.get_mut() { + map.insert(key.clone(), self.vec.len()); + } self.vec.push((key, value)); } @@ -39,22 +43,14 @@ where K: Borrow + PartialEq, Q: Hash + Eq, { - let mut map = self.map.borrow_mut(); - if let Some(map) = map.as_ref() { - map.get(key).map(|&i| &self.vec[i].1) - } else { - let mut new_map = AHashMap::with_capacity(self.vec.len()); - let mut value = None; - // reverse here so the last value is the one that's returned - for (index, (k, v)) in self.vec.iter().enumerate().rev() { - if value.is_none() && k == key { - value = Some(v); - } - new_map.insert(k.clone(), index); - } - *map = Some(new_map); - value - } + let map = self.map.get_or_init(|| { + self.vec + .iter() + .enumerate() + .map(|(index, (key, _))| (key.clone(), index)) + .collect() + }); + map.get(key).map(|&i| &self.vec[i].1) } pub fn keys(&self) -> impl Iterator { diff --git a/src/validators/function.rs b/src/validators/function.rs index 8f9b25d70..fa8a0673b 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -504,7 +504,7 @@ pub fn convert_err<'a>(py: Python<'a>, err: PyErr, input: &'a impl Input<'a>) -> } else if let Ok(pydantic_error_type) = err.value(py).extract::() { pydantic_error_type.into_val_error(input) } else if let Ok(validation_error) = err.value(py).extract::() { - validation_error.into_py(py) + validation_error.into_val_error(py) } else { py_err_string!(err.value(py), ValueError, input) } diff --git a/src/validators/generator.rs b/src/validators/generator.rs index 698504a83..1047e31bd 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -127,14 +127,14 @@ impl ValidatorIterator { Some(validator) => { if let Some(max_length) = max_length { if index >= max_length { - let val_error = ValError::new( + let val_error = ValError::new_custom_input( ErrorType::TooLong { field_type: "Generator".to_string(), max_length, actual_length: index + 1, context: None, }, - $iter.input(py), + $iter.input_as_error_value(py), ); return Err(ValidationError::from_val_error( py, @@ -153,14 +153,14 @@ impl ValidatorIterator { None => { if let Some(min_length) = min_length { if $iter.index() < min_length { - let val_error = ValError::new( + let val_error = ValError::new_custom_input( ErrorType::TooShort { field_type: "Generator".to_string(), min_length, actual_length: $iter.index(), context: None, }, - $iter.input(py), + $iter.input_as_error_value(py), ); return Err(ValidationError::from_val_error( py, diff --git a/src/validators/json.rs b/src/validators/json.rs index f99ac29ae..5eda007be 100644 --- a/src/validators/json.rs +++ b/src/validators/json.rs @@ -55,7 +55,7 @@ impl Validator for JsonValidator { match self.validator { Some(ref validator) => match validator.validate(py, &json_value, state) { Ok(v) => Ok(v), - Err(err) => Err(err.duplicate(py)), + Err(err) => Err(err.into_owned(py)), }, None => Ok(json_value.to_object(py)), }