11import inspect
2- from dataclasses import _MISSING_TYPE , MISSING , Field , field , fields
3- from functools import wraps
4- from typing import Any , Callable , ForwardRef , Literal , Optional , Type , TypeVar , Union , get_args , get_origin , overload
2+ from dataclasses import _MISSING_TYPE , MISSING , Field , field , fields , make_dataclass
3+ from functools import lru_cache , wraps
4+ from typing import (
5+ Annotated ,
6+ Any ,
7+ Callable ,
8+ ForwardRef ,
9+ Literal ,
10+ Optional ,
11+ Type ,
12+ TypeVar ,
13+ Union ,
14+ get_args ,
15+ get_origin ,
16+ overload ,
17+ )
18+
19+
20+ try :
21+ # Python 3.11+
22+ from typing import NotRequired , Required # type: ignore
23+ except ImportError :
24+ try :
25+ # In case typing_extensions is installed
26+ from typing_extensions import NotRequired , Required # type: ignore
27+ except ImportError :
28+ # Fallback: create dummy types that will never match
29+ Required = type ("Required" , (), {}) # type: ignore
30+ NotRequired = type ("NotRequired" , (), {}) # type: ignore
531
632from .errors import (
733 StrictDataclassClassValidationError ,
1238
1339Validator_T = Callable [[Any ], None ]
1440T = TypeVar ("T" )
41+ TypedDictType = TypeVar ("TypedDictType" , bound = dict [str , Any ])
42+
43+ _TYPED_DICT_DEFAULT_VALUE = object () # used as default value in TypedDict fields (to distinguish from None)
1544
1645
1746# The overload decorator helps type checkers understand the different return types
@@ -223,6 +252,92 @@ def init_with_validate(self, *args, **kwargs) -> None:
223252 return wrap (cls ) if cls is not None else wrap
224253
225254
255+ def validate_typed_dict (schema : type [TypedDictType ], data : dict ) -> None :
256+ """
257+ Validate that a dictionary conforms to the types defined in a TypedDict class.
258+
259+ Under the hood, the typed dict is converted to a strict dataclass and validated using the `@strict` decorator.
260+
261+ Args:
262+ schema (`type[TypedDictType]`):
263+ The TypedDict class defining the expected structure and types.
264+ data (`dict`):
265+ The dictionary to validate.
266+
267+ Raises:
268+ `StrictDataclassFieldValidationError`:
269+ If any field in the dictionary does not conform to the expected type.
270+
271+ Example:
272+ ```py
273+ >>> from typing import Annotated, TypedDict
274+ >>> from huggingface_hub.dataclasses import validate_typed_dict
275+
276+ >>> def positive_int(value: int):
277+ ... if not value >= 0:
278+ ... raise ValueError(f"Value must be positive, got {value}")
279+
280+ >>> class User(TypedDict):
281+ ... name: str
282+ ... age: Annotated[int, positive_int]
283+
284+ >>> # Valid data
285+ >>> validate_typed_dict(User, {"name": "John", "age": 30})
286+
287+ >>> # Invalid type for age
288+ >>> validate_typed_dict(User, {"name": "John", "age": "30"})
289+ huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
290+ TypeError: Field 'age' expected int, got str (value: '30')
291+
292+ >>> # Invalid value for age
293+ >>> validate_typed_dict(User, {"name": "John", "age": -1})
294+ huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
295+ ValueError: Value must be positive, got -1
296+ ```
297+ """
298+ # Convert typed dict to dataclass
299+ strict_cls = _build_strict_cls_from_typed_dict (schema )
300+
301+ # Validate the data by instantiating the strict dataclass
302+ strict_cls (** data ) # will raise if validation fails
303+
304+
305+ @lru_cache
306+ def _build_strict_cls_from_typed_dict (schema : type [TypedDictType ]) -> Type :
307+ # Extract type hints from the TypedDict class
308+ type_hints = {
309+ # We do not use `get_type_hints` here to avoid evaluating ForwardRefs (which might fail).
310+ # ForwardRefs are not validated by @strict anyway.
311+ name : value if value is not None else type (None )
312+ for name , value in schema .__dict__ .get ("__annotations__" , {}).items ()
313+ }
314+
315+ # If the TypedDict is not total, wrap fields as NotRequired (unless explicitly Required or NotRequired)
316+ if not getattr (schema , "__total__" , True ):
317+ for key , value in type_hints .items ():
318+ origin = get_origin (value )
319+
320+ if origin is Annotated :
321+ base , * meta = get_args (value )
322+ if not _is_required_or_notrequired (base ):
323+ base = NotRequired [base ]
324+ type_hints [key ] = Annotated [tuple ([base ] + list (meta ))]
325+ elif not _is_required_or_notrequired (value ):
326+ type_hints [key ] = NotRequired [value ]
327+
328+ # Convert type hints to dataclass fields
329+ fields = []
330+ for key , value in type_hints .items ():
331+ if get_origin (value ) is Annotated :
332+ base , * meta = get_args (value )
333+ fields .append ((key , base , field (default = _TYPED_DICT_DEFAULT_VALUE , metadata = {"validator" : meta [0 ]})))
334+ else :
335+ fields .append ((key , value , field (default = _TYPED_DICT_DEFAULT_VALUE )))
336+
337+ # Create a strict dataclass from the TypedDict fields
338+ return strict (make_dataclass (schema .__name__ , fields ))
339+
340+
226341def validated_field (
227342 validator : Union [list [Validator_T ], Validator_T ],
228343 default : Union [Any , _MISSING_TYPE ] = MISSING ,
@@ -313,6 +428,14 @@ def type_validator(name: str, value: Any, expected_type: Any) -> None:
313428 _validate_simple_type (name , value , expected_type )
314429 elif isinstance (expected_type , ForwardRef ) or isinstance (expected_type , str ):
315430 return
431+ elif origin is Required :
432+ if value is _TYPED_DICT_DEFAULT_VALUE :
433+ raise TypeError (f"Field '{ name } ' is required but missing." )
434+ _validate_simple_type (name , value , args [0 ])
435+ elif origin is NotRequired :
436+ if value is _TYPED_DICT_DEFAULT_VALUE :
437+ return
438+ _validate_simple_type (name , value , args [0 ])
316439 else :
317440 raise TypeError (f"Unsupported type for field '{ name } ': { expected_type } " )
318441
@@ -449,6 +572,11 @@ def _is_validator(validator: Any) -> bool:
449572 return True
450573
451574
575+ def _is_required_or_notrequired (type_hint : Any ) -> bool :
576+ """Helper to check if a type is Required/NotRequired."""
577+ return type_hint in (Required , NotRequired ) or (get_origin (type_hint ) in (Required , NotRequired ))
578+
579+
452580_BASIC_TYPE_VALIDATORS = {
453581 Union : _validate_union ,
454582 Literal : _validate_literal ,
@@ -461,6 +589,7 @@ def _is_validator(validator: Any) -> bool:
461589
462590__all__ = [
463591 "strict" ,
592+ "validate_typed_dict" ,
464593 "validated_field" ,
465594 "Validator_T" ,
466595 "StrictDataclassClassValidationError" ,
0 commit comments