@@ -33,12 +33,13 @@ def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size:
33
33
34
34
35
35
def _check_fill_arg (fill : Union [FillType , Dict [Type , FillType ]]) -> None :
36
- if type (fill ) == dict :
37
- # Do exact type check to avoid accepting default dicts from the user. DefaultDict values can be verified only
38
- # at runtime not at construction type.
36
+ if isinstance (fill , dict ):
39
37
for key , value in fill .items ():
40
38
# Check key for type
41
39
_check_fill_arg (value )
40
+ if isinstance (fill , defaultdict ) and callable (fill .default_factory ):
41
+ default_value = fill .default_factory ()
42
+ _check_fill_arg (default_value )
42
43
else :
43
44
if fill is not None and not isinstance (fill , (numbers .Number , tuple , list )):
44
45
raise TypeError ("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed." )
@@ -75,10 +76,13 @@ def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, F
75
76
_check_fill_arg (fill )
76
77
77
78
if isinstance (fill , dict ):
78
- fill_copy = {}
79
79
for k , v in fill .items ():
80
- fill_copy [k ] = _convert_fill_arg (v )
81
- return fill_copy
80
+ fill [k ] = _convert_fill_arg (v )
81
+ if isinstance (fill , defaultdict ) and callable (fill .default_factory ):
82
+ default_value = fill .default_factory ()
83
+ sanitized_default = _convert_fill_arg (default_value )
84
+ fill .default_factory = functools .partial (_default_arg , sanitized_default )
85
+ return fill # type: ignore[return-value]
82
86
83
87
return _get_defaultdict (_convert_fill_arg (fill ))
84
88
0 commit comments