Skip to content

Commit 843bcc9

Browse files
committed
Handling of defaultdicts
1 parent 53f12bb commit 843bcc9

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

test/test_prototype_transforms_consistency.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import inspect
33
import random
44
import re
5+
from collections import defaultdict
56
from importlib.machinery import SourceFileLoader
67
from pathlib import Path
78

@@ -1039,9 +1040,7 @@ def check(self, t, t_ref, data_kwargs=None):
10391040
seg_transforms.RandomCrop(size=480),
10401041
prototype_transforms.Compose(
10411042
[
1042-
PadIfSmaller(
1043-
size=480, fill={features.Mask: 255, features.Image: 0, PIL.Image.Image: 0, torch.Tensor: 0}
1044-
),
1043+
PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {features.Mask: 255})),
10451044
prototype_transforms.RandomCrop(size=480),
10461045
]
10471046
),

torchvision/prototype/transforms/_utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,13 @@ def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size:
3333

3434

3535
def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
36-
if type(fill) == dict:
37-
# Do exact type check to avoid accepting default dicts from the user. DefaultDict values can be verified only
38-
# at runtime not at construction type.
36+
if isinstance(fill, dict):
3937
for key, value in fill.items():
4038
# Check key for type
4139
_check_fill_arg(value)
40+
if isinstance(fill, defaultdict) and callable(fill.default_factory):
41+
default_value = fill.default_factory()
42+
_check_fill_arg(default_value)
4243
else:
4344
if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
4445
raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.")
@@ -75,10 +76,13 @@ def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, F
7576
_check_fill_arg(fill)
7677

7778
if isinstance(fill, dict):
78-
fill_copy = {}
7979
for k, v in fill.items():
80-
fill_copy[k] = _convert_fill_arg(v)
81-
return fill_copy
80+
fill[k] = _convert_fill_arg(v)
81+
if isinstance(fill, defaultdict) and callable(fill.default_factory):
82+
default_value = fill.default_factory()
83+
sanitized_default = _convert_fill_arg(default_value)
84+
fill.default_factory = functools.partial(_default_arg, sanitized_default)
85+
return fill # type: ignore[return-value]
8286

8387
return _get_defaultdict(_convert_fill_arg(fill))
8488

0 commit comments

Comments
 (0)