From b9aaf4330787e715ced25338496563b96ea82a20 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 19 Jan 2023 16:44:17 +0000 Subject: [PATCH 01/15] Let Normalize() and RandomPhotometricDistort return datapoints instead of tensors --- test/test_prototype_transforms_functional.py | 4 +-- torchvision/prototype/datapoints/_image.py | 4 +++ torchvision/prototype/datapoints/_video.py | 4 +++ torchvision/prototype/transforms/_color.py | 10 +++--- .../prototype/transforms/functional/_misc.py | 31 ++++++++++++------- 5 files changed, 34 insertions(+), 19 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index d199625df0f..b9f68481d54 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -1185,15 +1185,15 @@ def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize, torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}") +# TODO: I guess we need to change the name of this test. Should we have a +# _correctness test as well like the rest? def test_normalize_output_type(): inpt = torch.rand(1, 3, 32, 32) output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]) - assert type(output) is torch.Tensor torch.testing.assert_close(inpt - 0.5, output) inpt = make_image(color_space=datapoints.ColorSpace.RGB) output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]) - assert type(output) is torch.Tensor torch.testing.assert_close(inpt - 0.5, output) diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index fc20691100f..a6a8c928334 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -289,6 +289,10 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N ) return Image.wrap_like(self, output) + def normalize(self, mean: List[float], std: List[float], inplace: bool = False): + output = self._F.normalize_image_tensor(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace) + return Image.wrap_like(self, output) + ImageType = Union[torch.Tensor, PIL.Image.Image, Image] ImageTypeJIT = torch.Tensor diff --git a/torchvision/prototype/datapoints/_video.py b/torchvision/prototype/datapoints/_video.py index 5c55d23a149..af58a5d2c1f 100644 --- a/torchvision/prototype/datapoints/_video.py +++ b/torchvision/prototype/datapoints/_video.py @@ -241,6 +241,10 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N output = self._F.gaussian_blur_video(self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma) return Video.wrap_like(self, output) + def normalize(self, mean: List[float], std: List[float], inplace: bool = False): + output = self._F.normalize_video(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace) + return Video.wrap_like(self, output) + VideoType = Union[torch.Tensor, Video] VideoTypeJIT = torch.Tensor diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 0254dd7c225..4109ccfccf7 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -82,6 +82,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return output +# TODO: Are there tests for this class? class RandomPhotometricDistort(Transform): _transformed_types = ( datapoints.Image, @@ -119,15 +120,14 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _permute_channels( self, inpt: Union[datapoints.ImageType, datapoints.VideoType], permutation: torch.Tensor ) -> Union[datapoints.ImageType, datapoints.VideoType]: - if isinstance(inpt, PIL.Image.Image): + + orig_inpt = inpt + if isinstance(orig_inpt, PIL.Image.Image): inpt = F.pil_to_tensor(inpt) output = inpt[..., permutation, :, :] - if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.OTHER) # type: ignore[arg-type] - - elif isinstance(inpt, PIL.Image.Image): + if isinstance(orig_inpt, PIL.Image.Image): output = F.to_image_pil(output) return output diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 59570768160..0a0ede90940 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -58,20 +58,27 @@ def normalize( std: List[float], inplace: bool = False, ) -> torch.Tensor: + # if torch.jit.is_scripting() or is_simple_tensor(inpt): + # return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) + # elif isinstance(inpt, datapoints._datapoint.Datapoint): + # return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma) + # elif isinstance(inpt, PIL.Image.Image): + # return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma) + # else: + # raise TypeError( + # f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + # f"but got {type(inpt)} instead." + # ) if not torch.jit.is_scripting(): _log_api_usage_once(normalize) - - if isinstance(inpt, (datapoints.Image, datapoints.Video)): - inpt = inpt.as_subclass(torch.Tensor) - elif not is_simple_tensor(inpt): - raise TypeError( - f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " - f"but got {type(inpt)} instead." - ) - - # Image or Video type should not be retained after normalization due to unknown data range - # Thus we return Tensor for input Image - return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) + elif isinstance(inpt, (datapoints.Image, datapoints.Video)): + return inpt.normalize(mean=mean, std=std, inplace=inplace) + else: + raise TypeError( + f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead." + ) def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor: From 5a3c51ac119fa547ac03233f326f130eaade45ef Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 19 Jan 2023 16:45:46 +0000 Subject: [PATCH 02/15] cleanup --- torchvision/prototype/transforms/functional/_misc.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 0a0ede90940..9d0a00f88c3 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -58,17 +58,6 @@ def normalize( std: List[float], inplace: bool = False, ) -> torch.Tensor: - # if torch.jit.is_scripting() or is_simple_tensor(inpt): - # return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) - # elif isinstance(inpt, datapoints._datapoint.Datapoint): - # return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma) - # elif isinstance(inpt, PIL.Image.Image): - # return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma) - # else: - # raise TypeError( - # f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - # f"but got {type(inpt)} instead." - # ) if not torch.jit.is_scripting(): _log_api_usage_once(normalize) if torch.jit.is_scripting() or is_simple_tensor(inpt): From cb4c0f4996395cc0b75c8f84e55860b9fc2635da Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 20 Jan 2023 08:04:12 +0000 Subject: [PATCH 03/15] Address comments --- test/prototype_transforms_dispatcher_infos.py | 1 - test/test_prototype_transforms_functional.py | 12 ------------ torchvision/prototype/transforms/_color.py | 2 +- 3 files changed, 1 insertion(+), 14 deletions(-) diff --git a/test/prototype_transforms_dispatcher_infos.py b/test/prototype_transforms_dispatcher_infos.py index b92278fef56..90e2e7f570f 100644 --- a/test/prototype_transforms_dispatcher_infos.py +++ b/test/prototype_transforms_dispatcher_infos.py @@ -426,7 +426,6 @@ def fill_sequence_needs_broadcast(args_kwargs): datapoints.Video: F.normalize_video, }, test_marks=[ - skip_dispatch_feature, xfail_jit_python_scalar_arg("mean"), xfail_jit_python_scalar_arg("std"), ], diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index b9f68481d54..0b3570250a6 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -1185,18 +1185,6 @@ def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize, torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}") -# TODO: I guess we need to change the name of this test. Should we have a -# _correctness test as well like the rest? -def test_normalize_output_type(): - inpt = torch.rand(1, 3, 32, 32) - output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]) - torch.testing.assert_close(inpt - 0.5, output) - - inpt = make_image(color_space=datapoints.ColorSpace.RGB) - output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]) - torch.testing.assert_close(inpt - 0.5, output) - - @pytest.mark.parametrize( "inpt", [ diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 4109ccfccf7..0eb20e57764 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -82,7 +82,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return output -# TODO: Are there tests for this class? +# TODO: This class seems to be untested class RandomPhotometricDistort(Transform): _transformed_types = ( datapoints.Image, From ac94d7259a564d5cd52777811e1677ec18bf7a32 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 20 Jan 2023 09:00:20 +0000 Subject: [PATCH 04/15] lint + types --- test/test_prototype_transforms_functional.py | 2 +- torchvision/prototype/datapoints/_image.py | 2 +- torchvision/prototype/datapoints/_video.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 0b3570250a6..a80e0f4570d 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -13,7 +13,7 @@ import torchvision.prototype.transforms.utils from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed -from prototype_common_utils import assert_close, make_bounding_boxes, make_image, parametrized_error_message +from prototype_common_utils import assert_close, make_bounding_boxes, parametrized_error_message from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS from prototype_transforms_kernel_infos import KERNEL_INFOS from torch.utils._pytree import tree_map diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index a6a8c928334..d674745a716 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -289,7 +289,7 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N ) return Image.wrap_like(self, output) - def normalize(self, mean: List[float], std: List[float], inplace: bool = False): + def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Image: output = self._F.normalize_image_tensor(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace) return Image.wrap_like(self, output) diff --git a/torchvision/prototype/datapoints/_video.py b/torchvision/prototype/datapoints/_video.py index af58a5d2c1f..c7273874655 100644 --- a/torchvision/prototype/datapoints/_video.py +++ b/torchvision/prototype/datapoints/_video.py @@ -241,7 +241,7 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N output = self._F.gaussian_blur_video(self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma) return Video.wrap_like(self, output) - def normalize(self, mean: List[float], std: List[float], inplace: bool = False): + def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Video: output = self._F.normalize_video(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace) return Video.wrap_like(self, output) From 5a407af977a045bfcae2c52e9dbdaa4163226ebc Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 20 Jan 2023 14:03:47 +0000 Subject: [PATCH 05/15] Remove color_space meta-data from Image and Video --- test/prototype_common_utils.py | 8 ++--- torchvision/prototype/datapoints/_image.py | 25 +++---------- torchvision/prototype/datapoints/_video.py | 28 ++++----------- .../prototype/transforms/_deprecated.py | 35 +++---------------- torchvision/transforms/functional.py | 3 ++ 5 files changed, 22 insertions(+), 77 deletions(-) diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index 18664eb0945..78c016ef9f8 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -278,7 +278,7 @@ def fn(shape, dtype, device): data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device) if color_space in {datapoints.ColorSpace.GRAY_ALPHA, datapoints.ColorSpace.RGB_ALPHA} and constant_alpha: data[..., -1, :, :] = max_value - return datapoints.Image(data, color_space=color_space) + return datapoints.Image(data) return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, color_space=color_space) @@ -328,7 +328,7 @@ def fn(shape, dtype, device): image_tensor = convert_dtype_image_tensor(to_image_tensor(image_pil).to(device=device), dtype=dtype) - return datapoints.Image(image_tensor, color_space=color_space) + return datapoints.Image(image_tensor) return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, color_space=color_space) @@ -592,8 +592,8 @@ def make_video_loader( num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames def fn(shape, dtype, device): - video = make_image(size=shape[-2:], color_space=color_space, extra_dims=shape[:-3], dtype=dtype, device=device) - return datapoints.Video(video, color_space=color_space) + video = make_image(size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device) + return datapoints.Video(video) return VideoLoader( fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype, color_space=color_space diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index d674745a716..ba8f30fbd5a 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -58,19 +58,16 @@ def _from_tensor_shape(shape: List[int]) -> ColorSpace: class Image(Datapoint): - color_space: ColorSpace @classmethod - def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Image: + def _wrap(cls, tensor: torch.Tensor) -> Image: image = tensor.as_subclass(cls) - image.color_space = color_space return image def __new__( cls, data: Any, *, - color_space: Optional[Union[ColorSpace, str]] = None, dtype: Optional[torch.dtype] = None, device: Optional[Union[torch.device, str, int]] = None, requires_grad: bool = False, @@ -81,26 +78,14 @@ def __new__( elif tensor.ndim == 2: tensor = tensor.unsqueeze(0) - if color_space is None: - color_space = ColorSpace.from_tensor_shape(tensor.shape) # type: ignore[arg-type] - if color_space == ColorSpace.OTHER: - warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.") - elif isinstance(color_space, str): - color_space = ColorSpace.from_str(color_space.upper()) - elif not isinstance(color_space, ColorSpace): - raise ValueError - - return cls._wrap(tensor, color_space=color_space) + return cls._wrap(tensor) @classmethod - def wrap_like(cls, other: Image, tensor: torch.Tensor, *, color_space: Optional[ColorSpace] = None) -> Image: - return cls._wrap( - tensor, - color_space=color_space if color_space is not None else other.color_space, - ) + def wrap_like(cls, other: Image, tensor: torch.Tensor) -> Image: + return cls._wrap(tensor) def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] - return self._make_repr(color_space=self.color_space) + return self._make_repr() @property def spatial_size(self) -> Tuple[int, int]: diff --git a/torchvision/prototype/datapoints/_video.py b/torchvision/prototype/datapoints/_video.py index c7273874655..fff56d79f0d 100644 --- a/torchvision/prototype/datapoints/_video.py +++ b/torchvision/prototype/datapoints/_video.py @@ -7,23 +7,19 @@ from torchvision.transforms.functional import InterpolationMode from ._datapoint import Datapoint, FillTypeJIT -from ._image import ColorSpace class Video(Datapoint): - color_space: ColorSpace @classmethod - def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Video: + def _wrap(cls, tensor: torch.Tensor) -> Video: video = tensor.as_subclass(cls) - video.color_space = color_space return video def __new__( cls, data: Any, *, - color_space: Optional[Union[ColorSpace, str]] = None, dtype: Optional[torch.dtype] = None, device: Optional[Union[torch.device, str, int]] = None, requires_grad: bool = False, @@ -32,27 +28,15 @@ def __new__( if data.ndim < 4: raise ValueError video = super().__new__(cls, data, requires_grad=requires_grad) - - if color_space is None: - color_space = ColorSpace.from_tensor_shape(video.shape) # type: ignore[arg-type] - if color_space == ColorSpace.OTHER: - warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.") - elif isinstance(color_space, str): - color_space = ColorSpace.from_str(color_space.upper()) - elif not isinstance(color_space, ColorSpace): - raise ValueError - - return cls._wrap(tensor, color_space=color_space) + # TODO: Should this be `video` or can we remove it? + return cls._wrap(tensor) @classmethod - def wrap_like(cls, other: Video, tensor: torch.Tensor, *, color_space: Optional[ColorSpace] = None) -> Video: - return cls._wrap( - tensor, - color_space=color_space if color_space is not None else other.color_space, - ) + def wrap_like(cls, other: Video, tensor: torch.Tensor) -> Video: + return cls._wrap(tensor) def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] - return self._make_repr(color_space=self.color_space) + return self._make_repr() @property def spatial_size(self) -> Tuple[int, int]: diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index 3247a8051a3..aa65a79dcad 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -37,23 +37,6 @@ class Grayscale(Transform): ) def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: - deprecation_msg = ( - f"The transform `Grayscale(num_output_channels={num_output_channels})` " - f"is deprecated and will be removed in a future release." - ) - if num_output_channels == 1: - replacement_msg = ( - "transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)" - ) - else: - replacement_msg = ( - "transforms.Compose(\n" - " transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY),\n" - " transforms.ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB),\n" - ")" - ) - warnings.warn(f"{deprecation_msg} Instead, please use\n\n{replacement_msg}") - super().__init__() self.num_output_channels = num_output_channels @@ -62,7 +45,8 @@ def _transform( ) -> Union[datapoints.ImageType, datapoints.VideoType]: output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.GRAY) # type: ignore[arg-type] + # TODO: Q: is the wrapping still needed? Is the type ignore still needed? + output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] return output @@ -75,18 +59,6 @@ class RandomGrayscale(_RandomApplyTransform): ) def __init__(self, p: float = 0.1) -> None: - warnings.warn( - "The transform `RandomGrayscale(p=...)` is deprecated and will be removed in a future release. " - "Instead, please use\n\n" - "transforms.RandomApply(\n" - " transforms.Compose(\n" - " transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY),\n" - " transforms.ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB),\n" - " )\n" - " p=...,\n" - ")" - ) - super().__init__(p=p) def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: @@ -98,5 +70,6 @@ def _transform( ) -> Union[datapoints.ImageType, datapoints.VideoType]: output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.GRAY) # type: ignore[arg-type] + # TODO: Same as the other TODO above + output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] return output diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 5d662a2c1d1..69965126d8c 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1234,6 +1234,9 @@ def affine( return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill) +# Looks like to_grayscale() is a stand-alone functional that is never called +# from the transform classes. Perhaps it's still here for BC? I can't be +# bothered to dig. Anyway, this can be deprecated as we migrate to V2. @torch.jit.unused def to_grayscale(img, num_output_channels=1): """Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image. From 07cc5e9f9ba6867d9d58c2980e5d33041e3efe1c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 20 Jan 2023 15:26:09 +0000 Subject: [PATCH 06/15] Remove functional part --- test/prototype_transforms_kernel_infos.py | 63 --------- test/test_prototype_transforms.py | 39 ------ test/test_prototype_transforms_functional.py | 1 - torchvision/prototype/transforms/__init__.py | 2 +- torchvision/prototype/transforms/_meta.py | 29 ---- .../transforms/functional/__init__.py | 4 - .../prototype/transforms/functional/_meta.py | 124 ------------------ 7 files changed, 1 insertion(+), 261 deletions(-) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index ded888a4a00..5f471e6e4bd 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -684,69 +684,6 @@ def reference_inputs_convert_format_bounding_box(): ) -def sample_inputs_convert_color_space_image_tensor(): - color_spaces = sorted( - set(datapoints.ColorSpace) - {datapoints.ColorSpace.OTHER}, key=lambda color_space: color_space.value - ) - - for old_color_space, new_color_space in cycle_over(color_spaces): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=[old_color_space], constant_alpha=True): - yield ArgsKwargs(image_loader, old_color_space=old_color_space, new_color_space=new_color_space) - - for color_space in color_spaces: - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=[color_space], dtypes=[torch.float32], constant_alpha=True - ): - yield ArgsKwargs(image_loader, old_color_space=color_space, new_color_space=color_space) - - -@pil_reference_wrapper -def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_color_space): - color_space_pil = datapoints.ColorSpace.from_pil_mode(image_pil.mode) - if color_space_pil != old_color_space: - raise pytest.UsageError( - f"Converting the tensor image into an PIL image changed the colorspace " - f"from {old_color_space} to {color_space_pil}" - ) - - return F.convert_color_space_image_pil(image_pil, color_space=new_color_space) - - -def reference_inputs_convert_color_space_image_tensor(): - for args_kwargs in sample_inputs_convert_color_space_image_tensor(): - (image_loader, *other_args), kwargs = args_kwargs - if len(image_loader.shape) == 3 and image_loader.dtype == torch.uint8: - yield args_kwargs - - -def sample_inputs_convert_color_space_video(): - color_spaces = [datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB] - - for old_color_space, new_color_space in cycle_over(color_spaces): - for video_loader in make_video_loaders(sizes=["random"], color_spaces=[old_color_space], num_frames=["random"]): - yield ArgsKwargs(video_loader, old_color_space=old_color_space, new_color_space=new_color_space) - - -KERNEL_INFOS.extend( - [ - KernelInfo( - F.convert_color_space_image_tensor, - sample_inputs_fn=sample_inputs_convert_color_space_image_tensor, - reference_fn=reference_convert_color_space_image_tensor, - reference_inputs_fn=reference_inputs_convert_color_space_image_tensor, - closeness_kwargs={ - **pil_reference_pixel_difference(), - **float32_vs_uint8_pixel_difference(), - }, - ), - KernelInfo( - F.convert_color_space_video, - sample_inputs_fn=sample_inputs_convert_color_space_video, - ), - ] -) - - def sample_inputs_vertical_flip_image_tensor(): for image_loader in make_image_loaders(sizes=["random"], dtypes=[torch.float32]): yield ArgsKwargs(image_loader) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 3826293f3ed..097b3edf010 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -221,45 +221,6 @@ def test_normalize(self, transform, input): def test_random_resized_crop(self, transform, input): transform(input) - @parametrize( - [ - ( - transforms.ConvertColorSpace(color_space=new_color_space, old_color_space=old_color_space), - itertools.chain.from_iterable( - [ - fn(color_spaces=[old_color_space]) - for fn in ( - make_images, - make_vanilla_tensor_images, - make_pil_images, - make_videos, - ) - ] - ), - ) - for old_color_space, new_color_space in itertools.product( - [ - datapoints.ColorSpace.GRAY, - datapoints.ColorSpace.GRAY_ALPHA, - datapoints.ColorSpace.RGB, - datapoints.ColorSpace.RGB_ALPHA, - ], - repeat=2, - ) - ] - ) - def test_convert_color_space(self, transform, input): - transform(input) - - def test_convert_color_space_unsupported_types(self): - transform = transforms.ConvertColorSpace( - color_space=datapoints.ColorSpace.RGB, old_color_space=datapoints.ColorSpace.GRAY - ) - - for inpt in [make_bounding_box(format="XYXY"), make_masks()]: - output = transform(inpt) - assert output is inpt - @pytest.mark.parametrize("p", [0.0, 1.0]) class TestRandomHorizontalFlip: diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 102f78e6e11..a098f762687 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -335,7 +335,6 @@ def test_scripted_smoke(self, info, args_kwargs, device): "dispatcher", [ F.clamp_bounding_box, - F.convert_color_space, F.get_dimensions, F.get_image_num_channels, F.get_image_size, diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 04b007190b8..fa75cf63339 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -39,7 +39,7 @@ ScaleJitter, TenCrop, ) -from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertDtype, ConvertImageDtype +from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype from ._misc import ( GaussianBlur, Identity, diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 6ad9e041098..3c95c9b8e66 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -46,35 +46,6 @@ def _transform( ConvertImageDtype = ConvertDtype -class ConvertColorSpace(Transform): - _transformed_types = ( - is_simple_tensor, - datapoints.Image, - PIL.Image.Image, - datapoints.Video, - ) - - def __init__( - self, - color_space: Union[str, datapoints.ColorSpace], - old_color_space: Optional[Union[str, datapoints.ColorSpace]] = None, - ) -> None: - super().__init__() - - if isinstance(color_space, str): - color_space = datapoints.ColorSpace.from_str(color_space) - self.color_space = color_space - - if isinstance(old_color_space, str): - old_color_space = datapoints.ColorSpace.from_str(old_color_space) - self.old_color_space = old_color_space - - def _transform( - self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] - ) -> Union[datapoints.ImageType, datapoints.VideoType]: - return F.convert_color_space(inpt, color_space=self.color_space, old_color_space=self.old_color_space) - - class ClampBoundingBoxes(Transform): _transformed_types = (datapoints.BoundingBox,) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 30ef6e3fc99..57b4cc4423a 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -7,10 +7,6 @@ from ._meta import ( clamp_bounding_box, convert_format_bounding_box, - convert_color_space_image_tensor, - convert_color_space_image_pil, - convert_color_space_video, - convert_color_space, convert_dtype_image_tensor, convert_dtype, convert_dtype_video, diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 62f9664fc47..547159f30c0 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -225,29 +225,6 @@ def clamp_bounding_box( return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True) -def _strip_alpha(image: torch.Tensor) -> torch.Tensor: - image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3) - if not torch.all(alpha == _max_value(alpha.dtype)): - raise RuntimeError( - "Stripping the alpha channel if it contains values other than the max value is not supported." - ) - return image - - -def _add_alpha(image: torch.Tensor, alpha: Optional[torch.Tensor] = None) -> torch.Tensor: - if alpha is None: - shape = list(image.shape) - shape[-3] = 1 - alpha = torch.full(shape, _max_value(image.dtype), dtype=image.dtype, device=image.device) - return torch.cat((image, alpha), dim=-3) - - -def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor: - repeats = [1] * grayscale.ndim - repeats[-3] = 3 - return grayscale.repeat(repeats) - - def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor: r, g, b = image.unbind(dim=-3) l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) @@ -257,107 +234,6 @@ def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor: return l_img -def convert_color_space_image_tensor( - image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace -) -> torch.Tensor: - if new_color_space == old_color_space: - return image - - if old_color_space == ColorSpace.OTHER or new_color_space == ColorSpace.OTHER: - raise RuntimeError(f"Conversion to or from {ColorSpace.OTHER} is not supported.") - - if old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.GRAY_ALPHA: - return _add_alpha(image) - elif old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.RGB: - return _gray_to_rgb(image) - elif old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.RGB_ALPHA: - return _add_alpha(_gray_to_rgb(image)) - elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.GRAY: - return _strip_alpha(image) - elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB: - return _gray_to_rgb(_strip_alpha(image)) - elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB_ALPHA: - image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3) - return _add_alpha(_gray_to_rgb(image), alpha) - elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY: - return _rgb_to_gray(image) - elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY_ALPHA: - return _add_alpha(_rgb_to_gray(image)) - elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.RGB_ALPHA: - return _add_alpha(image) - elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY: - return _rgb_to_gray(_strip_alpha(image)) - elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY_ALPHA: - image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3) - return _add_alpha(_rgb_to_gray(image), alpha) - elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.RGB: - return _strip_alpha(image) - else: - raise RuntimeError(f"Conversion from {old_color_space} to {new_color_space} is not supported.") - - -_COLOR_SPACE_TO_PIL_MODE = { - ColorSpace.GRAY: "L", - ColorSpace.GRAY_ALPHA: "LA", - ColorSpace.RGB: "RGB", - ColorSpace.RGB_ALPHA: "RGBA", -} - - -@torch.jit.unused -def convert_color_space_image_pil(image: PIL.Image.Image, color_space: ColorSpace) -> PIL.Image.Image: - old_mode = image.mode - try: - new_mode = _COLOR_SPACE_TO_PIL_MODE[color_space] - except KeyError: - raise ValueError(f"Conversion from {ColorSpace.from_pil_mode(old_mode)} to {color_space} is not supported.") - - if image.mode == new_mode: - return image - - return image.convert(new_mode) - - -def convert_color_space_video( - video: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace -) -> torch.Tensor: - return convert_color_space_image_tensor(video, old_color_space=old_color_space, new_color_space=new_color_space) - - -def convert_color_space( - inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], - color_space: ColorSpace, - old_color_space: Optional[ColorSpace] = None, -) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: - if not torch.jit.is_scripting(): - _log_api_usage_once(convert_color_space) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - if old_color_space is None: - raise RuntimeError( - "In order to convert the color space of simple tensors, " - "the `old_color_space=...` parameter needs to be passed." - ) - return convert_color_space_image_tensor(inpt, old_color_space=old_color_space, new_color_space=color_space) - elif isinstance(inpt, datapoints.Image): - output = convert_color_space_image_tensor( - inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space - ) - return datapoints.Image.wrap_like(inpt, output, color_space=color_space) - elif isinstance(inpt, datapoints.Video): - output = convert_color_space_video( - inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space - ) - return datapoints.Video.wrap_like(inpt, output, color_space=color_space) - elif isinstance(inpt, PIL.Image.Image): - return convert_color_space_image_pil(inpt, color_space=color_space) - else: - raise TypeError( - f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) - - def _num_value_bits(dtype: torch.dtype) -> int: if dtype == torch.uint8: return 8 From e83ad89472c1f78e3a793794c1269a08d3b34e83 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 20 Jan 2023 15:27:19 +0000 Subject: [PATCH 07/15] Undo changes to Grayscale, will address in another PR --- .../prototype/transforms/_deprecated.py | 36 ++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index aa65a79dcad..ff0c6306d51 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -28,6 +28,7 @@ def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, return _F.to_tensor(inpt) +# TODO: in other PR (?) undeprecate those and make them use _rgb_to_gray? class Grayscale(Transform): _transformed_types = ( datapoints.Image, @@ -37,6 +38,23 @@ class Grayscale(Transform): ) def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: + deprecation_msg = ( + f"The transform `Grayscale(num_output_channels={num_output_channels})` " + f"is deprecated and will be removed in a future release." + ) + if num_output_channels == 1: + replacement_msg = ( + "transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)" + ) + else: + replacement_msg = ( + "transforms.Compose(\n" + " transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY),\n" + " transforms.ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB),\n" + ")" + ) + warnings.warn(f"{deprecation_msg} Instead, please use\n\n{replacement_msg}") + super().__init__() self.num_output_channels = num_output_channels @@ -45,8 +63,7 @@ def _transform( ) -> Union[datapoints.ImageType, datapoints.VideoType]: output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) if isinstance(inpt, (datapoints.Image, datapoints.Video)): - # TODO: Q: is the wrapping still needed? Is the type ignore still needed? - output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] + output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.GRAY) # type: ignore[arg-type] return output @@ -59,6 +76,18 @@ class RandomGrayscale(_RandomApplyTransform): ) def __init__(self, p: float = 0.1) -> None: + warnings.warn( + "The transform `RandomGrayscale(p=...)` is deprecated and will be removed in a future release. " + "Instead, please use\n\n" + "transforms.RandomApply(\n" + " transforms.Compose(\n" + " transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY),\n" + " transforms.ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB),\n" + " )\n" + " p=...,\n" + ")" + ) + super().__init__(p=p) def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: @@ -70,6 +99,5 @@ def _transform( ) -> Union[datapoints.ImageType, datapoints.VideoType]: output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) if isinstance(inpt, (datapoints.Image, datapoints.Video)): - # TODO: Same as the other TODO above - output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] + output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.GRAY) # type: ignore[arg-type] return output From f9151fa3f877dea032cdcf54255dca5aca2856d0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 20 Jan 2023 16:03:38 +0000 Subject: [PATCH 08/15] Updated tests to not use datapoints.ColorSpace --- test/prototype_common_utils.py | 48 +++++++----- test/prototype_transforms_kernel_infos.py | 73 ++++++++++--------- test/test_prototype_transforms.py | 9 ++- test/test_prototype_transforms_consistency.py | 23 +++--- test/test_prototype_transforms_utils.py | 4 +- torchvision/prototype/datapoints/__init__.py | 2 +- torchvision/prototype/datapoints/_image.py | 46 ------------ .../prototype/transforms/_deprecated.py | 4 +- .../prototype/transforms/functional/_meta.py | 2 +- 9 files changed, 88 insertions(+), 123 deletions(-) diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index 78c016ef9f8..433b6775d84 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -206,6 +206,14 @@ def _parse_spatial_size(size, *, name="size"): DEFAULT_EXTRA_DIMS = (*VALID_EXTRA_DIMS, *DEGENERATE_BATCH_DIMS) +from enum import Enum +class ColorSpace(Enum): + OTHER = "OTHER" + GRAY = "GRAY" + GRAY_ALPHA = "GRAY_ALPHA" + RGB = "RGB" + RGB_ALPHA = "RGB_ALPHA" + def from_loader(loader_fn): def wrapper(*args, **kwargs): @@ -238,7 +246,7 @@ def load(self, device): @dataclasses.dataclass class ImageLoader(TensorLoader): - color_space: datapoints.ColorSpace + color_space: ColorSpace spatial_size: Tuple[int, int] = dataclasses.field(init=False) num_channels: int = dataclasses.field(init=False) @@ -248,10 +256,10 @@ def __post_init__(self): NUM_CHANNELS_MAP = { - datapoints.ColorSpace.GRAY: 1, - datapoints.ColorSpace.GRAY_ALPHA: 2, - datapoints.ColorSpace.RGB: 3, - datapoints.ColorSpace.RGB_ALPHA: 4, + ColorSpace.GRAY: 1, + ColorSpace.GRAY_ALPHA: 2, + ColorSpace.RGB: 3, + ColorSpace.RGB_ALPHA: 4, } @@ -265,7 +273,7 @@ def get_num_channels(color_space): def make_image_loader( size="random", *, - color_space=datapoints.ColorSpace.RGB, + color_space=ColorSpace.RGB, extra_dims=(), dtype=torch.float32, constant_alpha=True, @@ -276,7 +284,7 @@ def make_image_loader( def fn(shape, dtype, device): max_value = get_max_value(dtype) data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device) - if color_space in {datapoints.ColorSpace.GRAY_ALPHA, datapoints.ColorSpace.RGB_ALPHA} and constant_alpha: + if color_space in {ColorSpace.GRAY_ALPHA, ColorSpace.RGB_ALPHA} and constant_alpha: data[..., -1, :, :] = max_value return datapoints.Image(data) @@ -290,10 +298,10 @@ def make_image_loaders( *, sizes=DEFAULT_SPATIAL_SIZES, color_spaces=( - datapoints.ColorSpace.GRAY, - datapoints.ColorSpace.GRAY_ALPHA, - datapoints.ColorSpace.RGB, - datapoints.ColorSpace.RGB_ALPHA, + ColorSpace.GRAY, + ColorSpace.GRAY_ALPHA, + ColorSpace.RGB, + ColorSpace.RGB_ALPHA, ), extra_dims=DEFAULT_EXTRA_DIMS, dtypes=(torch.float32, torch.uint8), @@ -306,7 +314,7 @@ def make_image_loaders( make_images = from_loaders(make_image_loaders) -def make_image_loader_for_interpolation(size="random", *, color_space=datapoints.ColorSpace.RGB, dtype=torch.uint8): +def make_image_loader_for_interpolation(size="random", *, color_space=ColorSpace.RGB, dtype=torch.uint8): size = _parse_spatial_size(size) num_channels = get_num_channels(color_space) @@ -318,10 +326,10 @@ def fn(shape, dtype, device): .resize((width, height)) .convert( { - datapoints.ColorSpace.GRAY: "L", - datapoints.ColorSpace.GRAY_ALPHA: "LA", - datapoints.ColorSpace.RGB: "RGB", - datapoints.ColorSpace.RGB_ALPHA: "RGBA", + ColorSpace.GRAY: "L", + ColorSpace.GRAY_ALPHA: "LA", + ColorSpace.RGB: "RGB", + ColorSpace.RGB_ALPHA: "RGBA", }[color_space] ) ) @@ -335,7 +343,7 @@ def fn(shape, dtype, device): def make_image_loaders_for_interpolation( sizes=((233, 147),), - color_spaces=(datapoints.ColorSpace.RGB,), + color_spaces=(ColorSpace.RGB,), dtypes=(torch.uint8,), ): for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes): @@ -583,7 +591,7 @@ class VideoLoader(ImageLoader): def make_video_loader( size="random", *, - color_space=datapoints.ColorSpace.RGB, + color_space=ColorSpace.RGB, num_frames="random", extra_dims=(), dtype=torch.uint8, @@ -607,8 +615,8 @@ def make_video_loaders( *, sizes=DEFAULT_SPATIAL_SIZES, color_spaces=( - datapoints.ColorSpace.GRAY, - datapoints.ColorSpace.RGB, + ColorSpace.GRAY, + ColorSpace.RGB, ), num_frames=(1, 0, "random"), extra_dims=DEFAULT_EXTRA_DIMS, diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 5f471e6e4bd..1e38e5c8d2b 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -13,6 +13,7 @@ from datasets_utils import combinations_grid from prototype_common_utils import ( ArgsKwargs, + ColorSpace, get_num_channels, ImageLoader, InfoBase, @@ -262,13 +263,13 @@ def _get_resize_sizes(spatial_size): def sample_inputs_resize_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32] + sizes=["random"], color_spaces=[ColorSpace.RGB], dtypes=[torch.float32] ): for size in _get_resize_sizes(image_loader.spatial_size): yield ArgsKwargs(image_loader, size=size) for image_loader, interpolation in itertools.product( - make_image_loaders(sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB]), + make_image_loaders(sizes=["random"], color_spaces=[ColorSpace.RGB]), [ F.InterpolationMode.NEAREST, F.InterpolationMode.BILINEAR, @@ -472,7 +473,7 @@ def float32_vs_uint8_fill_adapter(other_args, kwargs): def sample_inputs_affine_image_tensor(): make_affine_image_loaders = functools.partial( - make_image_loaders, sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32] + make_image_loaders, sizes=["random"], color_spaces=[ColorSpace.RGB], dtypes=[torch.float32] ) for image_loader, affine_params in itertools.product(make_affine_image_loaders(), _DIVERSE_AFFINE_PARAMS): @@ -759,7 +760,7 @@ def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size): def sample_inputs_rotate_image_tensor(): make_rotate_image_loaders = functools.partial( - make_image_loaders, sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32] + make_image_loaders, sizes=["random"], color_spaces=[ColorSpace.RGB], dtypes=[torch.float32] ) for image_loader in make_rotate_image_loaders(): @@ -841,7 +842,7 @@ def sample_inputs_rotate_video(): def sample_inputs_crop_image_tensor(): for image_loader, params in itertools.product( - make_image_loaders(sizes=[(16, 17)], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]), + make_image_loaders(sizes=[(16, 17)], color_spaces=[ColorSpace.RGB], dtypes=[torch.float32]), [ dict(top=4, left=3, height=7, width=8), dict(top=-1, left=3, height=7, width=8), @@ -1027,7 +1028,7 @@ def sample_inputs_resized_crop_video(): def sample_inputs_pad_image_tensor(): make_pad_image_loaders = functools.partial( - make_image_loaders, sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32] + make_image_loaders, sizes=["random"], color_spaces=[ColorSpace.RGB], dtypes=[torch.float32] ) for image_loader, padding in itertools.product( @@ -1343,7 +1344,7 @@ def sample_inputs_elastic_video(): def sample_inputs_center_crop_image_tensor(): for image_loader, output_size in itertools.product( - make_image_loaders(sizes=[(16, 17)], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]), + make_image_loaders(sizes=[(16, 17)], color_spaces=[ColorSpace.RGB], dtypes=[torch.float32]), [ # valid `output_size` types for which cropping is applied to both dimensions *[5, (4,), (2, 3), [6], [3, 2]], @@ -1430,7 +1431,7 @@ def sample_inputs_center_crop_video(): def sample_inputs_gaussian_blur_image_tensor(): make_gaussian_blur_image_loaders = functools.partial( - make_image_loaders, sizes=[(7, 33)], color_spaces=[datapoints.ColorSpace.RGB] + make_image_loaders, sizes=[(7, 33)], color_spaces=[ColorSpace.RGB] ) for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]): @@ -1469,7 +1470,7 @@ def sample_inputs_gaussian_blur_video(): def sample_inputs_equalize_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) + sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB) ): yield ArgsKwargs(image_loader) @@ -1497,7 +1498,7 @@ def make_beta_distributed_image(shape, dtype, device, *, alpha, beta): spatial_size = (256, 256) for dtype, color_space, fn in itertools.product( [torch.uint8], - [datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB], + [ColorSpace.GRAY, ColorSpace.RGB], [ lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device), lambda shape, dtype, device: torch.full( @@ -1553,14 +1554,14 @@ def sample_inputs_equalize_video(): def sample_inputs_invert_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) + sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB) ): yield ArgsKwargs(image_loader) def reference_inputs_invert_image_tensor(): for image_loader in make_image_loaders( - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] + color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] ): yield ArgsKwargs(image_loader) @@ -1593,7 +1594,7 @@ def sample_inputs_invert_video(): def sample_inputs_posterize_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) + sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB) ): yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0]) @@ -1601,7 +1602,7 @@ def sample_inputs_posterize_image_tensor(): def reference_inputs_posterize_image_tensor(): for image_loader, bits in itertools.product( make_image_loaders( - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] + color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] ), _POSTERIZE_BITS, ): @@ -1640,14 +1641,14 @@ def _get_solarize_thresholds(dtype): def sample_inputs_solarize_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) + sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB) ): yield ArgsKwargs(image_loader, threshold=next(_get_solarize_thresholds(image_loader.dtype))) def reference_inputs_solarize_image_tensor(): for image_loader in make_image_loaders( - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] + color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] ): for threshold in _get_solarize_thresholds(image_loader.dtype): yield ArgsKwargs(image_loader, threshold=threshold) @@ -1683,14 +1684,14 @@ def sample_inputs_solarize_video(): def sample_inputs_autocontrast_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) + sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB) ): yield ArgsKwargs(image_loader) def reference_inputs_autocontrast_image_tensor(): for image_loader in make_image_loaders( - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] + color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] ): yield ArgsKwargs(image_loader) @@ -1727,7 +1728,7 @@ def sample_inputs_autocontrast_video(): def sample_inputs_adjust_sharpness_image_tensor(): for image_loader in make_image_loaders( sizes=["random", (2, 2)], - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), + color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), ): yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0]) @@ -1735,7 +1736,7 @@ def sample_inputs_adjust_sharpness_image_tensor(): def reference_inputs_adjust_sharpness_image_tensor(): for image_loader, sharpness_factor in itertools.product( make_image_loaders( - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] + color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] ), _ADJUST_SHARPNESS_FACTORS, ): @@ -1801,7 +1802,7 @@ def sample_inputs_erase_video(): def sample_inputs_adjust_brightness_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) + sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB) ): yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0]) @@ -1809,7 +1810,7 @@ def sample_inputs_adjust_brightness_image_tensor(): def reference_inputs_adjust_brightness_image_tensor(): for image_loader, brightness_factor in itertools.product( make_image_loaders( - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] + color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] ), _ADJUST_BRIGHTNESS_FACTORS, ): @@ -1845,7 +1846,7 @@ def sample_inputs_adjust_brightness_video(): def sample_inputs_adjust_contrast_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) + sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB) ): yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0]) @@ -1853,7 +1854,7 @@ def sample_inputs_adjust_contrast_image_tensor(): def reference_inputs_adjust_contrast_image_tensor(): for image_loader, contrast_factor in itertools.product( make_image_loaders( - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] + color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] ), _ADJUST_CONTRAST_FACTORS, ): @@ -1897,7 +1898,7 @@ def sample_inputs_adjust_contrast_video(): def sample_inputs_adjust_gamma_image_tensor(): gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0] for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) + sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB) ): yield ArgsKwargs(image_loader, gamma=gamma, gain=gain) @@ -1905,7 +1906,7 @@ def sample_inputs_adjust_gamma_image_tensor(): def reference_inputs_adjust_gamma_image_tensor(): for image_loader, (gamma, gain) in itertools.product( make_image_loaders( - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] + color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] ), _ADJUST_GAMMA_GAMMAS_GAINS, ): @@ -1945,7 +1946,7 @@ def sample_inputs_adjust_gamma_video(): def sample_inputs_adjust_hue_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) + sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB) ): yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0]) @@ -1953,7 +1954,7 @@ def sample_inputs_adjust_hue_image_tensor(): def reference_inputs_adjust_hue_image_tensor(): for image_loader, hue_factor in itertools.product( make_image_loaders( - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] + color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] ), _ADJUST_HUE_FACTORS, ): @@ -1991,7 +1992,7 @@ def sample_inputs_adjust_hue_video(): def sample_inputs_adjust_saturation_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) + sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB) ): yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0]) @@ -1999,7 +2000,7 @@ def sample_inputs_adjust_saturation_image_tensor(): def reference_inputs_adjust_saturation_image_tensor(): for image_loader, saturation_factor in itertools.product( make_image_loaders( - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] + color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] ), _ADJUST_SATURATION_FACTORS, ): @@ -2065,7 +2066,7 @@ def sample_inputs_five_crop_image_tensor(): for size in _FIVE_TEN_CROP_SIZES: for image_loader in make_image_loaders( sizes=[_get_five_ten_crop_spatial_size(size)], - color_spaces=[datapoints.ColorSpace.RGB], + color_spaces=[ColorSpace.RGB], dtypes=[torch.float32], ): yield ArgsKwargs(image_loader, size=size) @@ -2089,7 +2090,7 @@ def sample_inputs_ten_crop_image_tensor(): for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]): for image_loader in make_image_loaders( sizes=[_get_five_ten_crop_spatial_size(size)], - color_spaces=[datapoints.ColorSpace.RGB], + color_spaces=[ColorSpace.RGB], dtypes=[torch.float32], ): yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip) @@ -2163,7 +2164,7 @@ def wrapper(input_tensor, *other_args, **kwargs): def sample_inputs_normalize_image_tensor(): for image_loader, (mean, std) in itertools.product( - make_image_loaders(sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]), + make_image_loaders(sizes=["random"], color_spaces=[ColorSpace.RGB], dtypes=[torch.float32]), _NORMALIZE_MEANS_STDS, ): yield ArgsKwargs(image_loader, mean=mean, std=std) @@ -2172,7 +2173,7 @@ def sample_inputs_normalize_image_tensor(): def sample_inputs_normalize_video(): mean, std = _NORMALIZE_MEANS_STDS[0] for video_loader in make_video_loaders( - sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], num_frames=["random"], dtypes=[torch.float32] + sizes=["random"], color_spaces=[ColorSpace.RGB], num_frames=["random"], dtypes=[torch.float32] ): yield ArgsKwargs(video_loader, mean=mean, std=std) @@ -2205,7 +2206,7 @@ def sample_inputs_convert_dtype_image_tensor(): continue for image_loader in make_image_loaders( - sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[input_dtype] + sizes=["random"], color_spaces=[ColorSpace.RGB], dtypes=[input_dtype] ): yield ArgsKwargs(image_loader, dtype=output_dtype) @@ -2333,7 +2334,7 @@ def reference_uniform_temporal_subsample_video(x, num_samples, temporal_dim=-4): def reference_inputs_uniform_temporal_subsample_video(): - for video_loader in make_video_loaders(sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], num_frames=[10]): + for video_loader in make_video_loaders(sizes=["random"], color_spaces=[ColorSpace.RGB], num_frames=[10]): for num_samples in range(1, video_loader.shape[-4] + 1): yield ArgsKwargs(video_loader, num_samples) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 097b3edf010..fede1bc3b15 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -10,6 +10,7 @@ import torchvision.prototype.transforms.utils from common_utils import assert_equal, cpu_and_gpu from prototype_common_utils import ( + ColorSpace, DEFAULT_EXTRA_DIMS, make_bounding_box, make_bounding_boxes, @@ -161,8 +162,8 @@ def test_mixup_cutmix(self, transform, input): itertools.chain.from_iterable( fn( color_spaces=[ - datapoints.ColorSpace.GRAY, - datapoints.ColorSpace.RGB, + ColorSpace.GRAY, + ColorSpace.RGB, ], dtypes=[torch.uint8], extra_dims=[(), (4,)], @@ -192,7 +193,7 @@ def test_auto_augment(self, transform, input): ( transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), itertools.chain.from_iterable( - fn(color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]) + fn(color_spaces=[ColorSpace.RGB], dtypes=[torch.float32]) for fn in [ make_images, make_vanilla_tensor_images, @@ -1519,7 +1520,7 @@ def test__get_params(self, mocker): transform = transforms.FixedSizeCrop(size=crop_size) flat_inputs = [ - make_image(size=spatial_size, color_space=datapoints.ColorSpace.RGB), + make_image(size=spatial_size, color_space=ColorSpace.RGB), make_bounding_box( format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=batch_shape ), diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 00dc40fb06d..022fa864eb5 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -16,6 +16,7 @@ ArgsKwargs, assert_close, assert_equal, + ColorSpace, make_bounding_box, make_detection_mask, make_image, @@ -31,7 +32,7 @@ from torchvision.prototype.transforms.utils import query_spatial_size from torchvision.transforms import functional as legacy_F -DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[datapoints.ColorSpace.RGB], extra_dims=[(4,)]) +DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[ColorSpace.RGB], extra_dims=[(4,)]) class ConsistencyConfig: @@ -139,7 +140,7 @@ def __init__( # Make sure that the product of the height, width and number of channels matches the number of elements in # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36. make_images_kwargs=dict( - DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=[datapoints.ColorSpace.RGB] + DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=[ColorSpace.RGB] ), supports_pil=False, ), @@ -151,7 +152,7 @@ def __init__( ArgsKwargs(num_output_channels=3), ], make_images_kwargs=dict( - DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=[datapoints.ColorSpace.RGB, datapoints.ColorSpace.GRAY] + DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=[ColorSpace.RGB, ColorSpace.GRAY] ), ), ConsistencyConfig( @@ -174,10 +175,10 @@ def __init__( [ArgsKwargs()], make_images_kwargs=dict( color_spaces=[ - datapoints.ColorSpace.GRAY, - datapoints.ColorSpace.GRAY_ALPHA, - datapoints.ColorSpace.RGB, - datapoints.ColorSpace.RGB_ALPHA, + ColorSpace.GRAY, + ColorSpace.GRAY_ALPHA, + ColorSpace.RGB, + ColorSpace.RGB_ALPHA, ], extra_dims=[()], ), @@ -911,7 +912,7 @@ def make_datapoints(self, with_mask=True): size = (600, 800) num_objects = 22 - pil_image = to_image_pil(make_image(size=size, color_space=datapoints.ColorSpace.RGB)) + pil_image = to_image_pil(make_image(size=size, color_space=ColorSpace.RGB)) target = { "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), @@ -921,7 +922,7 @@ def make_datapoints(self, with_mask=True): yield (pil_image, target) - tensor_image = torch.Tensor(make_image(size=size, color_space=datapoints.ColorSpace.RGB)) + tensor_image = torch.Tensor(make_image(size=size, color_space=ColorSpace.RGB)) target = { "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), @@ -931,7 +932,7 @@ def make_datapoints(self, with_mask=True): yield (tensor_image, target) - datapoint_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB) + datapoint_image = make_image(size=size, color_space=ColorSpace.RGB) target = { "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), @@ -1015,7 +1016,7 @@ def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8): conv_fns.extend([torch.Tensor, lambda x: x]) for conv_fn in conv_fns: - datapoint_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB, dtype=image_dtype) + datapoint_image = make_image(size=size, color_space=ColorSpace.RGB, dtype=image_dtype) datapoint_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8) dp = (conv_fn(datapoint_image), datapoint_mask) diff --git a/test/test_prototype_transforms_utils.py b/test/test_prototype_transforms_utils.py index 8774b3bb8c5..09c5d297147 100644 --- a/test/test_prototype_transforms_utils.py +++ b/test/test_prototype_transforms_utils.py @@ -4,14 +4,14 @@ import torch import torchvision.prototype.transforms.utils -from prototype_common_utils import make_bounding_box, make_detection_mask, make_image +from prototype_common_utils import make_bounding_box, make_detection_mask, make_image, ColorSpace from torchvision.prototype import datapoints from torchvision.prototype.transforms.functional import to_image_pil from torchvision.prototype.transforms.utils import has_all, has_any -IMAGE = make_image(color_space=datapoints.ColorSpace.RGB) +IMAGE = make_image(color_space=ColorSpace.RGB) BOUNDING_BOX = make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, spatial_size=IMAGE.spatial_size) MASK = make_detection_mask(size=IMAGE.spatial_size) diff --git a/torchvision/prototype/datapoints/__init__.py b/torchvision/prototype/datapoints/__init__.py index 92f345e20bd..f85cb3dd596 100644 --- a/torchvision/prototype/datapoints/__init__.py +++ b/torchvision/prototype/datapoints/__init__.py @@ -1,6 +1,6 @@ from ._bounding_box import BoundingBox, BoundingBoxFormat from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT -from ._image import ColorSpace, Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT +from ._image import Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT from ._label import Label, OneHotLabel from ._mask import Mask from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index ba8f30fbd5a..d3bfd7d5e4a 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -11,52 +11,6 @@ from ._datapoint import Datapoint, FillTypeJIT -class ColorSpace(StrEnum): - OTHER = StrEnum.auto() - GRAY = StrEnum.auto() - GRAY_ALPHA = StrEnum.auto() - RGB = StrEnum.auto() - RGB_ALPHA = StrEnum.auto() - - @classmethod - def from_pil_mode(cls, mode: str) -> ColorSpace: - if mode == "L": - return cls.GRAY - elif mode == "LA": - return cls.GRAY_ALPHA - elif mode == "RGB": - return cls.RGB - elif mode == "RGBA": - return cls.RGB_ALPHA - else: - return cls.OTHER - - @staticmethod - def from_tensor_shape(shape: List[int]) -> ColorSpace: - return _from_tensor_shape(shape) - - -def _from_tensor_shape(shape: List[int]) -> ColorSpace: - # Needed as a standalone method for JIT - ndim = len(shape) - if ndim < 2: - return ColorSpace.OTHER - elif ndim == 2: - return ColorSpace.GRAY - - num_channels = shape[-3] - if num_channels == 1: - return ColorSpace.GRAY - elif num_channels == 2: - return ColorSpace.GRAY_ALPHA - elif num_channels == 3: - return ColorSpace.RGB - elif num_channels == 4: - return ColorSpace.RGB_ALPHA - else: - return ColorSpace.OTHER - - class Image(Datapoint): @classmethod diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index ff0c6306d51..974fe2b2741 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -63,7 +63,7 @@ def _transform( ) -> Union[datapoints.ImageType, datapoints.VideoType]: output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.GRAY) # type: ignore[arg-type] + output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] return output @@ -99,5 +99,5 @@ def _transform( ) -> Union[datapoints.ImageType, datapoints.VideoType]: output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.GRAY) # type: ignore[arg-type] + output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] return output diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 547159f30c0..464fbada237 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -3,7 +3,7 @@ import PIL.Image import torch from torchvision.prototype import datapoints -from torchvision.prototype.datapoints import BoundingBoxFormat, ColorSpace +from torchvision.prototype.datapoints import BoundingBoxFormat from torchvision.transforms import functional_pil as _FP from torchvision.transforms.functional_tensor import _max_value From 874b00f56298bc6020cd275c42a6a7135e8152f6 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 20 Jan 2023 16:04:17 +0000 Subject: [PATCH 09/15] Formatting --- test/prototype_common_utils.py | 2 + test/prototype_transforms_kernel_infos.py | 76 +++++-------------- test/test_prototype_transforms_consistency.py | 8 +- test/test_prototype_transforms_utils.py | 2 +- torchvision/prototype/datapoints/_image.py | 1 - torchvision/prototype/datapoints/_video.py | 1 - 6 files changed, 24 insertions(+), 66 deletions(-) diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index 433b6775d84..479d0fa06c3 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -207,6 +207,8 @@ def _parse_spatial_size(size, *, name="size"): DEFAULT_EXTRA_DIMS = (*VALID_EXTRA_DIMS, *DEGENERATE_BATCH_DIMS) from enum import Enum + + class ColorSpace(Enum): OTHER = "OTHER" GRAY = "GRAY" diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 1e38e5c8d2b..b86cd68a58b 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -262,9 +262,7 @@ def _get_resize_sizes(spatial_size): def sample_inputs_resize_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=[ColorSpace.RGB], dtypes=[torch.float32] - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=[ColorSpace.RGB], dtypes=[torch.float32]): for size in _get_resize_sizes(image_loader.spatial_size): yield ArgsKwargs(image_loader, size=size) @@ -1469,9 +1467,7 @@ def sample_inputs_gaussian_blur_video(): def sample_inputs_equalize_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB) - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB)): yield ArgsKwargs(image_loader) @@ -1553,9 +1549,7 @@ def sample_inputs_equalize_video(): def sample_inputs_invert_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB) - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB)): yield ArgsKwargs(image_loader) @@ -1593,17 +1587,13 @@ def sample_inputs_invert_video(): def sample_inputs_posterize_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB) - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB)): yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0]) def reference_inputs_posterize_image_tensor(): for image_loader, bits in itertools.product( - make_image_loaders( - color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ), + make_image_loaders(color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]), _POSTERIZE_BITS, ): yield ArgsKwargs(image_loader, bits=bits) @@ -1640,9 +1630,7 @@ def _get_solarize_thresholds(dtype): def sample_inputs_solarize_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB) - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB)): yield ArgsKwargs(image_loader, threshold=next(_get_solarize_thresholds(image_loader.dtype))) @@ -1683,9 +1671,7 @@ def sample_inputs_solarize_video(): def sample_inputs_autocontrast_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB) - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB)): yield ArgsKwargs(image_loader) @@ -1735,9 +1721,7 @@ def sample_inputs_adjust_sharpness_image_tensor(): def reference_inputs_adjust_sharpness_image_tensor(): for image_loader, sharpness_factor in itertools.product( - make_image_loaders( - color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ), + make_image_loaders(color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]), _ADJUST_SHARPNESS_FACTORS, ): yield ArgsKwargs(image_loader, sharpness_factor=sharpness_factor) @@ -1801,17 +1785,13 @@ def sample_inputs_erase_video(): def sample_inputs_adjust_brightness_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB) - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB)): yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0]) def reference_inputs_adjust_brightness_image_tensor(): for image_loader, brightness_factor in itertools.product( - make_image_loaders( - color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ), + make_image_loaders(color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]), _ADJUST_BRIGHTNESS_FACTORS, ): yield ArgsKwargs(image_loader, brightness_factor=brightness_factor) @@ -1845,17 +1825,13 @@ def sample_inputs_adjust_brightness_video(): def sample_inputs_adjust_contrast_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB) - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB)): yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0]) def reference_inputs_adjust_contrast_image_tensor(): for image_loader, contrast_factor in itertools.product( - make_image_loaders( - color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ), + make_image_loaders(color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]), _ADJUST_CONTRAST_FACTORS, ): yield ArgsKwargs(image_loader, contrast_factor=contrast_factor) @@ -1897,17 +1873,13 @@ def sample_inputs_adjust_contrast_video(): def sample_inputs_adjust_gamma_image_tensor(): gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0] - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB) - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB)): yield ArgsKwargs(image_loader, gamma=gamma, gain=gain) def reference_inputs_adjust_gamma_image_tensor(): for image_loader, (gamma, gain) in itertools.product( - make_image_loaders( - color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ), + make_image_loaders(color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]), _ADJUST_GAMMA_GAMMAS_GAINS, ): yield ArgsKwargs(image_loader, gamma=gamma, gain=gain) @@ -1945,17 +1917,13 @@ def sample_inputs_adjust_gamma_video(): def sample_inputs_adjust_hue_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB) - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB)): yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0]) def reference_inputs_adjust_hue_image_tensor(): for image_loader, hue_factor in itertools.product( - make_image_loaders( - color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ), + make_image_loaders(color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]), _ADJUST_HUE_FACTORS, ): yield ArgsKwargs(image_loader, hue_factor=hue_factor) @@ -1991,17 +1959,13 @@ def sample_inputs_adjust_hue_video(): def sample_inputs_adjust_saturation_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB) - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB)): yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0]) def reference_inputs_adjust_saturation_image_tensor(): for image_loader, saturation_factor in itertools.product( - make_image_loaders( - color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ), + make_image_loaders(color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]), _ADJUST_SATURATION_FACTORS, ): yield ArgsKwargs(image_loader, saturation_factor=saturation_factor) @@ -2205,9 +2169,7 @@ def sample_inputs_convert_dtype_image_tensor(): # conversion cannot be performed safely continue - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=[ColorSpace.RGB], dtypes=[input_dtype] - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=[ColorSpace.RGB], dtypes=[input_dtype]): yield ArgsKwargs(image_loader, dtype=output_dtype) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 022fa864eb5..a6970792ddf 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -139,9 +139,7 @@ def __init__( ], # Make sure that the product of the height, width and number of channels matches the number of elements in # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36. - make_images_kwargs=dict( - DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=[ColorSpace.RGB] - ), + make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=[ColorSpace.RGB]), supports_pil=False, ), ConsistencyConfig( @@ -151,9 +149,7 @@ def __init__( ArgsKwargs(num_output_channels=1), ArgsKwargs(num_output_channels=3), ], - make_images_kwargs=dict( - DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=[ColorSpace.RGB, ColorSpace.GRAY] - ), + make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=[ColorSpace.RGB, ColorSpace.GRAY]), ), ConsistencyConfig( prototype_transforms.ConvertDtype, diff --git a/test/test_prototype_transforms_utils.py b/test/test_prototype_transforms_utils.py index 09c5d297147..e8e123ae811 100644 --- a/test/test_prototype_transforms_utils.py +++ b/test/test_prototype_transforms_utils.py @@ -4,7 +4,7 @@ import torch import torchvision.prototype.transforms.utils -from prototype_common_utils import make_bounding_box, make_detection_mask, make_image, ColorSpace +from prototype_common_utils import ColorSpace, make_bounding_box, make_detection_mask, make_image from torchvision.prototype import datapoints from torchvision.prototype.transforms.functional import to_image_pil diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index d3bfd7d5e4a..aff91e0e73a 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -12,7 +12,6 @@ class Image(Datapoint): - @classmethod def _wrap(cls, tensor: torch.Tensor) -> Image: image = tensor.as_subclass(cls) diff --git a/torchvision/prototype/datapoints/_video.py b/torchvision/prototype/datapoints/_video.py index fff56d79f0d..0e8ac4c7c17 100644 --- a/torchvision/prototype/datapoints/_video.py +++ b/torchvision/prototype/datapoints/_video.py @@ -10,7 +10,6 @@ class Video(Datapoint): - @classmethod def _wrap(cls, tensor: torch.Tensor) -> Video: video = tensor.as_subclass(cls) From b26a14a2e6c214cec9304ea5ccf61e12ce7d5ecf Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 24 Jan 2023 10:45:48 +0000 Subject: [PATCH 10/15] renaming --- test/prototype_common_utils.py | 55 +++++------- test/prototype_transforms_kernel_infos.py | 85 ++++++++----------- test/test_prototype_transforms.py | 9 +- test/test_prototype_transforms_consistency.py | 23 +++-- test/test_prototype_transforms_utils.py | 4 +- 5 files changed, 76 insertions(+), 100 deletions(-) diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index 479d0fa06c3..80302a1ab85 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -209,14 +209,6 @@ def _parse_spatial_size(size, *, name="size"): from enum import Enum -class ColorSpace(Enum): - OTHER = "OTHER" - GRAY = "GRAY" - GRAY_ALPHA = "GRAY_ALPHA" - RGB = "RGB" - RGB_ALPHA = "RGB_ALPHA" - - def from_loader(loader_fn): def wrapper(*args, **kwargs): device = kwargs.pop("device", "cpu") @@ -248,7 +240,6 @@ def load(self, device): @dataclasses.dataclass class ImageLoader(TensorLoader): - color_space: ColorSpace spatial_size: Tuple[int, int] = dataclasses.field(init=False) num_channels: int = dataclasses.field(init=False) @@ -258,10 +249,10 @@ def __post_init__(self): NUM_CHANNELS_MAP = { - ColorSpace.GRAY: 1, - ColorSpace.GRAY_ALPHA: 2, - ColorSpace.RGB: 3, - ColorSpace.RGB_ALPHA: 4, + "GRAY": 1, + "GRAY_ALPHA": 2, + "RGB": 3, + "RGBA": 4, } @@ -275,7 +266,7 @@ def get_num_channels(color_space): def make_image_loader( size="random", *, - color_space=ColorSpace.RGB, + color_space="RGB", extra_dims=(), dtype=torch.float32, constant_alpha=True, @@ -286,11 +277,11 @@ def make_image_loader( def fn(shape, dtype, device): max_value = get_max_value(dtype) data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device) - if color_space in {ColorSpace.GRAY_ALPHA, ColorSpace.RGB_ALPHA} and constant_alpha: + if color_space in {"GRAY_ALPHA", "RGBA"} and constant_alpha: data[..., -1, :, :] = max_value return datapoints.Image(data) - return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, color_space=color_space) + return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype) make_image = from_loader(make_image_loader) @@ -300,10 +291,10 @@ def make_image_loaders( *, sizes=DEFAULT_SPATIAL_SIZES, color_spaces=( - ColorSpace.GRAY, - ColorSpace.GRAY_ALPHA, - ColorSpace.RGB, - ColorSpace.RGB_ALPHA, + "GRAY", + "GRAY_ALPHA", + "RGB", + "RGBA", ), extra_dims=DEFAULT_EXTRA_DIMS, dtypes=(torch.float32, torch.uint8), @@ -316,7 +307,7 @@ def make_image_loaders( make_images = from_loaders(make_image_loaders) -def make_image_loader_for_interpolation(size="random", *, color_space=ColorSpace.RGB, dtype=torch.uint8): +def make_image_loader_for_interpolation(size="random", *, color_space="RGB", dtype=torch.uint8): size = _parse_spatial_size(size) num_channels = get_num_channels(color_space) @@ -328,10 +319,10 @@ def fn(shape, dtype, device): .resize((width, height)) .convert( { - ColorSpace.GRAY: "L", - ColorSpace.GRAY_ALPHA: "LA", - ColorSpace.RGB: "RGB", - ColorSpace.RGB_ALPHA: "RGBA", + "GRAY": "L", + "GRAY_ALPHA": "LA", + "RGB": "RGB", + "RGBA": "RGBA", }[color_space] ) ) @@ -340,12 +331,12 @@ def fn(shape, dtype, device): return datapoints.Image(image_tensor) - return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, color_space=color_space) + return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype) def make_image_loaders_for_interpolation( sizes=((233, 147),), - color_spaces=(ColorSpace.RGB,), + color_spaces=("RGB",), dtypes=(torch.uint8,), ): for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes): @@ -593,7 +584,7 @@ class VideoLoader(ImageLoader): def make_video_loader( size="random", *, - color_space=ColorSpace.RGB, + color_space="RGB", num_frames="random", extra_dims=(), dtype=torch.uint8, @@ -605,9 +596,7 @@ def fn(shape, dtype, device): video = make_image(size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device) return datapoints.Video(video) - return VideoLoader( - fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype, color_space=color_space - ) + return VideoLoader(fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype) make_video = from_loader(make_video_loader) @@ -617,8 +606,8 @@ def make_video_loaders( *, sizes=DEFAULT_SPATIAL_SIZES, color_spaces=( - ColorSpace.GRAY, - ColorSpace.RGB, + "GRAY", + "RGB", ), num_frames=(1, 0, "random"), extra_dims=DEFAULT_EXTRA_DIMS, diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index b86cd68a58b..14e06fd76bb 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -13,7 +13,6 @@ from datasets_utils import combinations_grid from prototype_common_utils import ( ArgsKwargs, - ColorSpace, get_num_channels, ImageLoader, InfoBase, @@ -262,12 +261,12 @@ def _get_resize_sizes(spatial_size): def sample_inputs_resize_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=[ColorSpace.RGB], dtypes=[torch.float32]): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]): for size in _get_resize_sizes(image_loader.spatial_size): yield ArgsKwargs(image_loader, size=size) for image_loader, interpolation in itertools.product( - make_image_loaders(sizes=["random"], color_spaces=[ColorSpace.RGB]), + make_image_loaders(sizes=["random"], color_spaces=["RGB"]), [ F.InterpolationMode.NEAREST, F.InterpolationMode.BILINEAR, @@ -471,7 +470,7 @@ def float32_vs_uint8_fill_adapter(other_args, kwargs): def sample_inputs_affine_image_tensor(): make_affine_image_loaders = functools.partial( - make_image_loaders, sizes=["random"], color_spaces=[ColorSpace.RGB], dtypes=[torch.float32] + make_image_loaders, sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32] ) for image_loader, affine_params in itertools.product(make_affine_image_loaders(), _DIVERSE_AFFINE_PARAMS): @@ -758,7 +757,7 @@ def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size): def sample_inputs_rotate_image_tensor(): make_rotate_image_loaders = functools.partial( - make_image_loaders, sizes=["random"], color_spaces=[ColorSpace.RGB], dtypes=[torch.float32] + make_image_loaders, sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32] ) for image_loader in make_rotate_image_loaders(): @@ -840,7 +839,7 @@ def sample_inputs_rotate_video(): def sample_inputs_crop_image_tensor(): for image_loader, params in itertools.product( - make_image_loaders(sizes=[(16, 17)], color_spaces=[ColorSpace.RGB], dtypes=[torch.float32]), + make_image_loaders(sizes=[(16, 17)], color_spaces=["RGB"], dtypes=[torch.float32]), [ dict(top=4, left=3, height=7, width=8), dict(top=-1, left=3, height=7, width=8), @@ -1026,7 +1025,7 @@ def sample_inputs_resized_crop_video(): def sample_inputs_pad_image_tensor(): make_pad_image_loaders = functools.partial( - make_image_loaders, sizes=["random"], color_spaces=[ColorSpace.RGB], dtypes=[torch.float32] + make_image_loaders, sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32] ) for image_loader, padding in itertools.product( @@ -1342,7 +1341,7 @@ def sample_inputs_elastic_video(): def sample_inputs_center_crop_image_tensor(): for image_loader, output_size in itertools.product( - make_image_loaders(sizes=[(16, 17)], color_spaces=[ColorSpace.RGB], dtypes=[torch.float32]), + make_image_loaders(sizes=[(16, 17)], color_spaces=["RGB"], dtypes=[torch.float32]), [ # valid `output_size` types for which cropping is applied to both dimensions *[5, (4,), (2, 3), [6], [3, 2]], @@ -1428,9 +1427,7 @@ def sample_inputs_center_crop_video(): def sample_inputs_gaussian_blur_image_tensor(): - make_gaussian_blur_image_loaders = functools.partial( - make_image_loaders, sizes=[(7, 33)], color_spaces=[ColorSpace.RGB] - ) + make_gaussian_blur_image_loaders = functools.partial(make_image_loaders, sizes=[(7, 33)], color_spaces=["RGB"]) for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]): yield ArgsKwargs(image_loader, kernel_size=kernel_size) @@ -1467,7 +1464,7 @@ def sample_inputs_gaussian_blur_video(): def sample_inputs_equalize_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB)): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader) @@ -1494,7 +1491,7 @@ def make_beta_distributed_image(shape, dtype, device, *, alpha, beta): spatial_size = (256, 256) for dtype, color_space, fn in itertools.product( [torch.uint8], - [ColorSpace.GRAY, ColorSpace.RGB], + ["GRAY", "RGB"], [ lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device), lambda shape, dtype, device: torch.full( @@ -1519,9 +1516,7 @@ def make_beta_distributed_image(shape, dtype, device, *, alpha, beta): ], ], ): - image_loader = ImageLoader( - fn, shape=(get_num_channels(color_space), *spatial_size), dtype=dtype, color_space=color_space - ) + image_loader = ImageLoader(fn, shape=(get_num_channels(color_space), *spatial_size), dtype=dtype) yield ArgsKwargs(image_loader) @@ -1549,14 +1544,12 @@ def sample_inputs_equalize_video(): def sample_inputs_invert_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB)): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader) def reference_inputs_invert_image_tensor(): - for image_loader in make_image_loaders( - color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ): + for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]): yield ArgsKwargs(image_loader) @@ -1587,13 +1580,13 @@ def sample_inputs_invert_video(): def sample_inputs_posterize_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB)): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0]) def reference_inputs_posterize_image_tensor(): for image_loader, bits in itertools.product( - make_image_loaders(color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]), + make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]), _POSTERIZE_BITS, ): yield ArgsKwargs(image_loader, bits=bits) @@ -1630,14 +1623,12 @@ def _get_solarize_thresholds(dtype): def sample_inputs_solarize_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB)): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, threshold=next(_get_solarize_thresholds(image_loader.dtype))) def reference_inputs_solarize_image_tensor(): - for image_loader in make_image_loaders( - color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ): + for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]): for threshold in _get_solarize_thresholds(image_loader.dtype): yield ArgsKwargs(image_loader, threshold=threshold) @@ -1671,14 +1662,12 @@ def sample_inputs_solarize_video(): def sample_inputs_autocontrast_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB)): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader) def reference_inputs_autocontrast_image_tensor(): - for image_loader in make_image_loaders( - color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ): + for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]): yield ArgsKwargs(image_loader) @@ -1714,14 +1703,14 @@ def sample_inputs_autocontrast_video(): def sample_inputs_adjust_sharpness_image_tensor(): for image_loader in make_image_loaders( sizes=["random", (2, 2)], - color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), + color_spaces=("GRAY", "RGB"), ): yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0]) def reference_inputs_adjust_sharpness_image_tensor(): for image_loader, sharpness_factor in itertools.product( - make_image_loaders(color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]), + make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]), _ADJUST_SHARPNESS_FACTORS, ): yield ArgsKwargs(image_loader, sharpness_factor=sharpness_factor) @@ -1785,13 +1774,13 @@ def sample_inputs_erase_video(): def sample_inputs_adjust_brightness_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB)): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0]) def reference_inputs_adjust_brightness_image_tensor(): for image_loader, brightness_factor in itertools.product( - make_image_loaders(color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]), + make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]), _ADJUST_BRIGHTNESS_FACTORS, ): yield ArgsKwargs(image_loader, brightness_factor=brightness_factor) @@ -1825,13 +1814,13 @@ def sample_inputs_adjust_brightness_video(): def sample_inputs_adjust_contrast_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB)): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0]) def reference_inputs_adjust_contrast_image_tensor(): for image_loader, contrast_factor in itertools.product( - make_image_loaders(color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]), + make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]), _ADJUST_CONTRAST_FACTORS, ): yield ArgsKwargs(image_loader, contrast_factor=contrast_factor) @@ -1873,13 +1862,13 @@ def sample_inputs_adjust_contrast_video(): def sample_inputs_adjust_gamma_image_tensor(): gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0] - for image_loader in make_image_loaders(sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB)): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, gamma=gamma, gain=gain) def reference_inputs_adjust_gamma_image_tensor(): for image_loader, (gamma, gain) in itertools.product( - make_image_loaders(color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]), + make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]), _ADJUST_GAMMA_GAMMAS_GAINS, ): yield ArgsKwargs(image_loader, gamma=gamma, gain=gain) @@ -1917,13 +1906,13 @@ def sample_inputs_adjust_gamma_video(): def sample_inputs_adjust_hue_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB)): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0]) def reference_inputs_adjust_hue_image_tensor(): for image_loader, hue_factor in itertools.product( - make_image_loaders(color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]), + make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]), _ADJUST_HUE_FACTORS, ): yield ArgsKwargs(image_loader, hue_factor=hue_factor) @@ -1959,13 +1948,13 @@ def sample_inputs_adjust_hue_video(): def sample_inputs_adjust_saturation_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=(ColorSpace.GRAY, ColorSpace.RGB)): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0]) def reference_inputs_adjust_saturation_image_tensor(): for image_loader, saturation_factor in itertools.product( - make_image_loaders(color_spaces=(ColorSpace.GRAY, ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]), + make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]), _ADJUST_SATURATION_FACTORS, ): yield ArgsKwargs(image_loader, saturation_factor=saturation_factor) @@ -2030,7 +2019,7 @@ def sample_inputs_five_crop_image_tensor(): for size in _FIVE_TEN_CROP_SIZES: for image_loader in make_image_loaders( sizes=[_get_five_ten_crop_spatial_size(size)], - color_spaces=[ColorSpace.RGB], + color_spaces=["RGB"], dtypes=[torch.float32], ): yield ArgsKwargs(image_loader, size=size) @@ -2054,7 +2043,7 @@ def sample_inputs_ten_crop_image_tensor(): for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]): for image_loader in make_image_loaders( sizes=[_get_five_ten_crop_spatial_size(size)], - color_spaces=[ColorSpace.RGB], + color_spaces=["RGB"], dtypes=[torch.float32], ): yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip) @@ -2128,7 +2117,7 @@ def wrapper(input_tensor, *other_args, **kwargs): def sample_inputs_normalize_image_tensor(): for image_loader, (mean, std) in itertools.product( - make_image_loaders(sizes=["random"], color_spaces=[ColorSpace.RGB], dtypes=[torch.float32]), + make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]), _NORMALIZE_MEANS_STDS, ): yield ArgsKwargs(image_loader, mean=mean, std=std) @@ -2137,7 +2126,7 @@ def sample_inputs_normalize_image_tensor(): def sample_inputs_normalize_video(): mean, std = _NORMALIZE_MEANS_STDS[0] for video_loader in make_video_loaders( - sizes=["random"], color_spaces=[ColorSpace.RGB], num_frames=["random"], dtypes=[torch.float32] + sizes=["random"], color_spaces=["RGB"], num_frames=["random"], dtypes=[torch.float32] ): yield ArgsKwargs(video_loader, mean=mean, std=std) @@ -2169,7 +2158,7 @@ def sample_inputs_convert_dtype_image_tensor(): # conversion cannot be performed safely continue - for image_loader in make_image_loaders(sizes=["random"], color_spaces=[ColorSpace.RGB], dtypes=[input_dtype]): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[input_dtype]): yield ArgsKwargs(image_loader, dtype=output_dtype) @@ -2296,7 +2285,7 @@ def reference_uniform_temporal_subsample_video(x, num_samples, temporal_dim=-4): def reference_inputs_uniform_temporal_subsample_video(): - for video_loader in make_video_loaders(sizes=["random"], color_spaces=[ColorSpace.RGB], num_frames=[10]): + for video_loader in make_video_loaders(sizes=["random"], color_spaces=["RGB"], num_frames=[10]): for num_samples in range(1, video_loader.shape[-4] + 1): yield ArgsKwargs(video_loader, num_samples) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index fede1bc3b15..335fbfd4fe3 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -10,7 +10,6 @@ import torchvision.prototype.transforms.utils from common_utils import assert_equal, cpu_and_gpu from prototype_common_utils import ( - ColorSpace, DEFAULT_EXTRA_DIMS, make_bounding_box, make_bounding_boxes, @@ -162,8 +161,8 @@ def test_mixup_cutmix(self, transform, input): itertools.chain.from_iterable( fn( color_spaces=[ - ColorSpace.GRAY, - ColorSpace.RGB, + "GRAY", + "RGB", ], dtypes=[torch.uint8], extra_dims=[(), (4,)], @@ -193,7 +192,7 @@ def test_auto_augment(self, transform, input): ( transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), itertools.chain.from_iterable( - fn(color_spaces=[ColorSpace.RGB], dtypes=[torch.float32]) + fn(color_spaces=["RGB"], dtypes=[torch.float32]) for fn in [ make_images, make_vanilla_tensor_images, @@ -1520,7 +1519,7 @@ def test__get_params(self, mocker): transform = transforms.FixedSizeCrop(size=crop_size) flat_inputs = [ - make_image(size=spatial_size, color_space=ColorSpace.RGB), + make_image(size=spatial_size, color_space="RGB"), make_bounding_box( format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=batch_shape ), diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index a6970792ddf..3b69b72dd4f 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -16,7 +16,6 @@ ArgsKwargs, assert_close, assert_equal, - ColorSpace, make_bounding_box, make_detection_mask, make_image, @@ -32,7 +31,7 @@ from torchvision.prototype.transforms.utils import query_spatial_size from torchvision.transforms import functional as legacy_F -DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[ColorSpace.RGB], extra_dims=[(4,)]) +DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)]) class ConsistencyConfig: @@ -139,7 +138,7 @@ def __init__( ], # Make sure that the product of the height, width and number of channels matches the number of elements in # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36. - make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=[ColorSpace.RGB]), + make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"]), supports_pil=False, ), ConsistencyConfig( @@ -149,7 +148,7 @@ def __init__( ArgsKwargs(num_output_channels=1), ArgsKwargs(num_output_channels=3), ], - make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=[ColorSpace.RGB, ColorSpace.GRAY]), + make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]), ), ConsistencyConfig( prototype_transforms.ConvertDtype, @@ -171,10 +170,10 @@ def __init__( [ArgsKwargs()], make_images_kwargs=dict( color_spaces=[ - ColorSpace.GRAY, - ColorSpace.GRAY_ALPHA, - ColorSpace.RGB, - ColorSpace.RGB_ALPHA, + "GRAY", + "GRAY_ALPHA", + "RGB", + "RGBA", ], extra_dims=[()], ), @@ -908,7 +907,7 @@ def make_datapoints(self, with_mask=True): size = (600, 800) num_objects = 22 - pil_image = to_image_pil(make_image(size=size, color_space=ColorSpace.RGB)) + pil_image = to_image_pil(make_image(size=size, color_space="RGB")) target = { "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), @@ -918,7 +917,7 @@ def make_datapoints(self, with_mask=True): yield (pil_image, target) - tensor_image = torch.Tensor(make_image(size=size, color_space=ColorSpace.RGB)) + tensor_image = torch.Tensor(make_image(size=size, color_space="RGB")) target = { "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), @@ -928,7 +927,7 @@ def make_datapoints(self, with_mask=True): yield (tensor_image, target) - datapoint_image = make_image(size=size, color_space=ColorSpace.RGB) + datapoint_image = make_image(size=size, color_space="RGB") target = { "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), @@ -1012,7 +1011,7 @@ def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8): conv_fns.extend([torch.Tensor, lambda x: x]) for conv_fn in conv_fns: - datapoint_image = make_image(size=size, color_space=ColorSpace.RGB, dtype=image_dtype) + datapoint_image = make_image(size=size, color_space="RGB", dtype=image_dtype) datapoint_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8) dp = (conv_fn(datapoint_image), datapoint_mask) diff --git a/test/test_prototype_transforms_utils.py b/test/test_prototype_transforms_utils.py index e8e123ae811..befccf0bea3 100644 --- a/test/test_prototype_transforms_utils.py +++ b/test/test_prototype_transforms_utils.py @@ -4,14 +4,14 @@ import torch import torchvision.prototype.transforms.utils -from prototype_common_utils import ColorSpace, make_bounding_box, make_detection_mask, make_image +from prototype_common_utils import make_bounding_box, make_detection_mask, make_image from torchvision.prototype import datapoints from torchvision.prototype.transforms.functional import to_image_pil from torchvision.prototype.transforms.utils import has_all, has_any -IMAGE = make_image(color_space=ColorSpace.RGB) +IMAGE = make_image(color_space="RGB") BOUNDING_BOX = make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, spatial_size=IMAGE.spatial_size) MASK = make_detection_mask(size=IMAGE.spatial_size) From 51a15be3aa8f7112024fb8671f631f55efc14aad Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 24 Jan 2023 10:52:06 +0000 Subject: [PATCH 11/15] avoid call to removed datapoints._image._from_tensor_shape --- .../prototype/transforms/functional/_deprecated.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_deprecated.py b/torchvision/prototype/transforms/functional/_deprecated.py index f6fb0af0ae9..b60d143ca06 100644 --- a/torchvision/prototype/transforms/functional/_deprecated.py +++ b/torchvision/prototype/transforms/functional/_deprecated.py @@ -27,13 +27,9 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima def rgb_to_grayscale( inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1 ) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: - if torch.jit.is_scripting() or is_simple_tensor(inpt): - old_color_space = datapoints._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type] - else: - old_color_space = None - - if isinstance(inpt, (datapoints.Image, datapoints.Video)): - inpt = inpt.as_subclass(torch.Tensor) + old_color_space = None # TODO: remove when un-deprecating + if not (torch.jit.is_scripting() or is_simple_tensor(inpt)) and isinstance(inpt, (datapoints.Image, datapoints.Video)): + inpt = inpt.as_subclass(torch.Tensor) call = ", num_output_channels=3" if num_output_channels == 3 else "" replacement = ( From 9651553324cb211f074b70b3a8c8a53ab4e4f7f3 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 24 Jan 2023 10:52:41 +0000 Subject: [PATCH 12/15] Address comments --- torchvision/prototype/datapoints/_video.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchvision/prototype/datapoints/_video.py b/torchvision/prototype/datapoints/_video.py index 0e8ac4c7c17..f91a0b18228 100644 --- a/torchvision/prototype/datapoints/_video.py +++ b/torchvision/prototype/datapoints/_video.py @@ -26,8 +26,6 @@ def __new__( tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) if data.ndim < 4: raise ValueError - video = super().__new__(cls, data, requires_grad=requires_grad) - # TODO: Should this be `video` or can we remove it? return cls._wrap(tensor) @classmethod From 7c3019d611cda37cc0de840e9188e403db11e34f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 24 Jan 2023 10:53:25 +0000 Subject: [PATCH 13/15] formatting --- torchvision/prototype/transforms/functional/_deprecated.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_deprecated.py b/torchvision/prototype/transforms/functional/_deprecated.py index b60d143ca06..a89bcae7b90 100644 --- a/torchvision/prototype/transforms/functional/_deprecated.py +++ b/torchvision/prototype/transforms/functional/_deprecated.py @@ -28,7 +28,9 @@ def rgb_to_grayscale( inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1 ) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: old_color_space = None # TODO: remove when un-deprecating - if not (torch.jit.is_scripting() or is_simple_tensor(inpt)) and isinstance(inpt, (datapoints.Image, datapoints.Video)): + if not (torch.jit.is_scripting() or is_simple_tensor(inpt)) and isinstance( + inpt, (datapoints.Image, datapoints.Video) + ): inpt = inpt.as_subclass(torch.Tensor) call = ", num_output_channels=3" if num_output_channels == 3 else "" From b8053f16bea6a2be92a6d5dced4d3537e388e4a3 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 24 Jan 2023 11:27:02 +0000 Subject: [PATCH 14/15] Merge --- test/prototype_transforms_kernel_infos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 26e5405ae00..963ea99eca8 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -2133,7 +2133,7 @@ def reference_normalize_image_tensor(image, mean, std, inplace=False): def reference_inputs_normalize_image_tensor(): yield ArgsKwargs( - make_image_loader(size=(32, 32), color_space=datapoints.ColorSpace.RGB, extra_dims=[1]), + make_image_loader(size=(32, 32), color_space="RGB", extra_dims=[1]), mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0], ) From 0693957ae00ffc602fe98cae85f1600f6e7e9d4f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 24 Jan 2023 12:58:32 +0000 Subject: [PATCH 15/15] lint --- test/prototype_common_utils.py | 2 -- test/prototype_transforms_kernel_infos.py | 1 - torchvision/prototype/datapoints/_image.py | 2 -- torchvision/prototype/datapoints/_video.py | 1 - torchvision/prototype/transforms/_meta.py | 4 +--- torchvision/prototype/transforms/functional/_meta.py | 2 +- 6 files changed, 2 insertions(+), 10 deletions(-) diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index 80302a1ab85..1cea10603ec 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -206,8 +206,6 @@ def _parse_spatial_size(size, *, name="size"): DEFAULT_EXTRA_DIMS = (*VALID_EXTRA_DIMS, *DEGENERATE_BATCH_DIMS) -from enum import Enum - def from_loader(loader_fn): def wrapper(*args, **kwargs): diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 963ea99eca8..1fac1526248 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -9,7 +9,6 @@ import torch.testing import torchvision.ops import torchvision.prototype.transforms.functional as F -from common_utils import cycle_over from datasets_utils import combinations_grid from prototype_common_utils import ( ArgsKwargs, diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index aff91e0e73a..ece95169ac3 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -1,11 +1,9 @@ from __future__ import annotations -import warnings from typing import Any, List, Optional, Tuple, Union import PIL.Image import torch -from torchvision._utils import StrEnum from torchvision.transforms.functional import InterpolationMode from ._datapoint import Datapoint, FillTypeJIT diff --git a/torchvision/prototype/datapoints/_video.py b/torchvision/prototype/datapoints/_video.py index f91a0b18228..5a73d35368a 100644 --- a/torchvision/prototype/datapoints/_video.py +++ b/torchvision/prototype/datapoints/_video.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from typing import Any, List, Optional, Tuple, Union import torch diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 3c95c9b8e66..0373ee1baf3 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -1,6 +1,4 @@ -from typing import Any, Dict, Optional, Union - -import PIL.Image +from typing import Any, Dict, Union import torch diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 464fbada237..b76dc7d7b68 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import List, Tuple, Union import PIL.Image import torch