Skip to content

Commit 82c51c4

Browse files
authored
enable get_params alias for transforms v2 (#7153)
1 parent 6bd04f6 commit 82c51c4

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

test/test_prototype_transforms_consistency.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,39 @@ def test_call_consistency(config, args_kwargs):
655655
)
656656

657657

658+
@pytest.mark.parametrize(
659+
"config",
660+
[config for config in CONSISTENCY_CONFIGS if hasattr(config.legacy_cls, "get_params")],
661+
ids=lambda config: config.legacy_cls.__name__,
662+
)
663+
def test_get_params_alias(config):
664+
assert config.prototype_cls.get_params is config.legacy_cls.get_params
665+
666+
667+
@pytest.mark.parametrize(
668+
("transform_cls", "args_kwargs"),
669+
[
670+
(prototype_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])),
671+
(prototype_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))),
672+
(prototype_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
673+
(prototype_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])),
674+
(prototype_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)),
675+
(
676+
prototype_transforms.RandomAffine,
677+
ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]),
678+
),
679+
(prototype_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))),
680+
(prototype_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),
681+
(prototype_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])),
682+
(prototype_transforms.AutoAugment, ArgsKwargs(5)),
683+
],
684+
)
685+
def test_get_params_jit(transform_cls, args_kwargs):
686+
args, kwargs = args_kwargs
687+
688+
torch.jit.script(transform_cls.get_params)(*args, **kwargs)
689+
690+
658691
@pytest.mark.parametrize(
659692
("config", "args_kwargs"),
660693
[

torchvision/prototype/transforms/_transform.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,19 @@ def extra_repr(self) -> str:
5656

5757
return ", ".join(extra)
5858

59-
# This attribute should be set on all transforms that have a v1 equivalent. Doing so enables the v2 transformation
60-
# to be scriptable. See `_extract_params_for_v1_transform()` and `__prepare_scriptable__` for details.
59+
# This attribute should be set on all transforms that have a v1 equivalent. Doing so enables two things:
60+
# 1. In case the v1 transform has a static `get_params` method, it will also be available under the same name on
61+
# the v2 transform. See `__init_subclass__` for details.
62+
# 2. The v2 transform will be JIT scriptable. See `_extract_params_for_v1_transform` and `__prepare_scriptable__`
63+
# for details.
6164
_v1_transform_cls: Optional[Type[nn.Module]] = None
6265

66+
def __init_subclass__(cls) -> None:
67+
# Since `get_params` is a `@staticmethod`, we have to bind it to the class itself rather than to an instance.
68+
# This method is called after subclassing has happened, i.e. `cls` is the subclass, e.g. `Resize`.
69+
if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"):
70+
cls.get_params = cls._v1_transform_cls.get_params # type: ignore[attr-defined]
71+
6372
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
6473
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
6574
# v2 transform instance. It does two things:

0 commit comments

Comments
 (0)