Skip to content

Commit d7946da

Browse files
Reuse re.Pattern object in regex patterns (#1318)
1 parent 8afaa45 commit d7946da

File tree

4 files changed

+63
-27
lines changed

4 files changed

+63
-27
lines changed

generate_self_schema.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
import decimal
1111
import importlib.util
1212
import re
13+
import sys
1314
from collections.abc import Callable
1415
from datetime import date, datetime, time, timedelta
1516
from pathlib import Path
16-
from typing import TYPE_CHECKING, Any, Dict, ForwardRef, List, Set, Type, Union
17+
from typing import TYPE_CHECKING, Any, Dict, ForwardRef, List, Pattern, Set, Type, Union
1718

1819
from typing_extensions import TypedDict, get_args, get_origin, is_typeddict
1920

@@ -46,7 +47,7 @@
4647
schema_ref_validator = {'type': 'definition-ref', 'schema_ref': 'root-schema'}
4748

4849

49-
def get_schema(obj: Any, definitions: dict[str, core_schema.CoreSchema]) -> core_schema.CoreSchema:
50+
def get_schema(obj: Any, definitions: dict[str, core_schema.CoreSchema]) -> core_schema.CoreSchema: # noqa: C901
5051
if isinstance(obj, str):
5152
return {'type': obj}
5253
elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal):
@@ -81,6 +82,9 @@ def get_schema(obj: Any, definitions: dict[str, core_schema.CoreSchema]) -> core
8182
elif issubclass(origin, Type):
8283
# can't really use 'is-instance' since this is used for the class_ parameter of 'is-instance' validators
8384
return {'type': 'any'}
85+
elif origin in (Pattern, re.Pattern):
86+
# can't really use 'is-instance' easily with Pattern, so we use `any` as a placeholder for now
87+
return {'type': 'any'}
8488
else:
8589
# debug(obj)
8690
raise TypeError(f'Unknown type: {obj!r}')
@@ -189,16 +193,12 @@ def all_literal_values(type_: type[core_schema.Literal]) -> list[any]:
189193

190194

191195
def eval_forward_ref(type_: Any) -> Any:
192-
try:
193-
try:
194-
# Python 3.12+
195-
return type_._evaluate(core_schema.__dict__, None, type_params=set(), recursive_guard=set())
196-
except TypeError:
197-
# Python 3.9+
198-
return type_._evaluate(core_schema.__dict__, None, set())
199-
except TypeError:
200-
# for Python 3.8
196+
if sys.version_info < (3, 9):
201197
return type_._evaluate(core_schema.__dict__, None)
198+
elif sys.version_info < (3, 12, 4):
199+
return type_._evaluate(core_schema.__dict__, None, recursive_guard=set())
200+
else:
201+
return type_._evaluate(core_schema.__dict__, None, type_params=set(), recursive_guard=set())
202202

203203

204204
def main() -> None:

python/pydantic_core/core_schema.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from collections.abc import Mapping
1111
from datetime import date, datetime, time, timedelta
1212
from decimal import Decimal
13-
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Set, Tuple, Type, Union
13+
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Pattern, Set, Tuple, Type, Union
1414

1515
from typing_extensions import deprecated
1616

@@ -744,7 +744,7 @@ def decimal_schema(
744744

745745
class StringSchema(TypedDict, total=False):
746746
type: Required[Literal['str']]
747-
pattern: str
747+
pattern: Union[str, Pattern[str]]
748748
max_length: int
749749
min_length: int
750750
strip_whitespace: bool
@@ -760,7 +760,7 @@ class StringSchema(TypedDict, total=False):
760760

761761
def str_schema(
762762
*,
763-
pattern: str | None = None,
763+
pattern: str | Pattern[str] | None = None,
764764
max_length: int | None = None,
765765
min_length: int | None = None,
766766
strip_whitespace: bool | None = None,

src/validators/string.rs

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ impl StrConstrainedValidator {
164164
.map(|s| s.to_str())
165165
.transpose()?
166166
.unwrap_or(RegexEngine::RUST_REGEX);
167-
Pattern::compile(py, s, regex_engine)
167+
Pattern::compile(s, regex_engine)
168168
})
169169
.transpose()?;
170170
let min_length: Option<usize> =
@@ -230,18 +230,47 @@ impl RegexEngine {
230230
}
231231

232232
impl Pattern {
233-
fn compile(py: Python<'_>, pattern: String, engine: &str) -> PyResult<Self> {
234-
let engine = match engine {
235-
RegexEngine::RUST_REGEX => {
236-
RegexEngine::RustRegex(Regex::new(&pattern).map_err(|e| py_schema_error_type!("{}", e))?)
237-
}
238-
RegexEngine::PYTHON_RE => {
239-
let re_compile = py.import_bound(intern!(py, "re"))?.getattr(intern!(py, "compile"))?;
240-
RegexEngine::PythonRe(re_compile.call1((&pattern,))?.into())
241-
}
242-
_ => return Err(py_schema_error_type!("Invalid regex engine: {}", engine)),
243-
};
244-
Ok(Self { pattern, engine })
233+
fn extract_pattern_str(pattern: &Bound<'_, PyAny>) -> PyResult<String> {
234+
if pattern.is_instance_of::<PyString>() {
235+
Ok(pattern.to_string())
236+
} else {
237+
pattern
238+
.getattr("pattern")
239+
.and_then(|attr| attr.extract::<String>())
240+
.map_err(|_| py_schema_error_type!("Invalid pattern, must be str or re.Pattern: {}", pattern))
241+
}
242+
}
243+
244+
fn compile(pattern: Bound<'_, PyAny>, engine: &str) -> PyResult<Self> {
245+
let pattern_str = Self::extract_pattern_str(&pattern)?;
246+
247+
let py = pattern.py();
248+
249+
let re_module = py.import_bound(intern!(py, "re"))?;
250+
let re_compile = re_module.getattr(intern!(py, "compile"))?;
251+
let re_pattern = re_module.getattr(intern!(py, "Pattern"))?;
252+
253+
if pattern.is_instance(&re_pattern)? {
254+
// if the pattern is already a compiled regex object, we default to using the python re engine
255+
// so that any flags, etc. are preserved
256+
Ok(Self {
257+
pattern: pattern_str,
258+
engine: RegexEngine::PythonRe(pattern.to_object(py)),
259+
})
260+
} else {
261+
let engine = match engine {
262+
RegexEngine::RUST_REGEX => {
263+
RegexEngine::RustRegex(Regex::new(&pattern_str).map_err(|e| py_schema_error_type!("{}", e))?)
264+
}
265+
RegexEngine::PYTHON_RE => RegexEngine::PythonRe(re_compile.call1((pattern,))?.into()),
266+
_ => return Err(py_schema_error_type!("Invalid regex engine: {}", engine)),
267+
};
268+
269+
Ok(Self {
270+
pattern: pattern_str,
271+
engine,
272+
})
273+
}
245274
}
246275

247276
fn is_match(&self, py: Python<'_>, target: &str) -> PyResult<bool> {

tests/validators/test_string.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,3 +398,10 @@ def test_coerce_numbers_to_str_schema_with_strict_mode(number: int):
398398
v.validate_python(number)
399399
with pytest.raises(ValidationError):
400400
v.validate_json(str(number))
401+
402+
403+
@pytest.mark.parametrize('engine', [None, 'rust-regex', 'python-re'])
404+
def test_compiled_regex(engine) -> None:
405+
v = SchemaValidator(core_schema.str_schema(pattern=re.compile('abc', re.IGNORECASE), regex_engine=engine))
406+
assert v.validate_python('abc') == 'abc'
407+
assert v.validate_python('ABC') == 'ABC'

0 commit comments

Comments
 (0)