Skip to content

Commit 2d6e663

Browse files
pmeierNicolasHug
andauthored
make transforms v2 get_params a staticmethod (#7177)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent bac678c commit 2d6e663

File tree

2 files changed

+48
-27
lines changed

2 files changed

+48
-27
lines changed

test/test_prototype_transforms_consistency.py

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -649,37 +649,58 @@ def test_call_consistency(config, args_kwargs):
649649
)
650650

651651

652-
@pytest.mark.parametrize(
653-
"config",
654-
[config for config in CONSISTENCY_CONFIGS if hasattr(config.legacy_cls, "get_params")],
655-
ids=lambda config: config.legacy_cls.__name__,
652+
get_params_parametrization = pytest.mark.parametrize(
653+
("config", "get_params_args_kwargs"),
654+
[
655+
pytest.param(
656+
next(config for config in CONSISTENCY_CONFIGS if config.prototype_cls is transform_cls),
657+
get_params_args_kwargs,
658+
id=transform_cls.__name__,
659+
)
660+
for transform_cls, get_params_args_kwargs in [
661+
(prototype_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])),
662+
(prototype_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))),
663+
(prototype_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
664+
(prototype_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])),
665+
(prototype_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)),
666+
(
667+
prototype_transforms.RandomAffine,
668+
ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]),
669+
),
670+
(prototype_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))),
671+
(prototype_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),
672+
(prototype_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])),
673+
(prototype_transforms.AutoAugment, ArgsKwargs(5)),
674+
]
675+
],
656676
)
657-
def test_get_params_alias(config):
677+
678+
679+
@get_paramsl_parametrization
680+
def test_get_params_alias(config, get_params_args_kwargs):
658681
assert config.prototype_cls.get_params is config.legacy_cls.get_params
659682

683+
if not config.args_kwargs:
684+
return
685+
args, kwargs = config.args_kwargs[0]
686+
legacy_transform = config.legacy_cls(*args, **kwargs)
687+
prototype_transform = config.prototype_cls(*args, **kwargs)
660688

661-
@pytest.mark.parametrize(
662-
("transform_cls", "args_kwargs"),
663-
[
664-
(prototype_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])),
665-
(prototype_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))),
666-
(prototype_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
667-
(prototype_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])),
668-
(prototype_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)),
669-
(
670-
prototype_transforms.RandomAffine,
671-
ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]),
672-
),
673-
(prototype_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))),
674-
(prototype_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),
675-
(prototype_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])),
676-
(prototype_transforms.AutoAugment, ArgsKwargs(5)),
677-
],
678-
)
679-
def test_get_params_jit(transform_cls, args_kwargs):
680-
args, kwargs = args_kwargs
689+
assert prototype_transform.get_params is legacy_transform.get_params
690+
691+
692+
@get_paramsl_parametrization
693+
def test_get_params_jit(config, get_params_args_kwargs):
694+
get_params_args, get_params_kwargs = get_params_args_kwargs
695+
696+
torch.jit.script(config.prototype_cls.get_params)(*get_params_args, **get_params_kwargs)
697+
698+
if not config.args_kwargs:
699+
return
700+
args, kwargs = config.args_kwargs[0]
701+
transform = config.prototype_cls(*args, **kwargs)
681702

682-
torch.jit.script(transform_cls.get_params)(*args, **kwargs)
703+
torch.jit.script(transform.get_params)(*get_params_args, **get_params_kwargs)
683704

684705

685706
@pytest.mark.parametrize(

torchvision/prototype/transforms/_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init_subclass__(cls) -> None:
6767
# Since `get_params` is a `@staticmethod`, we have to bind it to the class itself rather than to an instance.
6868
# This method is called after subclassing has happened, i.e. `cls` is the subclass, e.g. `Resize`.
6969
if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"):
70-
cls.get_params = cls._v1_transform_cls.get_params # type: ignore[attr-defined]
70+
cls.get_params = staticmethod(cls._v1_transform_cls.get_params) # type: ignore[attr-defined]
7171

7272
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
7373
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current

0 commit comments

Comments
 (0)