Skip to content

Commit d7cf72d

Browse files
authored
Int extraction (#1155)
1 parent 5d3aa43 commit d7cf72d

File tree

11 files changed

+64
-25
lines changed

11 files changed

+64
-25
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,6 @@ node_modules/
3636
/foobar.py
3737
/python/pydantic_core/*.so
3838
/src/self_schema.py
39+
40+
# samply
41+
/profile.json

src/errors/types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ impl From<Int> for Number {
786786

787787
impl FromPyObject<'_> for Number {
788788
fn extract(obj: &PyAny) -> PyResult<Self> {
789-
if let Ok(int) = extract_i64(obj) {
789+
if let Some(int) = extract_i64(obj) {
790790
Ok(Number::Int(int))
791791
} else if let Ok(float) = obj.extract::<f64>() {
792792
Ok(Number::Float(float))

src/errors/value_exception.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ impl PydanticCustomError {
122122
let key: &PyString = key.downcast()?;
123123
if let Ok(py_str) = value.downcast::<PyString>() {
124124
message = message.replace(&format!("{{{}}}", key.to_str()?), py_str.to_str()?);
125-
} else if let Ok(value_int) = extract_i64(value) {
125+
} else if let Some(value_int) = extract_i64(value) {
126126
message = message.replace(&format!("{{{}}}", key.to_str()?), &value_int.to_string());
127127
} else {
128128
// fallback for anything else just in case

src/input/input_python.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ impl AsLocItem for PyAny {
9696
fn as_loc_item(&self) -> LocItem {
9797
if let Ok(py_str) = self.downcast::<PyString>() {
9898
py_str.to_string_lossy().as_ref().into()
99-
} else if let Ok(key_int) = extract_i64(self) {
99+
} else if let Some(key_int) = extract_i64(self) {
100100
key_int.into()
101101
} else {
102102
safe_repr(self).to_string().into()
@@ -292,7 +292,7 @@ impl<'a> Input<'a> for PyAny {
292292
if !strict {
293293
if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::BoolParsing)? {
294294
return str_as_bool(self, &cow_str).map(ValidationMatch::lax);
295-
} else if let Ok(int) = extract_i64(self) {
295+
} else if let Some(int) = extract_i64(self) {
296296
return int_as_bool(self, int).map(ValidationMatch::lax);
297297
} else if let Ok(float) = self.extract::<f64>() {
298298
if let Ok(int) = float_as_int(self, float) {
@@ -635,7 +635,7 @@ impl<'a> Input<'a> for PyAny {
635635
bytes_as_time(self, py_bytes.as_bytes(), microseconds_overflow_behavior)
636636
} else if PyBool::is_exact_type_of(self) {
637637
Err(ValError::new(ErrorTypeDefaults::TimeType, self))
638-
} else if let Ok(int) = extract_i64(self) {
638+
} else if let Some(int) = extract_i64(self) {
639639
int_as_time(self, int, 0)
640640
} else if let Ok(float) = self.extract::<f64>() {
641641
float_as_time(self, float)
@@ -669,7 +669,7 @@ impl<'a> Input<'a> for PyAny {
669669
bytes_as_datetime(self, py_bytes.as_bytes(), microseconds_overflow_behavior)
670670
} else if PyBool::is_exact_type_of(self) {
671671
Err(ValError::new(ErrorTypeDefaults::DatetimeType, self))
672-
} else if let Ok(int) = extract_i64(self) {
672+
} else if let Some(int) = extract_i64(self) {
673673
int_as_datetime(self, int, 0)
674674
} else if let Ok(float) = self.extract::<f64>() {
675675
float_as_datetime(self, float)
@@ -706,7 +706,7 @@ impl<'a> Input<'a> for PyAny {
706706
bytes_as_timedelta(self, str.as_bytes(), microseconds_overflow_behavior)
707707
} else if let Ok(py_bytes) = self.downcast::<PyBytes>() {
708708
bytes_as_timedelta(self, py_bytes.as_bytes(), microseconds_overflow_behavior)
709-
} else if let Ok(int) = extract_i64(self) {
709+
} else if let Some(int) = extract_i64(self) {
710710
Ok(int_as_duration(self, int)?.into())
711711
} else if let Ok(float) = self.extract::<f64>() {
712712
Ok(float_as_duration(self, float)?.into())

src/input/return_enums.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use pyo3::PyTypeInfo;
2323
use serde::{ser::Error, Serialize, Serializer};
2424

2525
use crate::errors::{py_err_string, ErrorType, ErrorTypeDefaults, InputValue, ValError, ValLineError, ValResult};
26-
use crate::tools::py_err;
26+
use crate::tools::{extract_i64, py_err};
2727
use crate::validators::{CombinedValidator, Exactness, ValidationState, Validator};
2828

2929
use super::input_string::StringMapping;
@@ -863,7 +863,7 @@ pub enum EitherInt<'a> {
863863
impl<'a> EitherInt<'a> {
864864
pub fn upcast(py_any: &'a PyAny) -> ValResult<Self> {
865865
// Safety: we know that py_any is a python int
866-
if let Ok(int_64) = py_any.extract::<i64>() {
866+
if let Some(int_64) = extract_i64(py_any) {
867867
Ok(Self::I64(int_64))
868868
} else {
869869
let big_int: BigInt = py_any.extract()?;
@@ -1021,7 +1021,7 @@ impl<'a> Rem for &'a Int {
10211021

10221022
impl<'a> FromPyObject<'a> for Int {
10231023
fn extract(obj: &'a PyAny) -> PyResult<Self> {
1024-
if let Ok(i) = obj.extract::<i64>() {
1024+
if let Some(i) = extract_i64(obj) {
10251025
Ok(Int::I64(i))
10261026
} else if let Ok(b) = obj.extract::<BigInt>() {
10271027
Ok(Int::Big(b))

src/lookup_key.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ impl PathItem {
429429
} else {
430430
Ok(Self::Pos(usize_key))
431431
}
432-
} else if let Ok(int_key) = extract_i64(obj) {
432+
} else if let Some(int_key) = extract_i64(obj) {
433433
if index == 0 {
434434
py_err!(PyTypeError; "The first item in an alias path should be a string")
435435
} else {

src/serializers/infer.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,10 @@ pub(crate) fn infer_to_python_known(
123123
// `bool` and `None` can't be subclasses, `ObType::Int`, `ObType::Float`, `ObType::Str` refer to exact types
124124
ObType::None | ObType::Bool | ObType::Int | ObType::Str => value.into_py(py),
125125
// have to do this to make sure subclasses of for example str are upcast to `str`
126-
ObType::IntSubclass => extract_i64(value)?.into_py(py),
126+
ObType::IntSubclass => match extract_i64(value) {
127+
Some(v) => v.into_py(py),
128+
None => return py_err!(PyTypeError; "expected int, got {}", safe_repr(value)),
129+
},
127130
ObType::Float | ObType::FloatSubclass => {
128131
let v = value.extract::<f64>()?;
129132
if (v.is_nan() || v.is_infinite()) && extra.config.inf_nan_mode == InfNanMode::Null {

src/serializers/type_serializers/literal.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ impl BuildSerializer for LiteralSerializer {
4646
repr_args.push(item.repr()?.extract()?);
4747
if let Ok(bool) = item.downcast::<PyBool>() {
4848
expected_py.append(bool)?;
49-
} else if let Ok(int) = extract_i64(item) {
49+
} else if let Some(int) = extract_i64(item) {
5050
expected_int.insert(int);
5151
} else if let Ok(py_str) = item.downcast::<PyString>() {
5252
expected_str.insert(py_str.to_str()?.to_string());
@@ -79,7 +79,7 @@ impl LiteralSerializer {
7979
fn check<'a>(&self, value: &'a PyAny, extra: &Extra) -> PyResult<OutputValue<'a>> {
8080
if extra.check.enabled() {
8181
if !self.expected_int.is_empty() && !PyBool::is_type_of(value) {
82-
if let Ok(int) = extract_i64(value) {
82+
if let Some(int) = extract_i64(value) {
8383
if self.expected_int.contains(&int) {
8484
return Ok(OutputValue::OkInt(int));
8585
}

src/tools.rs

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
use std::borrow::Cow;
22

3-
use pyo3::exceptions::{PyKeyError, PyTypeError};
3+
use pyo3::exceptions::PyKeyError;
44
use pyo3::prelude::*;
5-
use pyo3::types::{PyDict, PyInt, PyString};
6-
use pyo3::{intern, FromPyObject, PyTypeInfo};
5+
use pyo3::types::{PyDict, PyString};
6+
use pyo3::{ffi, intern, FromPyObject};
77

88
pub trait SchemaDict<'py> {
99
fn get_as<T>(&'py self, key: &PyString) -> PyResult<Option<T>>
@@ -99,10 +99,24 @@ pub fn safe_repr(v: &PyAny) -> Cow<str> {
9999
}
100100
}
101101

102-
pub fn extract_i64(v: &PyAny) -> PyResult<i64> {
103-
if PyInt::is_type_of(v) {
104-
v.extract()
102+
/// Extract an i64 from a python object more quickly, see
103+
/// https://github.com/PyO3/pyo3/pull/3742#discussion_r1451763928
104+
#[cfg(not(any(target_pointer_width = "32", windows, PyPy)))]
105+
pub fn extract_i64(obj: &PyAny) -> Option<i64> {
106+
let val = unsafe { ffi::PyLong_AsLong(obj.as_ptr()) };
107+
if val == -1 && PyErr::occurred(obj.py()) {
108+
unsafe { ffi::PyErr_Clear() };
109+
None
105110
} else {
106-
py_err!(PyTypeError; "expected int, got {}", safe_repr(v))
111+
Some(val)
112+
}
113+
}
114+
115+
#[cfg(any(target_pointer_width = "32", windows, PyPy))]
116+
pub fn extract_i64(v: &PyAny) -> Option<i64> {
117+
if v.is_instance_of::<pyo3::types::PyInt>() {
118+
v.extract().ok()
119+
} else {
120+
None
107121
}
108122
}

tests/benchmarks/test_micro_benchmarks.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,18 @@ def test_strict_int(benchmark):
12321232
benchmark(v.validate_python, 42)
12331233

12341234

1235+
@pytest.mark.benchmark(group='strict_int')
1236+
def test_strict_int_fails(benchmark):
1237+
v = SchemaValidator(core_schema.int_schema(strict=True))
1238+
1239+
@benchmark
1240+
def t():
1241+
try:
1242+
v.validate_python(())
1243+
except ValidationError:
1244+
pass
1245+
1246+
12351247
@pytest.mark.benchmark(group='int_range')
12361248
def test_int_range(benchmark):
12371249
v = SchemaValidator(core_schema.int_schema(gt=0, lt=100))

tests/validators/test_int.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
('123456789123456.00001', Err('Input should be a valid integer, unable to parse string as an integer')),
3030
(int(1e10), int(1e10)),
3131
(i64_max, i64_max),
32+
(i64_max + 1, i64_max + 1),
33+
(i64_max * 2, i64_max * 2),
3234
pytest.param(
3335
12.5,
3436
Err('Input should be a valid integer, got a number with a fractional part [type=int_from_float'),
@@ -106,10 +108,15 @@ def test_int(input_value, expected):
106108
@pytest.mark.parametrize(
107109
'input_value,expected',
108110
[
109-
(Decimal('1'), 1),
110-
(Decimal('1.0'), 1),
111-
(i64_max, i64_max),
112-
(i64_max + 1, i64_max + 1),
111+
pytest.param(Decimal('1'), 1),
112+
pytest.param(Decimal('1.0'), 1),
113+
pytest.param(i64_max, i64_max, id='i64_max'),
114+
pytest.param(i64_max + 1, i64_max + 1, id='i64_max+1'),
115+
pytest.param(
116+
-1,
117+
Err('Input should be greater than 0 [type=greater_than, input_value=-1, input_type=int]'),
118+
id='-1',
119+
),
113120
(
114121
-i64_max + 1,
115122
Err('Input should be greater than 0 [type=greater_than, input_value=-9223372036854775806, input_type=int]'),

0 commit comments

Comments
 (0)