Skip to content

make error "duplicate" cheaper #950

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions src/errors/line_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ impl<'a> ValError<'a> {
}

/// a bit like clone but change the lifetime to match py
pub fn duplicate<'py>(&self, py: Python<'py>) -> ValError<'py> {
pub fn into_owned(self, py: Python<'_>) -> ValError<'_> {
match self {
ValError::LineErrors(errors) => errors.iter().map(|e| e.duplicate(py)).collect::<Vec<_>>().into(),
ValError::InternalErr(err) => ValError::InternalErr(err.clone_ref(py)),
ValError::LineErrors(errors) => errors.into_iter().map(|e| e.into_owned(py)).collect::<Vec<_>>().into(),
ValError::InternalErr(err) => ValError::InternalErr(err),
ValError::Omit => ValError::Omit,
ValError::UseDefault => ValError::UseDefault,
}
Expand Down Expand Up @@ -129,28 +129,26 @@ impl<'a> ValLineError<'a> {
self
}

/// a bit like clone but change the lifetime to match py, used by ValError.duplicate above
pub fn duplicate<'py>(&'a self, py: Python<'py>) -> ValLineError<'py> {
/// a bit like clone but change the lifetime to match py, used by ValError.into_owned above
pub fn into_owned(self, py: Python<'_>) -> ValLineError<'_> {
ValLineError {
error_type: self.error_type.clone(),
input_value: InputValue::<'py>::from(self.input_value.to_object(py)),
location: self.location.clone(),
error_type: self.error_type,
input_value: match self.input_value {
InputValue::PyAny(input) => InputValue::PyAny(input.to_object(py).into_ref(py)),
InputValue::JsonInput(input) => InputValue::JsonInput(input),
InputValue::String(input) => InputValue::PyAny(input.to_object(py).into_ref(py)),
},
location: self.location,
}
}
}

#[cfg_attr(debug_assertions, derive(Debug))]
#[derive(Clone)]
pub enum InputValue<'a> {
PyAny(&'a PyAny),
JsonInput(&'a JsonInput),
JsonInput(JsonInput),
String(&'a str),
PyObject(PyObject),
}

impl<'a> From<PyObject> for InputValue<'a> {
fn from(py_object: PyObject) -> Self {
Self::PyObject(py_object)
}
}

impl<'a> ToPyObject for InputValue<'a> {
Expand All @@ -159,7 +157,6 @@ impl<'a> ToPyObject for InputValue<'a> {
Self::PyAny(input) => input.into_py(py),
Self::JsonInput(input) => input.to_object(py),
Self::String(input) => input.into_py(py),
Self::PyObject(py_obj) => py_obj.into_py(py),
}
}
}
16 changes: 8 additions & 8 deletions src/errors/validation_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use super::line_error::ValLineError;
use super::location::Location;
use super::types::{ErrorMode, ErrorType};
use super::value_exception::PydanticCustomError;
use super::ValError;
use super::{InputValue, ValError};

#[pyclass(extends=PyValueError, module="pydantic_core._pydantic_core")]
#[derive(Clone)]
Expand Down Expand Up @@ -128,11 +128,11 @@ fn get_url_prefix(py: Python, include_url: bool) -> Option<&str> {
}

// used to convert a validation error back to ValError for wrap functions
impl<'a> IntoPy<ValError<'a>> for ValidationError {
fn into_py(self, py: Python) -> ValError<'a> {
impl ValidationError {
pub(crate) fn into_val_error(self, py: Python<'_>) -> ValError<'_> {
self.line_errors
.into_iter()
.map(|e| e.into_py(py))
.map(|e| e.into_val_line_error(py))
.collect::<Vec<_>>()
.into()
}
Expand Down Expand Up @@ -322,13 +322,13 @@ impl<'a> IntoPy<PyLineError> for ValLineError<'a> {
}
}

/// opposite of above, used to extract line errors from a validation error for wrap functions
impl<'a> IntoPy<ValLineError<'a>> for PyLineError {
fn into_py(self, _py: Python) -> ValLineError<'a> {
impl PyLineError {
/// Used to extract line errors from a validation error for wrap functions
fn into_val_line_error(self, py: Python<'_>) -> ValLineError<'_> {
ValLineError {
error_type: self.error_type,
location: self.location,
input_value: self.input_value.into(),
input_value: InputValue::PyAny(self.input_value.into_ref(py)),
}
}
}
Expand Down
7 changes: 4 additions & 3 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ impl<'a> Input<'a> for JsonInput {
}

fn as_error_value(&'a self) -> InputValue<'a> {
InputValue::JsonInput(self)
// cloning JsonInput is cheap due to use of Arc
InputValue::JsonInput(self.clone())
}

fn is_none(&self) -> bool {
Expand Down Expand Up @@ -262,7 +263,7 @@ impl<'a> Input<'a> for JsonInput {
JsonInput::String(s) => Ok(string_to_vec(s).into()),
JsonInput::Object(object) => {
// return keys iterator to match python's behavior
let keys: Vec<JsonInput> = object.keys().map(|k| JsonInput::String(k.clone())).collect();
let keys: JsonArray = JsonArray::new(object.keys().map(|k| JsonInput::String(k.clone())).collect());
Ok(keys.into())
}
_ => Err(ValError::new(ErrorTypeDefaults::IterableType, self)),
Expand Down Expand Up @@ -550,5 +551,5 @@ impl<'a> Input<'a> for String {
}

fn string_to_vec(s: &str) -> JsonArray {
s.chars().map(|c| JsonInput::String(c.to_string())).collect()
JsonArray::new(s.chars().map(|c| JsonInput::String(c.to_string())).collect())
}
14 changes: 8 additions & 6 deletions src/input/parse_json.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use std::fmt;
use std::sync::Arc;

use num_bigint::BigInt;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
use serde::de::{Deserialize, DeserializeSeed, Error as SerdeError, MapAccess, SeqAccess, Visitor};
use smallvec::SmallVec;

use crate::lazy_index_map::LazyIndexMap;

Expand All @@ -20,8 +22,8 @@ pub enum JsonInput {
Array(JsonArray),
Object(JsonObject),
}
pub type JsonArray = Vec<JsonInput>;
pub type JsonObject = LazyIndexMap<String, JsonInput>;
pub type JsonArray = Arc<SmallVec<[JsonInput; 8]>>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume we can have more than 8 items right? Is 8 just the default that gets allocated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, how smallvec works is you pick a number of entries (in this case 8) which will be stored inline in the struct, and beyond that size it will instead use a heap allocation.

pub type JsonObject = Arc<LazyIndexMap<String, JsonInput>>;

impl ToPyObject for JsonInput {
fn to_object(&self, py: Python<'_>) -> PyObject {
Expand Down Expand Up @@ -111,13 +113,13 @@ impl<'de> Deserialize<'de> for JsonInput {
where
V: SeqAccess<'de>,
{
let mut vec = Vec::new();
let mut vec = SmallVec::new();

while let Some(elem) = visitor.next_element()? {
vec.push(elem);
}

Ok(JsonInput::Array(vec))
Ok(JsonInput::Array(JsonArray::new(vec)))
}

fn visit_map<V>(self, mut visitor: V) -> Result<JsonInput, V::Error>
Expand Down Expand Up @@ -171,9 +173,9 @@ impl<'de> Deserialize<'de> for JsonInput {
while let Some((key, value)) = visitor.next_entry()? {
values.insert(key, value);
}
Ok(JsonInput::Object(values))
Ok(JsonInput::Object(Arc::new(values)))
}
None => Ok(JsonInput::Object(LazyIndexMap::new())),
None => Ok(JsonInput::Object(Arc::new(LazyIndexMap::new()))),
}
}
}
Expand Down
9 changes: 4 additions & 5 deletions src/input/return_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -625,8 +625,8 @@ impl GenericPyIterator {
}
}

pub fn input<'a>(&'a self, py: Python<'a>) -> &'a PyAny {
self.obj.as_ref(py)
pub fn input_as_error_value<'py>(&self, py: Python<'py>) -> InputValue<'py> {
InputValue::PyAny(self.obj.clone_ref(py).into_ref(py))
}

pub fn index(&self) -> usize {
Expand Down Expand Up @@ -654,9 +654,8 @@ impl GenericJsonIterator {
}
}

pub fn input<'a>(&'a self, py: Python<'a>) -> &'a PyAny {
let input = JsonInput::Array(self.array.clone());
input.to_object(py).into_ref(py)
pub fn input_as_error_value<'py>(&self, _py: Python<'py>) -> InputValue<'py> {
InputValue::JsonInput(JsonInput::Array(self.array.clone()))
}

pub fn index(&self) -> usize {
Expand Down
40 changes: 18 additions & 22 deletions src/lazy_index_map.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,36 @@
use std::borrow::Borrow;
use std::cell::RefCell;
use std::cmp::{Eq, PartialEq};
use std::fmt::Debug;
use std::hash::Hash;
use std::slice::Iter as SliceIter;
use std::sync::OnceLock;

use ahash::AHashMap;
use smallvec::SmallVec;

#[derive(Debug, Clone, Default)]
pub struct LazyIndexMap<K, V> {
vec: Vec<(K, V)>,
map: RefCell<Option<AHashMap<K, usize>>>,
vec: SmallVec<[(K, V); 8]>,
map: OnceLock<AHashMap<K, usize>>,
}

/// Like [IndexMap](https://docs.rs/indexmap/latest/indexmap/) but only builds the lookup map when it's needed.
impl<K, V> LazyIndexMap<K, V>
where
K: Clone + Debug + Eq + Hash,
V: Clone + Debug,
V: Debug,
{
pub fn new() -> Self {
Self {
vec: Vec::new(),
map: RefCell::new(None),
vec: SmallVec::new(),
map: OnceLock::new(),
}
}

pub fn insert(&mut self, key: K, value: V) {
if let Some(map) = self.map.get_mut() {
map.insert(key.clone(), self.vec.len());
}
self.vec.push((key, value));
}

Expand All @@ -39,22 +43,14 @@ where
K: Borrow<Q> + PartialEq<Q>,
Q: Hash + Eq,
{
let mut map = self.map.borrow_mut();
if let Some(map) = map.as_ref() {
map.get(key).map(|&i| &self.vec[i].1)
} else {
let mut new_map = AHashMap::with_capacity(self.vec.len());
let mut value = None;
// reverse here so the last value is the one that's returned
for (index, (k, v)) in self.vec.iter().enumerate().rev() {
if value.is_none() && k == key {
value = Some(v);
}
new_map.insert(k.clone(), index);
}
*map = Some(new_map);
value
}
let map = self.map.get_or_init(|| {
self.vec
.iter()
.enumerate()
.map(|(index, (key, _))| (key.clone(), index))
.collect()
});
map.get(key).map(|&i| &self.vec[i].1)
}

pub fn keys(&self) -> impl Iterator<Item = &K> {
Expand Down
2 changes: 1 addition & 1 deletion src/validators/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ pub fn convert_err<'a>(py: Python<'a>, err: PyErr, input: &'a impl Input<'a>) ->
} else if let Ok(pydantic_error_type) = err.value(py).extract::<PydanticKnownError>() {
pydantic_error_type.into_val_error(input)
} else if let Ok(validation_error) = err.value(py).extract::<ValidationError>() {
validation_error.into_py(py)
validation_error.into_val_error(py)
} else {
py_err_string!(err.value(py), ValueError, input)
}
Expand Down
8 changes: 4 additions & 4 deletions src/validators/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,14 @@ impl ValidatorIterator {
Some(validator) => {
if let Some(max_length) = max_length {
if index >= max_length {
let val_error = ValError::new(
let val_error = ValError::new_custom_input(
ErrorType::TooLong {
field_type: "Generator".to_string(),
max_length,
actual_length: index + 1,
context: None,
},
$iter.input(py),
$iter.input_as_error_value(py),
);
return Err(ValidationError::from_val_error(
py,
Expand All @@ -153,14 +153,14 @@ impl ValidatorIterator {
None => {
if let Some(min_length) = min_length {
if $iter.index() < min_length {
let val_error = ValError::new(
let val_error = ValError::new_custom_input(
ErrorType::TooShort {
field_type: "Generator".to_string(),
min_length,
actual_length: $iter.index(),
context: None,
},
$iter.input(py),
$iter.input_as_error_value(py),
);
return Err(ValidationError::from_val_error(
py,
Expand Down
2 changes: 1 addition & 1 deletion src/validators/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl Validator for JsonValidator {
match self.validator {
Some(ref validator) => match validator.validate(py, &json_value, state) {
Ok(v) => Ok(v),
Err(err) => Err(err.duplicate(py)),
Err(err) => Err(err.into_owned(py)),
},
None => Ok(json_value.to_object(py)),
}
Expand Down