Skip to content

Commit a65f327

Browse files
Fix union validation logic when extra='allow' (#1334)
1 parent fcc77f8 commit a65f327

File tree

4 files changed

+62
-6
lines changed

4 files changed

+62
-6
lines changed

src/validators/model.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ impl Validator for ModelValidator {
204204
for field_name in validated_fields_set {
205205
fields_set.add(field_name)?;
206206
}
207-
state.add_fields_set(fields_set.len());
208207
}
209208

210209
force_setattr(py, model, intern!(py, DUNDER_DICT), validated_dict.to_object(py))?;
@@ -244,11 +243,9 @@ impl ModelValidator {
244243
};
245244
force_setattr(py, self_instance, intern!(py, DUNDER_FIELDS_SET_KEY), &fields_set)?;
246245
force_setattr(py, self_instance, intern!(py, ROOT_FIELD), &output)?;
247-
state.add_fields_set(fields_set.len());
248246
} else {
249247
let (model_dict, model_extra, fields_set): (Bound<PyAny>, Bound<PyAny>, Bound<PyAny>) =
250248
output.extract(py)?;
251-
state.add_fields_set(fields_set.len().unwrap_or(0));
252249
set_model_attrs(self_instance, &model_dict, &model_extra, &fields_set)?;
253250
}
254251
self.call_post_init(py, self_instance.clone(), input, state.extra())
@@ -287,11 +284,10 @@ impl ModelValidator {
287284
};
288285
force_setattr(py, &instance, intern!(py, DUNDER_FIELDS_SET_KEY), &fields_set)?;
289286
force_setattr(py, &instance, intern!(py, ROOT_FIELD), output)?;
290-
state.add_fields_set(fields_set.len());
291287
} else {
292-
let (model_dict, model_extra, val_fields_set) = output.extract(py)?;
288+
let (model_dict, model_extra, val_fields_set): (Bound<PyAny>, Bound<PyAny>, Bound<PyAny>) =
289+
output.extract(py)?;
293290
let fields_set = existing_fields_set.unwrap_or(&val_fields_set);
294-
state.add_fields_set(fields_set.len().unwrap_or(0));
295291
set_model_attrs(&instance, &model_dict, &model_extra, fields_set)?;
296292
}
297293
self.call_post_init(py, instance, input, state.extra())

src/validators/model_fields.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ impl Validator for ModelFieldsValidator {
150150
let mut model_extra_dict_op: Option<Bound<PyDict>> = None;
151151
let mut errors: Vec<ValLineError> = Vec::with_capacity(self.fields.len());
152152
let mut fields_set_vec: Vec<Py<PyString>> = Vec::with_capacity(self.fields.len());
153+
let mut fields_set_count: usize = 0;
153154

154155
// we only care about which keys have been used if we're iterating over the object for extra after
155156
// the first pass
@@ -184,6 +185,7 @@ impl Validator for ModelFieldsValidator {
184185
Ok(value) => {
185186
model_dict.set_item(&field.name_py, value)?;
186187
fields_set_vec.push(field.name_py.clone_ref(py));
188+
fields_set_count += 1;
187189
}
188190
Err(ValError::Omit) => continue,
189191
Err(ValError::LineErrors(line_errors)) => {
@@ -327,6 +329,7 @@ impl Validator for ModelFieldsValidator {
327329
Err(ValError::LineErrors(errors))
328330
} else {
329331
let fields_set = PySet::new_bound(py, &fields_set_vec)?;
332+
state.add_fields_set(fields_set_count);
330333

331334
// if we have extra=allow, but we didn't create a dict because we were validating
332335
// from attributes, set it now so __pydantic_extra__ is always a dict if extra=allow

src/validators/validation_state.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ pub enum Exactness {
1818
pub struct ValidationState<'a, 'py> {
1919
pub recursion_guard: &'a mut RecursionState,
2020
pub exactness: Option<Exactness>,
21+
// This is used as a tie-breaking mechanism for union validation.
22+
// Note: the count of the fields set is not always equivalent to the length of the
23+
// `model_fields_set` attached to a model. `model_fields_set` includes extra fields
24+
// when extra='allow', whereas this tally does not.
2125
pub fields_set_count: Option<usize>,
2226
// deliberately make Extra readonly
2327
extra: Extra<'a, 'py>,

tests/validators/test_union.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,3 +1280,56 @@ class ModelB:
12801280
)
12811281
assert isinstance(result, ModelB)
12821282
assert isinstance(result.b, SubModelW)
1283+
1284+
1285+
@pytest.mark.parametrize('extra_behavior', ['forbid', 'ignore', 'allow'])
1286+
def test_smart_union_extra_behavior(extra_behavior) -> None:
1287+
class Foo:
1288+
foo: str = 'foo'
1289+
1290+
class Bar:
1291+
bar: str = 'bar'
1292+
1293+
class Model:
1294+
x: Union[Foo, Bar]
1295+
1296+
validator = SchemaValidator(
1297+
core_schema.model_schema(
1298+
Model,
1299+
core_schema.model_fields_schema(
1300+
fields={
1301+
'x': core_schema.model_field(
1302+
core_schema.union_schema(
1303+
[
1304+
core_schema.model_schema(
1305+
Foo,
1306+
core_schema.model_fields_schema(
1307+
fields={
1308+
'foo': core_schema.model_field(
1309+
core_schema.with_default_schema(core_schema.str_schema(), default='foo')
1310+
)
1311+
}
1312+
),
1313+
extra_behavior=extra_behavior,
1314+
),
1315+
core_schema.model_schema(
1316+
Bar,
1317+
core_schema.model_fields_schema(
1318+
fields={
1319+
'bar': core_schema.model_field(
1320+
core_schema.with_default_schema(core_schema.str_schema(), default='bar')
1321+
)
1322+
}
1323+
),
1324+
extra_behavior=extra_behavior,
1325+
),
1326+
]
1327+
)
1328+
)
1329+
}
1330+
),
1331+
)
1332+
)
1333+
1334+
assert isinstance(validator.validate_python({'x': {'foo': 'foo'}}).x, Foo)
1335+
assert isinstance(validator.validate_python({'x': {'bar': 'bar'}}).x, Bar)

0 commit comments

Comments
 (0)