diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index f76c0f93d5e..2ac7e78e6a2 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -74,6 +74,7 @@ def __init__( legacy_transforms.Resize, [ ArgsKwargs(32), + ArgsKwargs([32]), ArgsKwargs((32, 29)), ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST), ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC), diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 1cbf02d5ae2..a62fbf4263f 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -46,11 +46,16 @@ def __init__( ) -> None: super().__init__() - self.size = ( - [size] - if isinstance(size, int) - else _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - ) + if isinstance(size, int): + size = [size] + elif isinstance(size, (list, tuple)) and len(size) in {1, 2}: + size = list(size) + else: + raise ValueError( + f"size can either be an integer or a list or tuple of one or two integers, " f"but got {size} instead." + ) + self.size = size + self.interpolation = interpolation self.max_size = max_size self.antialias = antialias