diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index 18664eb0945..1cea10603ec 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -238,7 +238,6 @@ def load(self, device): @dataclasses.dataclass class ImageLoader(TensorLoader): - color_space: datapoints.ColorSpace spatial_size: Tuple[int, int] = dataclasses.field(init=False) num_channels: int = dataclasses.field(init=False) @@ -248,10 +247,10 @@ def __post_init__(self): NUM_CHANNELS_MAP = { - datapoints.ColorSpace.GRAY: 1, - datapoints.ColorSpace.GRAY_ALPHA: 2, - datapoints.ColorSpace.RGB: 3, - datapoints.ColorSpace.RGB_ALPHA: 4, + "GRAY": 1, + "GRAY_ALPHA": 2, + "RGB": 3, + "RGBA": 4, } @@ -265,7 +264,7 @@ def get_num_channels(color_space): def make_image_loader( size="random", *, - color_space=datapoints.ColorSpace.RGB, + color_space="RGB", extra_dims=(), dtype=torch.float32, constant_alpha=True, @@ -276,11 +275,11 @@ def make_image_loader( def fn(shape, dtype, device): max_value = get_max_value(dtype) data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device) - if color_space in {datapoints.ColorSpace.GRAY_ALPHA, datapoints.ColorSpace.RGB_ALPHA} and constant_alpha: + if color_space in {"GRAY_ALPHA", "RGBA"} and constant_alpha: data[..., -1, :, :] = max_value - return datapoints.Image(data, color_space=color_space) + return datapoints.Image(data) - return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, color_space=color_space) + return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype) make_image = from_loader(make_image_loader) @@ -290,10 +289,10 @@ def make_image_loaders( *, sizes=DEFAULT_SPATIAL_SIZES, color_spaces=( - datapoints.ColorSpace.GRAY, - datapoints.ColorSpace.GRAY_ALPHA, - datapoints.ColorSpace.RGB, - datapoints.ColorSpace.RGB_ALPHA, + "GRAY", + "GRAY_ALPHA", + "RGB", + "RGBA", ), extra_dims=DEFAULT_EXTRA_DIMS, dtypes=(torch.float32, torch.uint8), @@ -306,7 +305,7 @@ def make_image_loaders( make_images = from_loaders(make_image_loaders) -def make_image_loader_for_interpolation(size="random", *, color_space=datapoints.ColorSpace.RGB, dtype=torch.uint8): +def make_image_loader_for_interpolation(size="random", *, color_space="RGB", dtype=torch.uint8): size = _parse_spatial_size(size) num_channels = get_num_channels(color_space) @@ -318,24 +317,24 @@ def fn(shape, dtype, device): .resize((width, height)) .convert( { - datapoints.ColorSpace.GRAY: "L", - datapoints.ColorSpace.GRAY_ALPHA: "LA", - datapoints.ColorSpace.RGB: "RGB", - datapoints.ColorSpace.RGB_ALPHA: "RGBA", + "GRAY": "L", + "GRAY_ALPHA": "LA", + "RGB": "RGB", + "RGBA": "RGBA", }[color_space] ) ) image_tensor = convert_dtype_image_tensor(to_image_tensor(image_pil).to(device=device), dtype=dtype) - return datapoints.Image(image_tensor, color_space=color_space) + return datapoints.Image(image_tensor) - return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, color_space=color_space) + return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype) def make_image_loaders_for_interpolation( sizes=((233, 147),), - color_spaces=(datapoints.ColorSpace.RGB,), + color_spaces=("RGB",), dtypes=(torch.uint8,), ): for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes): @@ -583,7 +582,7 @@ class VideoLoader(ImageLoader): def make_video_loader( size="random", *, - color_space=datapoints.ColorSpace.RGB, + color_space="RGB", num_frames="random", extra_dims=(), dtype=torch.uint8, @@ -592,12 +591,10 @@ def make_video_loader( num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames def fn(shape, dtype, device): - video = make_image(size=shape[-2:], color_space=color_space, extra_dims=shape[:-3], dtype=dtype, device=device) - return datapoints.Video(video, color_space=color_space) + video = make_image(size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device) + return datapoints.Video(video) - return VideoLoader( - fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype, color_space=color_space - ) + return VideoLoader(fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype) make_video = from_loader(make_video_loader) @@ -607,8 +604,8 @@ def make_video_loaders( *, sizes=DEFAULT_SPATIAL_SIZES, color_spaces=( - datapoints.ColorSpace.GRAY, - datapoints.ColorSpace.RGB, + "GRAY", + "RGB", ), num_frames=(1, 0, "random"), extra_dims=DEFAULT_EXTRA_DIMS, diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index e1420d1cc7b..1fac1526248 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -9,7 +9,6 @@ import torch.testing import torchvision.ops import torchvision.prototype.transforms.functional as F -from common_utils import cycle_over from datasets_utils import combinations_grid from prototype_common_utils import ( ArgsKwargs, @@ -261,14 +260,12 @@ def _get_resize_sizes(spatial_size): def sample_inputs_resize_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32] - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]): for size in _get_resize_sizes(image_loader.spatial_size): yield ArgsKwargs(image_loader, size=size) for image_loader, interpolation in itertools.product( - make_image_loaders(sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB]), + make_image_loaders(sizes=["random"], color_spaces=["RGB"]), [ F.InterpolationMode.NEAREST, F.InterpolationMode.BILINEAR, @@ -472,7 +469,7 @@ def float32_vs_uint8_fill_adapter(other_args, kwargs): def sample_inputs_affine_image_tensor(): make_affine_image_loaders = functools.partial( - make_image_loaders, sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32] + make_image_loaders, sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32] ) for image_loader, affine_params in itertools.product(make_affine_image_loaders(), _DIVERSE_AFFINE_PARAMS): @@ -684,69 +681,6 @@ def reference_inputs_convert_format_bounding_box(): ) -def sample_inputs_convert_color_space_image_tensor(): - color_spaces = sorted( - set(datapoints.ColorSpace) - {datapoints.ColorSpace.OTHER}, key=lambda color_space: color_space.value - ) - - for old_color_space, new_color_space in cycle_over(color_spaces): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=[old_color_space], constant_alpha=True): - yield ArgsKwargs(image_loader, old_color_space=old_color_space, new_color_space=new_color_space) - - for color_space in color_spaces: - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=[color_space], dtypes=[torch.float32], constant_alpha=True - ): - yield ArgsKwargs(image_loader, old_color_space=color_space, new_color_space=color_space) - - -@pil_reference_wrapper -def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_color_space): - color_space_pil = datapoints.ColorSpace.from_pil_mode(image_pil.mode) - if color_space_pil != old_color_space: - raise pytest.UsageError( - f"Converting the tensor image into an PIL image changed the colorspace " - f"from {old_color_space} to {color_space_pil}" - ) - - return F.convert_color_space_image_pil(image_pil, color_space=new_color_space) - - -def reference_inputs_convert_color_space_image_tensor(): - for args_kwargs in sample_inputs_convert_color_space_image_tensor(): - (image_loader, *other_args), kwargs = args_kwargs - if len(image_loader.shape) == 3 and image_loader.dtype == torch.uint8: - yield args_kwargs - - -def sample_inputs_convert_color_space_video(): - color_spaces = [datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB] - - for old_color_space, new_color_space in cycle_over(color_spaces): - for video_loader in make_video_loaders(sizes=["random"], color_spaces=[old_color_space], num_frames=["random"]): - yield ArgsKwargs(video_loader, old_color_space=old_color_space, new_color_space=new_color_space) - - -KERNEL_INFOS.extend( - [ - KernelInfo( - F.convert_color_space_image_tensor, - sample_inputs_fn=sample_inputs_convert_color_space_image_tensor, - reference_fn=reference_convert_color_space_image_tensor, - reference_inputs_fn=reference_inputs_convert_color_space_image_tensor, - closeness_kwargs={ - **pil_reference_pixel_difference(), - **float32_vs_uint8_pixel_difference(), - }, - ), - KernelInfo( - F.convert_color_space_video, - sample_inputs_fn=sample_inputs_convert_color_space_video, - ), - ] -) - - def sample_inputs_vertical_flip_image_tensor(): for image_loader in make_image_loaders(sizes=["random"], dtypes=[torch.float32]): yield ArgsKwargs(image_loader) @@ -822,7 +756,7 @@ def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size): def sample_inputs_rotate_image_tensor(): make_rotate_image_loaders = functools.partial( - make_image_loaders, sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32] + make_image_loaders, sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32] ) for image_loader in make_rotate_image_loaders(): @@ -904,7 +838,7 @@ def sample_inputs_rotate_video(): def sample_inputs_crop_image_tensor(): for image_loader, params in itertools.product( - make_image_loaders(sizes=[(16, 17)], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]), + make_image_loaders(sizes=[(16, 17)], color_spaces=["RGB"], dtypes=[torch.float32]), [ dict(top=4, left=3, height=7, width=8), dict(top=-1, left=3, height=7, width=8), @@ -1090,7 +1024,7 @@ def sample_inputs_resized_crop_video(): def sample_inputs_pad_image_tensor(): make_pad_image_loaders = functools.partial( - make_image_loaders, sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32] + make_image_loaders, sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32] ) for image_loader, padding in itertools.product( @@ -1406,7 +1340,7 @@ def sample_inputs_elastic_video(): def sample_inputs_center_crop_image_tensor(): for image_loader, output_size in itertools.product( - make_image_loaders(sizes=[(16, 17)], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]), + make_image_loaders(sizes=[(16, 17)], color_spaces=["RGB"], dtypes=[torch.float32]), [ # valid `output_size` types for which cropping is applied to both dimensions *[5, (4,), (2, 3), [6], [3, 2]], @@ -1492,9 +1426,7 @@ def sample_inputs_center_crop_video(): def sample_inputs_gaussian_blur_image_tensor(): - make_gaussian_blur_image_loaders = functools.partial( - make_image_loaders, sizes=[(7, 33)], color_spaces=[datapoints.ColorSpace.RGB] - ) + make_gaussian_blur_image_loaders = functools.partial(make_image_loaders, sizes=[(7, 33)], color_spaces=["RGB"]) for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]): yield ArgsKwargs(image_loader, kernel_size=kernel_size) @@ -1531,9 +1463,7 @@ def sample_inputs_gaussian_blur_video(): def sample_inputs_equalize_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader) @@ -1560,7 +1490,7 @@ def make_beta_distributed_image(shape, dtype, device, *, alpha, beta): spatial_size = (256, 256) for dtype, color_space, fn in itertools.product( [torch.uint8], - [datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB], + ["GRAY", "RGB"], [ lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device), lambda shape, dtype, device: torch.full( @@ -1585,9 +1515,7 @@ def make_beta_distributed_image(shape, dtype, device, *, alpha, beta): ], ], ): - image_loader = ImageLoader( - fn, shape=(get_num_channels(color_space), *spatial_size), dtype=dtype, color_space=color_space - ) + image_loader = ImageLoader(fn, shape=(get_num_channels(color_space), *spatial_size), dtype=dtype) yield ArgsKwargs(image_loader) @@ -1615,16 +1543,12 @@ def sample_inputs_equalize_video(): def sample_inputs_invert_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader) def reference_inputs_invert_image_tensor(): - for image_loader in make_image_loaders( - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ): + for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]): yield ArgsKwargs(image_loader) @@ -1655,17 +1579,13 @@ def sample_inputs_invert_video(): def sample_inputs_posterize_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0]) def reference_inputs_posterize_image_tensor(): for image_loader, bits in itertools.product( - make_image_loaders( - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ), + make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]), _POSTERIZE_BITS, ): yield ArgsKwargs(image_loader, bits=bits) @@ -1702,16 +1622,12 @@ def _get_solarize_thresholds(dtype): def sample_inputs_solarize_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, threshold=next(_get_solarize_thresholds(image_loader.dtype))) def reference_inputs_solarize_image_tensor(): - for image_loader in make_image_loaders( - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ): + for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]): for threshold in _get_solarize_thresholds(image_loader.dtype): yield ArgsKwargs(image_loader, threshold=threshold) @@ -1745,16 +1661,12 @@ def sample_inputs_solarize_video(): def sample_inputs_autocontrast_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader) def reference_inputs_autocontrast_image_tensor(): - for image_loader in make_image_loaders( - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ): + for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]): yield ArgsKwargs(image_loader) @@ -1790,16 +1702,14 @@ def sample_inputs_autocontrast_video(): def sample_inputs_adjust_sharpness_image_tensor(): for image_loader in make_image_loaders( sizes=["random", (2, 2)], - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), + color_spaces=("GRAY", "RGB"), ): yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0]) def reference_inputs_adjust_sharpness_image_tensor(): for image_loader, sharpness_factor in itertools.product( - make_image_loaders( - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ), + make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]), _ADJUST_SHARPNESS_FACTORS, ): yield ArgsKwargs(image_loader, sharpness_factor=sharpness_factor) @@ -1863,17 +1773,13 @@ def sample_inputs_erase_video(): def sample_inputs_adjust_brightness_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0]) def reference_inputs_adjust_brightness_image_tensor(): for image_loader, brightness_factor in itertools.product( - make_image_loaders( - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ), + make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]), _ADJUST_BRIGHTNESS_FACTORS, ): yield ArgsKwargs(image_loader, brightness_factor=brightness_factor) @@ -1907,17 +1813,13 @@ def sample_inputs_adjust_brightness_video(): def sample_inputs_adjust_contrast_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0]) def reference_inputs_adjust_contrast_image_tensor(): for image_loader, contrast_factor in itertools.product( - make_image_loaders( - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ), + make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]), _ADJUST_CONTRAST_FACTORS, ): yield ArgsKwargs(image_loader, contrast_factor=contrast_factor) @@ -1959,17 +1861,13 @@ def sample_inputs_adjust_contrast_video(): def sample_inputs_adjust_gamma_image_tensor(): gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0] - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, gamma=gamma, gain=gain) def reference_inputs_adjust_gamma_image_tensor(): for image_loader, (gamma, gain) in itertools.product( - make_image_loaders( - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ), + make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]), _ADJUST_GAMMA_GAMMAS_GAINS, ): yield ArgsKwargs(image_loader, gamma=gamma, gain=gain) @@ -2007,17 +1905,13 @@ def sample_inputs_adjust_gamma_video(): def sample_inputs_adjust_hue_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0]) def reference_inputs_adjust_hue_image_tensor(): for image_loader, hue_factor in itertools.product( - make_image_loaders( - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ), + make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]), _ADJUST_HUE_FACTORS, ): yield ArgsKwargs(image_loader, hue_factor=hue_factor) @@ -2053,17 +1947,13 @@ def sample_inputs_adjust_hue_video(): def sample_inputs_adjust_saturation_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0]) def reference_inputs_adjust_saturation_image_tensor(): for image_loader, saturation_factor in itertools.product( - make_image_loaders( - color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ), + make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]), _ADJUST_SATURATION_FACTORS, ): yield ArgsKwargs(image_loader, saturation_factor=saturation_factor) @@ -2128,7 +2018,7 @@ def sample_inputs_five_crop_image_tensor(): for size in _FIVE_TEN_CROP_SIZES: for image_loader in make_image_loaders( sizes=[_get_five_ten_crop_spatial_size(size)], - color_spaces=[datapoints.ColorSpace.RGB], + color_spaces=["RGB"], dtypes=[torch.float32], ): yield ArgsKwargs(image_loader, size=size) @@ -2152,7 +2042,7 @@ def sample_inputs_ten_crop_image_tensor(): for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]): for image_loader in make_image_loaders( sizes=[_get_five_ten_crop_spatial_size(size)], - color_spaces=[datapoints.ColorSpace.RGB], + color_spaces=["RGB"], dtypes=[torch.float32], ): yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip) @@ -2226,7 +2116,7 @@ def wrapper(input_tensor, *other_args, **kwargs): def sample_inputs_normalize_image_tensor(): for image_loader, (mean, std) in itertools.product( - make_image_loaders(sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]), + make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]), _NORMALIZE_MEANS_STDS, ): yield ArgsKwargs(image_loader, mean=mean, std=std) @@ -2242,7 +2132,7 @@ def reference_normalize_image_tensor(image, mean, std, inplace=False): def reference_inputs_normalize_image_tensor(): yield ArgsKwargs( - make_image_loader(size=(32, 32), color_space=datapoints.ColorSpace.RGB, extra_dims=[1]), + make_image_loader(size=(32, 32), color_space="RGB", extra_dims=[1]), mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0], ) @@ -2251,7 +2141,7 @@ def reference_inputs_normalize_image_tensor(): def sample_inputs_normalize_video(): mean, std = _NORMALIZE_MEANS_STDS[0] for video_loader in make_video_loaders( - sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], num_frames=["random"], dtypes=[torch.float32] + sizes=["random"], color_spaces=["RGB"], num_frames=["random"], dtypes=[torch.float32] ): yield ArgsKwargs(video_loader, mean=mean, std=std) @@ -2285,9 +2175,7 @@ def sample_inputs_convert_dtype_image_tensor(): # conversion cannot be performed safely continue - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[input_dtype] - ): + for image_loader in make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[input_dtype]): yield ArgsKwargs(image_loader, dtype=output_dtype) @@ -2414,7 +2302,7 @@ def reference_uniform_temporal_subsample_video(x, num_samples, temporal_dim=-4): def reference_inputs_uniform_temporal_subsample_video(): - for video_loader in make_video_loaders(sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], num_frames=[10]): + for video_loader in make_video_loaders(sizes=["random"], color_spaces=["RGB"], num_frames=[10]): for num_samples in range(1, video_loader.shape[-4] + 1): yield ArgsKwargs(video_loader, num_samples) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 3826293f3ed..335fbfd4fe3 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -161,8 +161,8 @@ def test_mixup_cutmix(self, transform, input): itertools.chain.from_iterable( fn( color_spaces=[ - datapoints.ColorSpace.GRAY, - datapoints.ColorSpace.RGB, + "GRAY", + "RGB", ], dtypes=[torch.uint8], extra_dims=[(), (4,)], @@ -192,7 +192,7 @@ def test_auto_augment(self, transform, input): ( transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), itertools.chain.from_iterable( - fn(color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]) + fn(color_spaces=["RGB"], dtypes=[torch.float32]) for fn in [ make_images, make_vanilla_tensor_images, @@ -221,45 +221,6 @@ def test_normalize(self, transform, input): def test_random_resized_crop(self, transform, input): transform(input) - @parametrize( - [ - ( - transforms.ConvertColorSpace(color_space=new_color_space, old_color_space=old_color_space), - itertools.chain.from_iterable( - [ - fn(color_spaces=[old_color_space]) - for fn in ( - make_images, - make_vanilla_tensor_images, - make_pil_images, - make_videos, - ) - ] - ), - ) - for old_color_space, new_color_space in itertools.product( - [ - datapoints.ColorSpace.GRAY, - datapoints.ColorSpace.GRAY_ALPHA, - datapoints.ColorSpace.RGB, - datapoints.ColorSpace.RGB_ALPHA, - ], - repeat=2, - ) - ] - ) - def test_convert_color_space(self, transform, input): - transform(input) - - def test_convert_color_space_unsupported_types(self): - transform = transforms.ConvertColorSpace( - color_space=datapoints.ColorSpace.RGB, old_color_space=datapoints.ColorSpace.GRAY - ) - - for inpt in [make_bounding_box(format="XYXY"), make_masks()]: - output = transform(inpt) - assert output is inpt - @pytest.mark.parametrize("p", [0.0, 1.0]) class TestRandomHorizontalFlip: @@ -1558,7 +1519,7 @@ def test__get_params(self, mocker): transform = transforms.FixedSizeCrop(size=crop_size) flat_inputs = [ - make_image(size=spatial_size, color_space=datapoints.ColorSpace.RGB), + make_image(size=spatial_size, color_space="RGB"), make_bounding_box( format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=batch_shape ), diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 00dc40fb06d..3b69b72dd4f 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -31,7 +31,7 @@ from torchvision.prototype.transforms.utils import query_spatial_size from torchvision.transforms import functional as legacy_F -DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[datapoints.ColorSpace.RGB], extra_dims=[(4,)]) +DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)]) class ConsistencyConfig: @@ -138,9 +138,7 @@ def __init__( ], # Make sure that the product of the height, width and number of channels matches the number of elements in # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36. - make_images_kwargs=dict( - DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=[datapoints.ColorSpace.RGB] - ), + make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"]), supports_pil=False, ), ConsistencyConfig( @@ -150,9 +148,7 @@ def __init__( ArgsKwargs(num_output_channels=1), ArgsKwargs(num_output_channels=3), ], - make_images_kwargs=dict( - DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=[datapoints.ColorSpace.RGB, datapoints.ColorSpace.GRAY] - ), + make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]), ), ConsistencyConfig( prototype_transforms.ConvertDtype, @@ -174,10 +170,10 @@ def __init__( [ArgsKwargs()], make_images_kwargs=dict( color_spaces=[ - datapoints.ColorSpace.GRAY, - datapoints.ColorSpace.GRAY_ALPHA, - datapoints.ColorSpace.RGB, - datapoints.ColorSpace.RGB_ALPHA, + "GRAY", + "GRAY_ALPHA", + "RGB", + "RGBA", ], extra_dims=[()], ), @@ -911,7 +907,7 @@ def make_datapoints(self, with_mask=True): size = (600, 800) num_objects = 22 - pil_image = to_image_pil(make_image(size=size, color_space=datapoints.ColorSpace.RGB)) + pil_image = to_image_pil(make_image(size=size, color_space="RGB")) target = { "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), @@ -921,7 +917,7 @@ def make_datapoints(self, with_mask=True): yield (pil_image, target) - tensor_image = torch.Tensor(make_image(size=size, color_space=datapoints.ColorSpace.RGB)) + tensor_image = torch.Tensor(make_image(size=size, color_space="RGB")) target = { "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), @@ -931,7 +927,7 @@ def make_datapoints(self, with_mask=True): yield (tensor_image, target) - datapoint_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB) + datapoint_image = make_image(size=size, color_space="RGB") target = { "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), @@ -1015,7 +1011,7 @@ def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8): conv_fns.extend([torch.Tensor, lambda x: x]) for conv_fn in conv_fns: - datapoint_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB, dtype=image_dtype) + datapoint_image = make_image(size=size, color_space="RGB", dtype=image_dtype) datapoint_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8) dp = (conv_fn(datapoint_image), datapoint_mask) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 7f0781fb010..649620eda62 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -340,7 +340,6 @@ def test_scripted_smoke(self, info, args_kwargs, device): "dispatcher", [ F.clamp_bounding_box, - F.convert_color_space, F.get_dimensions, F.get_image_num_channels, F.get_image_size, diff --git a/test/test_prototype_transforms_utils.py b/test/test_prototype_transforms_utils.py index 8774b3bb8c5..befccf0bea3 100644 --- a/test/test_prototype_transforms_utils.py +++ b/test/test_prototype_transforms_utils.py @@ -11,7 +11,7 @@ from torchvision.prototype.transforms.utils import has_all, has_any -IMAGE = make_image(color_space=datapoints.ColorSpace.RGB) +IMAGE = make_image(color_space="RGB") BOUNDING_BOX = make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, spatial_size=IMAGE.spatial_size) MASK = make_detection_mask(size=IMAGE.spatial_size) diff --git a/torchvision/prototype/datapoints/__init__.py b/torchvision/prototype/datapoints/__init__.py index 92f345e20bd..f85cb3dd596 100644 --- a/torchvision/prototype/datapoints/__init__.py +++ b/torchvision/prototype/datapoints/__init__.py @@ -1,6 +1,6 @@ from ._bounding_box import BoundingBox, BoundingBoxFormat from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT -from ._image import ColorSpace, Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT +from ._image import Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT from ._label import Label, OneHotLabel from ._mask import Mask from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index d674745a716..ece95169ac3 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -1,76 +1,24 @@ from __future__ import annotations -import warnings from typing import Any, List, Optional, Tuple, Union import PIL.Image import torch -from torchvision._utils import StrEnum from torchvision.transforms.functional import InterpolationMode from ._datapoint import Datapoint, FillTypeJIT -class ColorSpace(StrEnum): - OTHER = StrEnum.auto() - GRAY = StrEnum.auto() - GRAY_ALPHA = StrEnum.auto() - RGB = StrEnum.auto() - RGB_ALPHA = StrEnum.auto() - - @classmethod - def from_pil_mode(cls, mode: str) -> ColorSpace: - if mode == "L": - return cls.GRAY - elif mode == "LA": - return cls.GRAY_ALPHA - elif mode == "RGB": - return cls.RGB - elif mode == "RGBA": - return cls.RGB_ALPHA - else: - return cls.OTHER - - @staticmethod - def from_tensor_shape(shape: List[int]) -> ColorSpace: - return _from_tensor_shape(shape) - - -def _from_tensor_shape(shape: List[int]) -> ColorSpace: - # Needed as a standalone method for JIT - ndim = len(shape) - if ndim < 2: - return ColorSpace.OTHER - elif ndim == 2: - return ColorSpace.GRAY - - num_channels = shape[-3] - if num_channels == 1: - return ColorSpace.GRAY - elif num_channels == 2: - return ColorSpace.GRAY_ALPHA - elif num_channels == 3: - return ColorSpace.RGB - elif num_channels == 4: - return ColorSpace.RGB_ALPHA - else: - return ColorSpace.OTHER - - class Image(Datapoint): - color_space: ColorSpace - @classmethod - def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Image: + def _wrap(cls, tensor: torch.Tensor) -> Image: image = tensor.as_subclass(cls) - image.color_space = color_space return image def __new__( cls, data: Any, *, - color_space: Optional[Union[ColorSpace, str]] = None, dtype: Optional[torch.dtype] = None, device: Optional[Union[torch.device, str, int]] = None, requires_grad: bool = False, @@ -81,26 +29,14 @@ def __new__( elif tensor.ndim == 2: tensor = tensor.unsqueeze(0) - if color_space is None: - color_space = ColorSpace.from_tensor_shape(tensor.shape) # type: ignore[arg-type] - if color_space == ColorSpace.OTHER: - warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.") - elif isinstance(color_space, str): - color_space = ColorSpace.from_str(color_space.upper()) - elif not isinstance(color_space, ColorSpace): - raise ValueError - - return cls._wrap(tensor, color_space=color_space) + return cls._wrap(tensor) @classmethod - def wrap_like(cls, other: Image, tensor: torch.Tensor, *, color_space: Optional[ColorSpace] = None) -> Image: - return cls._wrap( - tensor, - color_space=color_space if color_space is not None else other.color_space, - ) + def wrap_like(cls, other: Image, tensor: torch.Tensor) -> Image: + return cls._wrap(tensor) def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] - return self._make_repr(color_space=self.color_space) + return self._make_repr() @property def spatial_size(self) -> Tuple[int, int]: diff --git a/torchvision/prototype/datapoints/_video.py b/torchvision/prototype/datapoints/_video.py index c7273874655..5a73d35368a 100644 --- a/torchvision/prototype/datapoints/_video.py +++ b/torchvision/prototype/datapoints/_video.py @@ -1,29 +1,23 @@ from __future__ import annotations -import warnings from typing import Any, List, Optional, Tuple, Union import torch from torchvision.transforms.functional import InterpolationMode from ._datapoint import Datapoint, FillTypeJIT -from ._image import ColorSpace class Video(Datapoint): - color_space: ColorSpace - @classmethod - def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Video: + def _wrap(cls, tensor: torch.Tensor) -> Video: video = tensor.as_subclass(cls) - video.color_space = color_space return video def __new__( cls, data: Any, *, - color_space: Optional[Union[ColorSpace, str]] = None, dtype: Optional[torch.dtype] = None, device: Optional[Union[torch.device, str, int]] = None, requires_grad: bool = False, @@ -31,28 +25,14 @@ def __new__( tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) if data.ndim < 4: raise ValueError - video = super().__new__(cls, data, requires_grad=requires_grad) - - if color_space is None: - color_space = ColorSpace.from_tensor_shape(video.shape) # type: ignore[arg-type] - if color_space == ColorSpace.OTHER: - warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.") - elif isinstance(color_space, str): - color_space = ColorSpace.from_str(color_space.upper()) - elif not isinstance(color_space, ColorSpace): - raise ValueError - - return cls._wrap(tensor, color_space=color_space) + return cls._wrap(tensor) @classmethod - def wrap_like(cls, other: Video, tensor: torch.Tensor, *, color_space: Optional[ColorSpace] = None) -> Video: - return cls._wrap( - tensor, - color_space=color_space if color_space is not None else other.color_space, - ) + def wrap_like(cls, other: Video, tensor: torch.Tensor) -> Video: + return cls._wrap(tensor) def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] - return self._make_repr(color_space=self.color_space) + return self._make_repr() @property def spatial_size(self) -> Tuple[int, int]: diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 04b007190b8..fa75cf63339 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -39,7 +39,7 @@ ScaleJitter, TenCrop, ) -from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertDtype, ConvertImageDtype +from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype from ._misc import ( GaussianBlur, Identity, diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index 3247a8051a3..974fe2b2741 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -28,6 +28,7 @@ def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, return _F.to_tensor(inpt) +# TODO: in other PR (?) undeprecate those and make them use _rgb_to_gray? class Grayscale(Transform): _transformed_types = ( datapoints.Image, @@ -62,7 +63,7 @@ def _transform( ) -> Union[datapoints.ImageType, datapoints.VideoType]: output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.GRAY) # type: ignore[arg-type] + output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] return output @@ -98,5 +99,5 @@ def _transform( ) -> Union[datapoints.ImageType, datapoints.VideoType]: output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.GRAY) # type: ignore[arg-type] + output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] return output diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 6ad9e041098..0373ee1baf3 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -1,6 +1,4 @@ -from typing import Any, Dict, Optional, Union - -import PIL.Image +from typing import Any, Dict, Union import torch @@ -46,35 +44,6 @@ def _transform( ConvertImageDtype = ConvertDtype -class ConvertColorSpace(Transform): - _transformed_types = ( - is_simple_tensor, - datapoints.Image, - PIL.Image.Image, - datapoints.Video, - ) - - def __init__( - self, - color_space: Union[str, datapoints.ColorSpace], - old_color_space: Optional[Union[str, datapoints.ColorSpace]] = None, - ) -> None: - super().__init__() - - if isinstance(color_space, str): - color_space = datapoints.ColorSpace.from_str(color_space) - self.color_space = color_space - - if isinstance(old_color_space, str): - old_color_space = datapoints.ColorSpace.from_str(old_color_space) - self.old_color_space = old_color_space - - def _transform( - self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] - ) -> Union[datapoints.ImageType, datapoints.VideoType]: - return F.convert_color_space(inpt, color_space=self.color_space, old_color_space=self.old_color_space) - - class ClampBoundingBoxes(Transform): _transformed_types = (datapoints.BoundingBox,) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 30ef6e3fc99..57b4cc4423a 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -7,10 +7,6 @@ from ._meta import ( clamp_bounding_box, convert_format_bounding_box, - convert_color_space_image_tensor, - convert_color_space_image_pil, - convert_color_space_video, - convert_color_space, convert_dtype_image_tensor, convert_dtype, convert_dtype_video, diff --git a/torchvision/prototype/transforms/functional/_deprecated.py b/torchvision/prototype/transforms/functional/_deprecated.py index f6fb0af0ae9..a89bcae7b90 100644 --- a/torchvision/prototype/transforms/functional/_deprecated.py +++ b/torchvision/prototype/transforms/functional/_deprecated.py @@ -27,13 +27,11 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima def rgb_to_grayscale( inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1 ) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: - if torch.jit.is_scripting() or is_simple_tensor(inpt): - old_color_space = datapoints._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type] - else: - old_color_space = None - - if isinstance(inpt, (datapoints.Image, datapoints.Video)): - inpt = inpt.as_subclass(torch.Tensor) + old_color_space = None # TODO: remove when un-deprecating + if not (torch.jit.is_scripting() or is_simple_tensor(inpt)) and isinstance( + inpt, (datapoints.Image, datapoints.Video) + ): + inpt = inpt.as_subclass(torch.Tensor) call = ", num_output_channels=3" if num_output_channels == 3 else "" replacement = ( diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 62f9664fc47..b76dc7d7b68 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -1,9 +1,9 @@ -from typing import List, Optional, Tuple, Union +from typing import List, Tuple, Union import PIL.Image import torch from torchvision.prototype import datapoints -from torchvision.prototype.datapoints import BoundingBoxFormat, ColorSpace +from torchvision.prototype.datapoints import BoundingBoxFormat from torchvision.transforms import functional_pil as _FP from torchvision.transforms.functional_tensor import _max_value @@ -225,29 +225,6 @@ def clamp_bounding_box( return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True) -def _strip_alpha(image: torch.Tensor) -> torch.Tensor: - image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3) - if not torch.all(alpha == _max_value(alpha.dtype)): - raise RuntimeError( - "Stripping the alpha channel if it contains values other than the max value is not supported." - ) - return image - - -def _add_alpha(image: torch.Tensor, alpha: Optional[torch.Tensor] = None) -> torch.Tensor: - if alpha is None: - shape = list(image.shape) - shape[-3] = 1 - alpha = torch.full(shape, _max_value(image.dtype), dtype=image.dtype, device=image.device) - return torch.cat((image, alpha), dim=-3) - - -def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor: - repeats = [1] * grayscale.ndim - repeats[-3] = 3 - return grayscale.repeat(repeats) - - def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor: r, g, b = image.unbind(dim=-3) l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) @@ -257,107 +234,6 @@ def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor: return l_img -def convert_color_space_image_tensor( - image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace -) -> torch.Tensor: - if new_color_space == old_color_space: - return image - - if old_color_space == ColorSpace.OTHER or new_color_space == ColorSpace.OTHER: - raise RuntimeError(f"Conversion to or from {ColorSpace.OTHER} is not supported.") - - if old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.GRAY_ALPHA: - return _add_alpha(image) - elif old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.RGB: - return _gray_to_rgb(image) - elif old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.RGB_ALPHA: - return _add_alpha(_gray_to_rgb(image)) - elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.GRAY: - return _strip_alpha(image) - elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB: - return _gray_to_rgb(_strip_alpha(image)) - elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB_ALPHA: - image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3) - return _add_alpha(_gray_to_rgb(image), alpha) - elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY: - return _rgb_to_gray(image) - elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY_ALPHA: - return _add_alpha(_rgb_to_gray(image)) - elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.RGB_ALPHA: - return _add_alpha(image) - elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY: - return _rgb_to_gray(_strip_alpha(image)) - elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY_ALPHA: - image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3) - return _add_alpha(_rgb_to_gray(image), alpha) - elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.RGB: - return _strip_alpha(image) - else: - raise RuntimeError(f"Conversion from {old_color_space} to {new_color_space} is not supported.") - - -_COLOR_SPACE_TO_PIL_MODE = { - ColorSpace.GRAY: "L", - ColorSpace.GRAY_ALPHA: "LA", - ColorSpace.RGB: "RGB", - ColorSpace.RGB_ALPHA: "RGBA", -} - - -@torch.jit.unused -def convert_color_space_image_pil(image: PIL.Image.Image, color_space: ColorSpace) -> PIL.Image.Image: - old_mode = image.mode - try: - new_mode = _COLOR_SPACE_TO_PIL_MODE[color_space] - except KeyError: - raise ValueError(f"Conversion from {ColorSpace.from_pil_mode(old_mode)} to {color_space} is not supported.") - - if image.mode == new_mode: - return image - - return image.convert(new_mode) - - -def convert_color_space_video( - video: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace -) -> torch.Tensor: - return convert_color_space_image_tensor(video, old_color_space=old_color_space, new_color_space=new_color_space) - - -def convert_color_space( - inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], - color_space: ColorSpace, - old_color_space: Optional[ColorSpace] = None, -) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: - if not torch.jit.is_scripting(): - _log_api_usage_once(convert_color_space) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - if old_color_space is None: - raise RuntimeError( - "In order to convert the color space of simple tensors, " - "the `old_color_space=...` parameter needs to be passed." - ) - return convert_color_space_image_tensor(inpt, old_color_space=old_color_space, new_color_space=color_space) - elif isinstance(inpt, datapoints.Image): - output = convert_color_space_image_tensor( - inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space - ) - return datapoints.Image.wrap_like(inpt, output, color_space=color_space) - elif isinstance(inpt, datapoints.Video): - output = convert_color_space_video( - inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space - ) - return datapoints.Video.wrap_like(inpt, output, color_space=color_space) - elif isinstance(inpt, PIL.Image.Image): - return convert_color_space_image_pil(inpt, color_space=color_space) - else: - raise TypeError( - f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) - - def _num_value_bits(dtype: torch.dtype) -> int: if dtype == torch.uint8: return 8 diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 5d662a2c1d1..69965126d8c 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1234,6 +1234,9 @@ def affine( return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill) +# Looks like to_grayscale() is a stand-alone functional that is never called +# from the transform classes. Perhaps it's still here for BC? I can't be +# bothered to dig. Anyway, this can be deprecated as we migrate to V2. @torch.jit.unused def to_grayscale(img, num_output_channels=1): """Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.