Skip to content

Commit 379c06a

Browse files
authored
Strict typed dict validator (#3408)
* Strict typed dict validator * better type hint * work with Required / NotREquired * Make it work with Python 3.11 * correctly handle total=False * fix python 3.9+ * fix python 3.9
1 parent 942fe42 commit 379c06a

File tree

3 files changed

+272
-5
lines changed

3 files changed

+272
-5
lines changed

docs/source/en/package_reference/dataclasses.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,14 @@ The `@strict` decorator enhances a dataclass with strict validation.
188188

189189
[[autodoc]] dataclasses.strict
190190

191+
### `validate_typed_dict`
192+
193+
Method to validate that a dictionary conforms to the types defined in a `TypedDict` class.
194+
195+
This is the equivalent to dataclass validation but for `TypedDict`s. Since typed dicts are never instantiated (only used by static type checkers), validation step must be manually called.
196+
197+
[[autodoc]] dataclasses.validate_typed_dict
198+
191199
### `as_validated_field`
192200

193201
Decorator to create a [`validated_field`]. Recommended for fields with a single validator to avoid boilerplate code.

src/huggingface_hub/dataclasses.py

Lines changed: 132 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,33 @@
11
import inspect
2-
from dataclasses import _MISSING_TYPE, MISSING, Field, field, fields
3-
from functools import wraps
4-
from typing import Any, Callable, ForwardRef, Literal, Optional, Type, TypeVar, Union, get_args, get_origin, overload
2+
from dataclasses import _MISSING_TYPE, MISSING, Field, field, fields, make_dataclass
3+
from functools import lru_cache, wraps
4+
from typing import (
5+
Annotated,
6+
Any,
7+
Callable,
8+
ForwardRef,
9+
Literal,
10+
Optional,
11+
Type,
12+
TypeVar,
13+
Union,
14+
get_args,
15+
get_origin,
16+
overload,
17+
)
18+
19+
20+
try:
21+
# Python 3.11+
22+
from typing import NotRequired, Required # type: ignore
23+
except ImportError:
24+
try:
25+
# In case typing_extensions is installed
26+
from typing_extensions import NotRequired, Required # type: ignore
27+
except ImportError:
28+
# Fallback: create dummy types that will never match
29+
Required = type("Required", (), {}) # type: ignore
30+
NotRequired = type("NotRequired", (), {}) # type: ignore
531

632
from .errors import (
733
StrictDataclassClassValidationError,
@@ -12,6 +38,9 @@
1238

1339
Validator_T = Callable[[Any], None]
1440
T = TypeVar("T")
41+
TypedDictType = TypeVar("TypedDictType", bound=dict[str, Any])
42+
43+
_TYPED_DICT_DEFAULT_VALUE = object() # used as default value in TypedDict fields (to distinguish from None)
1544

1645

1746
# The overload decorator helps type checkers understand the different return types
@@ -223,6 +252,92 @@ def init_with_validate(self, *args, **kwargs) -> None:
223252
return wrap(cls) if cls is not None else wrap
224253

225254

255+
def validate_typed_dict(schema: type[TypedDictType], data: dict) -> None:
256+
"""
257+
Validate that a dictionary conforms to the types defined in a TypedDict class.
258+
259+
Under the hood, the typed dict is converted to a strict dataclass and validated using the `@strict` decorator.
260+
261+
Args:
262+
schema (`type[TypedDictType]`):
263+
The TypedDict class defining the expected structure and types.
264+
data (`dict`):
265+
The dictionary to validate.
266+
267+
Raises:
268+
`StrictDataclassFieldValidationError`:
269+
If any field in the dictionary does not conform to the expected type.
270+
271+
Example:
272+
```py
273+
>>> from typing import Annotated, TypedDict
274+
>>> from huggingface_hub.dataclasses import validate_typed_dict
275+
276+
>>> def positive_int(value: int):
277+
... if not value >= 0:
278+
... raise ValueError(f"Value must be positive, got {value}")
279+
280+
>>> class User(TypedDict):
281+
... name: str
282+
... age: Annotated[int, positive_int]
283+
284+
>>> # Valid data
285+
>>> validate_typed_dict(User, {"name": "John", "age": 30})
286+
287+
>>> # Invalid type for age
288+
>>> validate_typed_dict(User, {"name": "John", "age": "30"})
289+
huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
290+
TypeError: Field 'age' expected int, got str (value: '30')
291+
292+
>>> # Invalid value for age
293+
>>> validate_typed_dict(User, {"name": "John", "age": -1})
294+
huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
295+
ValueError: Value must be positive, got -1
296+
```
297+
"""
298+
# Convert typed dict to dataclass
299+
strict_cls = _build_strict_cls_from_typed_dict(schema)
300+
301+
# Validate the data by instantiating the strict dataclass
302+
strict_cls(**data) # will raise if validation fails
303+
304+
305+
@lru_cache
306+
def _build_strict_cls_from_typed_dict(schema: type[TypedDictType]) -> Type:
307+
# Extract type hints from the TypedDict class
308+
type_hints = {
309+
# We do not use `get_type_hints` here to avoid evaluating ForwardRefs (which might fail).
310+
# ForwardRefs are not validated by @strict anyway.
311+
name: value if value is not None else type(None)
312+
for name, value in schema.__dict__.get("__annotations__", {}).items()
313+
}
314+
315+
# If the TypedDict is not total, wrap fields as NotRequired (unless explicitly Required or NotRequired)
316+
if not getattr(schema, "__total__", True):
317+
for key, value in type_hints.items():
318+
origin = get_origin(value)
319+
320+
if origin is Annotated:
321+
base, *meta = get_args(value)
322+
if not _is_required_or_notrequired(base):
323+
base = NotRequired[base]
324+
type_hints[key] = Annotated[tuple([base] + list(meta))]
325+
elif not _is_required_or_notrequired(value):
326+
type_hints[key] = NotRequired[value]
327+
328+
# Convert type hints to dataclass fields
329+
fields = []
330+
for key, value in type_hints.items():
331+
if get_origin(value) is Annotated:
332+
base, *meta = get_args(value)
333+
fields.append((key, base, field(default=_TYPED_DICT_DEFAULT_VALUE, metadata={"validator": meta[0]})))
334+
else:
335+
fields.append((key, value, field(default=_TYPED_DICT_DEFAULT_VALUE)))
336+
337+
# Create a strict dataclass from the TypedDict fields
338+
return strict(make_dataclass(schema.__name__, fields))
339+
340+
226341
def validated_field(
227342
validator: Union[list[Validator_T], Validator_T],
228343
default: Union[Any, _MISSING_TYPE] = MISSING,
@@ -313,6 +428,14 @@ def type_validator(name: str, value: Any, expected_type: Any) -> None:
313428
_validate_simple_type(name, value, expected_type)
314429
elif isinstance(expected_type, ForwardRef) or isinstance(expected_type, str):
315430
return
431+
elif origin is Required:
432+
if value is _TYPED_DICT_DEFAULT_VALUE:
433+
raise TypeError(f"Field '{name}' is required but missing.")
434+
_validate_simple_type(name, value, args[0])
435+
elif origin is NotRequired:
436+
if value is _TYPED_DICT_DEFAULT_VALUE:
437+
return
438+
_validate_simple_type(name, value, args[0])
316439
else:
317440
raise TypeError(f"Unsupported type for field '{name}': {expected_type}")
318441

@@ -449,6 +572,11 @@ def _is_validator(validator: Any) -> bool:
449572
return True
450573

451574

575+
def _is_required_or_notrequired(type_hint: Any) -> bool:
576+
"""Helper to check if a type is Required/NotRequired."""
577+
return type_hint in (Required, NotRequired) or (get_origin(type_hint) in (Required, NotRequired))
578+
579+
452580
_BASIC_TYPE_VALIDATORS = {
453581
Union: _validate_union,
454582
Literal: _validate_literal,
@@ -461,6 +589,7 @@ def _is_validator(validator: Any) -> bool:
461589

462590
__all__ = [
463591
"strict",
592+
"validate_typed_dict",
464593
"validated_field",
465594
"Validator_T",
466595
"StrictDataclassClassValidationError",

tests/test_utils_strict_dataclass.py

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,27 @@
11
import inspect
2+
import sys
23
from dataclasses import asdict, astuple, dataclass, is_dataclass
3-
from typing import Any, Literal, Optional, Union, get_type_hints
4+
from typing import Annotated, Any, Literal, Optional, TypedDict, Union, get_type_hints
45

56
import jedi
67
import pytest
78

8-
from huggingface_hub.dataclasses import _is_validator, as_validated_field, strict, type_validator, validated_field
9+
10+
if sys.version_info >= (3, 11):
11+
from typing import NotRequired, Required
12+
else:
13+
# Provide fallbacks or skip the entire module
14+
NotRequired = None
15+
Required = None
16+
from huggingface_hub.dataclasses import (
17+
_build_strict_cls_from_typed_dict,
18+
_is_validator,
19+
as_validated_field,
20+
strict,
21+
type_validator,
22+
validate_typed_dict,
23+
validated_field,
24+
)
925
from huggingface_hub.errors import (
1026
StrictDataclassClassValidationError,
1127
StrictDataclassDefinitionError,
@@ -646,3 +662,117 @@ def validate(self):
646662
@dataclass
647663
class ConfigWithParent(ParentClass): # 'validate' already defined => should raise an error
648664
foo: int = 0
665+
666+
667+
class ConfigDict(TypedDict):
668+
str_value: str
669+
positive_int_value: Annotated[int, positive_int]
670+
forward_ref_value: "ForwardDtype"
671+
optional_value: Optional[int]
672+
673+
674+
@pytest.mark.parametrize(
675+
"data",
676+
[
677+
# All values are valid
678+
{"str_value": "foo", "positive_int_value": 1, "forward_ref_value": "bar", "optional_value": 0},
679+
],
680+
)
681+
def test_typed_dict_valid_data(data: dict):
682+
validate_typed_dict(ConfigDict, data)
683+
684+
685+
@pytest.mark.parametrize(
686+
"data",
687+
[
688+
# Optional value cannot be omitted
689+
{"str_value": "foo", "positive_int_value": 1, "forward_ref_value": "bar"},
690+
# Other fields neither
691+
{"positive_int_value": 1, "forward_ref_value": "bar", "optional_value": 0},
692+
# Not a string
693+
{"str_value": 123, "positive_int_value": 1, "forward_ref_value": "bar", "optional_value": 0},
694+
# Not an integer
695+
{"str_value": "foo", "positive_int_value": "1", "forward_ref_value": "bar", "optional_value": 0},
696+
# Annotated validator is used
697+
{"str_value": "foo", "positive_int_value": -1, "forward_ref_value": "bar", "optional_value": 0},
698+
],
699+
)
700+
def test_typed_dict_invalid_data(data: dict):
701+
with pytest.raises(StrictDataclassFieldValidationError):
702+
validate_typed_dict(ConfigDict, data)
703+
704+
705+
def test_typed_dict_error_message():
706+
with pytest.raises(StrictDataclassFieldValidationError) as exception:
707+
validate_typed_dict(
708+
ConfigDict, {"str_value": 123, "positive_int_value": 1, "forward_ref_value": "bar", "optional_value": 0}
709+
)
710+
assert "Validation error for field 'str_value'" in str(exception.value)
711+
assert "Field 'str_value' expected str, got int (value: 123)" in str(exception.value)
712+
713+
714+
def test_typed_dict_unknown_attribute():
715+
with pytest.raises(TypeError):
716+
validate_typed_dict(
717+
ConfigDict,
718+
{
719+
"str_value": "foo",
720+
"positive_int_value": 1,
721+
"forward_ref_value": "bar",
722+
"optional_value": 0,
723+
"another_value": 0,
724+
},
725+
)
726+
727+
728+
def test_typed_dict_to_dataclass_is_cached():
729+
strict_cls = _build_strict_cls_from_typed_dict(ConfigDict)
730+
strict_cls_bis = _build_strict_cls_from_typed_dict(ConfigDict)
731+
assert strict_cls is strict_cls_bis # "is" because dataclass is built only once
732+
733+
734+
@pytest.mark.skipif(sys.version_info < (3, 11), reason="Requires Python 3.11+")
735+
class TestConfigDictNotRequired:
736+
def __init__(self):
737+
# cannot be defined at class level because of Python<3.11
738+
self.ConfigDictNotRequired = TypedDict(
739+
"ConfigDictNotRequired",
740+
{"required_value": Required[int], "not_required_value": NotRequired[int]},
741+
total=False,
742+
)
743+
744+
@pytest.mark.parametrize(
745+
"data",
746+
[
747+
{"required_value": 1, "not_required_value": 2},
748+
{"required_value": 1}, # not required value is not validated
749+
],
750+
)
751+
def test_typed_dict_not_required_valid_data(self, data: dict):
752+
validate_typed_dict(self.ConfigDictNotRequired, data)
753+
754+
@pytest.mark.parametrize(
755+
"data",
756+
[
757+
# Missing required value
758+
{"not_required_value": 2},
759+
# If exists, the value is validated
760+
{"required_value": 1, "not_required_value": "2"},
761+
],
762+
)
763+
def test_typed_dict_not_required_invalid_data(self, data: dict):
764+
with pytest.raises(StrictDataclassFieldValidationError):
765+
validate_typed_dict(self.ConfigDictNotRequired, data)
766+
767+
768+
def test_typed_dict_total_true():
769+
ConfigDictTotalTrue = TypedDict("ConfigDictTotalTrue", {"value": int}, total=True)
770+
validate_typed_dict(ConfigDictTotalTrue, {"value": 1})
771+
with pytest.raises(StrictDataclassFieldValidationError):
772+
validate_typed_dict(ConfigDictTotalTrue, {})
773+
774+
775+
def test_typed_dict_total_false():
776+
ConfigDictTotalFalse = TypedDict("ConfigDictTotalFalse", {"value": int}, total=False)
777+
validate_typed_dict(ConfigDictTotalFalse, {})
778+
validate_typed_dict(ConfigDictTotalFalse, {"value": 1})

0 commit comments

Comments
 (0)