diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index 333e11fb227..c10cec94c31 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -45,6 +45,8 @@ "make_segmentation_masks", "make_mask_loaders", "make_masks", + "make_video", + "make_videos", ] @@ -210,17 +212,19 @@ def _parse_image_size(size, *, name="size"): def from_loader(loader_fn): def wrapper(*args, **kwargs): + device = kwargs.pop("device", "cpu") loader = loader_fn(*args, **kwargs) - return loader.load(kwargs.get("device", "cpu")) + return loader.load(device) return wrapper def from_loaders(loaders_fn): def wrapper(*args, **kwargs): + device = kwargs.pop("device", "cpu") loaders = loaders_fn(*args, **kwargs) for loader in loaders: - yield loader.load(kwargs.get("device", "cpu")) + yield loader.load(device) return wrapper @@ -246,6 +250,21 @@ def __post_init__(self): self.num_channels = self.shape[-3] +NUM_CHANNELS_MAP = { + features.ColorSpace.GRAY: 1, + features.ColorSpace.GRAY_ALPHA: 2, + features.ColorSpace.RGB: 3, + features.ColorSpace.RGB_ALPHA: 4, +} + + +def get_num_channels(color_space): + num_channels = NUM_CHANNELS_MAP.get(color_space) + if not num_channels: + raise pytest.UsageError(f"Can't determine the number of channels for color space {color_space}") + return num_channels + + def make_image_loader( size="random", *, @@ -255,16 +274,7 @@ def make_image_loader( constant_alpha=True, ): size = _parse_image_size(size) - - try: - num_channels = { - features.ColorSpace.GRAY: 1, - features.ColorSpace.GRAY_ALPHA: 2, - features.ColorSpace.RGB: 3, - features.ColorSpace.RGB_ALPHA: 4, - }[color_space] - except KeyError as error: - raise pytest.UsageError(f"Can't determine the number of channels for color space {color_space}") from error + num_channels = get_num_channels(color_space) def fn(shape, dtype, device): max_value = get_max_value(dtype) @@ -531,3 +541,50 @@ def make_mask_loaders( make_masks = from_loaders(make_mask_loaders) + + +class VideoLoader(ImageLoader): + pass + + +def make_video_loader( + size="random", + *, + color_space=features.ColorSpace.RGB, + num_frames="random", + extra_dims=(), + dtype=torch.uint8, +): + size = _parse_image_size(size) + num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames + + def fn(shape, dtype, device): + video = make_image(size=shape[-2:], color_space=color_space, extra_dims=shape[:-3], dtype=dtype, device=device) + return features.Video(video, color_space=color_space) + + return VideoLoader( + fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype, color_space=color_space + ) + + +make_video = from_loader(make_video_loader) + + +def make_video_loaders( + *, + sizes=DEFAULT_IMAGE_SIZES, + color_spaces=( + features.ColorSpace.GRAY, + features.ColorSpace.RGB, + ), + num_frames=(1, 0, "random"), + extra_dims=DEFAULT_EXTRA_DIMS, + dtypes=(torch.uint8,), +): + for params in combinations_grid( + size=sizes, color_space=color_spaces, num_frames=num_frames, extra_dims=extra_dims, dtype=dtypes + ): + yield make_video_loader(**params) + + +make_videos = from_loaders(make_video_loaders) diff --git a/test/prototype_transforms_dispatcher_infos.py b/test/prototype_transforms_dispatcher_infos.py index 9678249aa0b..be8bd3002c1 100644 --- a/test/prototype_transforms_dispatcher_infos.py +++ b/test/prototype_transforms_dispatcher_infos.py @@ -127,6 +127,23 @@ def fill_sequence_needs_broadcast(args_kwargs): ) +def xfail_all_tests(*, reason, condition): + return [ + TestMark(("TestDispatchers", test_name), pytest.mark.xfail(reason=reason), condition=condition) + for test_name in [ + "test_scripted_smoke", + "test_dispatch_simple_tensor", + "test_dispatch_feature", + ] + ] + + +xfails_degenerate_or_multi_batch_dims = xfail_all_tests( + reason="See https://github.com/pytorch/vision/issues/6670 for details.", + condition=lambda args_kwargs: len(args_kwargs.args[0].shape) > 4 or not all(args_kwargs.args[0].shape[:-3]), +) + + DISPATCHER_INFOS = [ DispatcherInfo( F.horizontal_flip, @@ -243,6 +260,7 @@ def fill_sequence_needs_broadcast(args_kwargs): pil_kernel_info=PILKernelInfo(F.perspective_image_pil), test_marks=[ xfail_dispatch_pil_if_fill_sequence_needs_broadcast, + *xfails_degenerate_or_multi_batch_dims, ], ), DispatcherInfo( @@ -253,6 +271,7 @@ def fill_sequence_needs_broadcast(args_kwargs): features.Mask: F.elastic_mask, }, pil_kernel_info=PILKernelInfo(F.elastic_image_pil), + test_marks=xfails_degenerate_or_multi_batch_dims, ), DispatcherInfo( F.center_crop, @@ -275,6 +294,7 @@ def fill_sequence_needs_broadcast(args_kwargs): test_marks=[ xfail_jit_python_scalar_arg("kernel_size"), xfail_jit_python_scalar_arg("sigma"), + *xfails_degenerate_or_multi_batch_dims, ], ), DispatcherInfo( diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index c0e7bf5bff4..d90d3bf68be 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -20,6 +20,7 @@ make_image_loader, make_image_loaders, make_mask_loaders, + make_video_loaders, VALID_EXTRA_DIMS, ) from torchvision.prototype import features @@ -142,6 +143,25 @@ def xfail_jit_list_of_ints(name, *, reason=None): ) +def xfail_all_tests(*, reason, condition): + return [ + TestMark(("TestKernels", test_name), pytest.mark.xfail(reason=reason), condition=condition) + for test_name in [ + "test_scripted_vs_eager", + "test_batched_vs_single", + "test_no_inplace", + "test_cuda_vs_cpu", + "test_dtype_and_device_consistency", + ] + ] + + +xfails_image_degenerate_or_multi_batch_dims = xfail_all_tests( + reason="See https://github.com/pytorch/vision/issues/6670 for details.", + condition=lambda args_kwargs: len(args_kwargs.args[0].shape) > 4 or not all(args_kwargs.args[0].shape[:-3]), +) + + KERNEL_INFOS = [] @@ -169,6 +189,11 @@ def sample_inputs_horizontal_flip_mask(): yield ArgsKwargs(image_loader) +def sample_inputs_horizontal_flip_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader) + + KERNEL_INFOS.extend( [ KernelInfo( @@ -187,6 +212,10 @@ def sample_inputs_horizontal_flip_mask(): F.horizontal_flip_mask, sample_inputs_fn=sample_inputs_horizontal_flip_mask, ), + KernelInfo( + F.horizontal_flip_video, + sample_inputs_fn=sample_inputs_horizontal_flip_video, + ), ] ) @@ -287,6 +316,11 @@ def reference_inputs_resize_mask(): yield ArgsKwargs(mask_loader, size=size) +def sample_inputs_resize_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader, size=[min(video_loader.shape[-2:]) + 1]) + + KERNEL_INFOS.extend( [ KernelInfo( @@ -316,6 +350,10 @@ def reference_inputs_resize_mask(): xfail_jit_integer_size(), ], ), + KernelInfo( + F.resize_video, + sample_inputs_fn=sample_inputs_resize_video, + ), ] ) @@ -485,7 +523,7 @@ def reference_inputs_affine_bounding_box(): ) -def sample_inputs_affine_image_mask(): +def sample_inputs_affine_mask(): for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]): yield ArgsKwargs(mask_loader, **_full_affine_params()) @@ -502,6 +540,11 @@ def reference_inputs_resize_mask(): yield ArgsKwargs(mask_loader, **affine_kwargs) +def sample_inputs_affine_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader, **_full_affine_params()) + + KERNEL_INFOS.extend( [ KernelInfo( @@ -529,7 +572,7 @@ def reference_inputs_resize_mask(): ), KernelInfo( F.affine_mask, - sample_inputs_fn=sample_inputs_affine_image_mask, + sample_inputs_fn=sample_inputs_affine_mask, reference_fn=reference_affine_mask, reference_inputs_fn=reference_inputs_resize_mask, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, @@ -537,6 +580,10 @@ def reference_inputs_resize_mask(): xfail_jit_python_scalar_arg("shear"), ], ), + KernelInfo( + F.affine_video, + sample_inputs_fn=sample_inputs_affine_video, + ), ] ) @@ -608,14 +655,28 @@ def reference_inputs_convert_color_space_image_tensor(): yield args_kwargs -KERNEL_INFOS.append( - KernelInfo( - F.convert_color_space_image_tensor, - sample_inputs_fn=sample_inputs_convert_color_space_image_tensor, - reference_fn=reference_convert_color_space_image_tensor, - reference_inputs_fn=reference_inputs_convert_color_space_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ), +def sample_inputs_convert_color_space_video(): + color_spaces = [features.ColorSpace.GRAY, features.ColorSpace.RGB] + + for old_color_space, new_color_space in cycle_over(color_spaces): + for video_loader in make_video_loaders(sizes=["random"], color_spaces=[old_color_space], num_frames=["random"]): + yield ArgsKwargs(video_loader, old_color_space=old_color_space, new_color_space=new_color_space) + + +KERNEL_INFOS.extend( + [ + KernelInfo( + F.convert_color_space_image_tensor, + sample_inputs_fn=sample_inputs_convert_color_space_image_tensor, + reference_fn=reference_convert_color_space_image_tensor, + reference_inputs_fn=reference_inputs_convert_color_space_image_tensor, + closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + ), + KernelInfo( + F.convert_color_space_video, + sample_inputs_fn=sample_inputs_convert_color_space_video, + ), + ] ) @@ -643,6 +704,11 @@ def sample_inputs_vertical_flip_mask(): yield ArgsKwargs(image_loader) +def sample_inputs_vertical_flip_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader) + + KERNEL_INFOS.extend( [ KernelInfo( @@ -661,6 +727,10 @@ def sample_inputs_vertical_flip_mask(): F.vertical_flip_mask, sample_inputs_fn=sample_inputs_vertical_flip_mask, ), + KernelInfo( + F.vertical_flip_video, + sample_inputs_fn=sample_inputs_vertical_flip_video, + ), ] ) @@ -724,6 +794,11 @@ def reference_inputs_rotate_mask(): yield ArgsKwargs(mask_loader, angle=angle) +def sample_inputs_rotate_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader, angle=15.0) + + KERNEL_INFOS.extend( [ KernelInfo( @@ -749,6 +824,10 @@ def reference_inputs_rotate_mask(): reference_inputs_fn=reference_inputs_rotate_mask, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, ), + KernelInfo( + F.rotate_video, + sample_inputs_fn=sample_inputs_rotate_video, + ), ] ) @@ -791,6 +870,11 @@ def reference_inputs_crop_mask(): yield ArgsKwargs(mask_loader, **params) +def sample_inputs_crop_video(): + for video_loader in make_video_loaders(sizes=[(16, 17)], num_frames=["random"]): + yield ArgsKwargs(video_loader, top=4, left=3, height=7, width=8) + + KERNEL_INFOS.extend( [ KernelInfo( @@ -812,6 +896,10 @@ def reference_inputs_crop_mask(): reference_inputs_fn=reference_inputs_crop_mask, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, ), + KernelInfo( + F.crop_video, + sample_inputs_fn=sample_inputs_crop_video, + ), ] ) @@ -872,6 +960,11 @@ def reference_inputs_resized_crop_mask(): yield ArgsKwargs(mask_loader, **params) +def sample_inputs_resized_crop_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader, **_RESIZED_CROP_PARAMS[0]) + + KERNEL_INFOS.extend( [ KernelInfo( @@ -892,6 +985,10 @@ def reference_inputs_resized_crop_mask(): reference_inputs_fn=reference_inputs_resized_crop_mask, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, ), + KernelInfo( + F.resized_crop_video, + sample_inputs_fn=sample_inputs_resized_crop_video, + ), ] ) @@ -965,6 +1062,11 @@ def reference_inputs_pad_mask(): yield ArgsKwargs(image_loader, fill=fill, **params) +def sample_inputs_pad_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader, padding=[1]) + + KERNEL_INFOS.extend( [ KernelInfo( @@ -996,6 +1098,10 @@ def reference_inputs_pad_mask(): reference_inputs_fn=reference_inputs_pad_mask, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, ), + KernelInfo( + F.pad_video, + sample_inputs_fn=sample_inputs_pad_video, + ), ] ) @@ -1006,11 +1112,7 @@ def reference_inputs_pad_mask(): def sample_inputs_perspective_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], - # FIXME: kernel should support arbitrary batch sizes - extra_dims=[(), (4,)], - ): + for image_loader in make_image_loaders(sizes=["random"]): for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]: yield ArgsKwargs(image_loader, fill=fill, perspective_coeffs=_PERSPECTIVE_COEFFS[0]) @@ -1030,11 +1132,7 @@ def sample_inputs_perspective_bounding_box(): def sample_inputs_perspective_mask(): - for mask_loader in make_mask_loaders( - sizes=["random"], - # FIXME: kernel should support arbitrary batch sizes - extra_dims=[(), (4,)], - ): + for mask_loader in make_mask_loaders(sizes=["random"]): yield ArgsKwargs(mask_loader, perspective_coeffs=_PERSPECTIVE_COEFFS[0]) @@ -1045,6 +1143,11 @@ def reference_inputs_perspective_mask(): yield ArgsKwargs(mask_loader, perspective_coeffs=perspective_coeffs) +def sample_inputs_perspective_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader, perspective_coeffs=_PERSPECTIVE_COEFFS[0]) + + KERNEL_INFOS.extend( [ KernelInfo( @@ -1053,6 +1156,7 @@ def reference_inputs_perspective_mask(): reference_fn=pil_reference_wrapper(F.perspective_image_pil), reference_inputs_fn=reference_inputs_perspective_image_tensor, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + test_marks=xfails_image_degenerate_or_multi_batch_dims, ), KernelInfo( F.perspective_bounding_box, @@ -1064,6 +1168,11 @@ def reference_inputs_perspective_mask(): reference_fn=pil_reference_wrapper(F.perspective_image_pil), reference_inputs_fn=reference_inputs_perspective_mask, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + test_marks=xfails_image_degenerate_or_multi_batch_dims, + ), + KernelInfo( + F.perspective_video, + sample_inputs_fn=sample_inputs_perspective_video, ), ] ) @@ -1074,11 +1183,7 @@ def _get_elastic_displacement(image_size): def sample_inputs_elastic_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], - # FIXME: kernel should support arbitrary batch sizes - extra_dims=[(), (4,)], - ): + for image_loader in make_image_loaders(sizes=["random"]): displacement = _get_elastic_displacement(image_loader.image_size) for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]: yield ArgsKwargs(image_loader, displacement=displacement, fill=fill) @@ -1109,11 +1214,7 @@ def sample_inputs_elastic_bounding_box(): def sample_inputs_elastic_mask(): - for mask_loader in make_mask_loaders( - sizes=["random"], - # FIXME: kernel should support arbitrary batch sizes - extra_dims=[(), (4,)], - ): + for mask_loader in make_mask_loaders(sizes=["random"]): displacement = _get_elastic_displacement(mask_loader.shape[-2:]) yield ArgsKwargs(mask_loader, displacement=displacement) @@ -1124,6 +1225,12 @@ def reference_inputs_elastic_mask(): yield ArgsKwargs(mask_loader, displacement=displacement) +def sample_inputs_elastic_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + displacement = _get_elastic_displacement(video_loader.shape[-2:]) + yield ArgsKwargs(video_loader, displacement=displacement) + + KERNEL_INFOS.extend( [ KernelInfo( @@ -1132,6 +1239,7 @@ def reference_inputs_elastic_mask(): reference_fn=pil_reference_wrapper(F.elastic_image_pil), reference_inputs_fn=reference_inputs_elastic_image_tensor, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + test_marks=xfails_image_degenerate_or_multi_batch_dims, ), KernelInfo( F.elastic_bounding_box, @@ -1143,6 +1251,11 @@ def reference_inputs_elastic_mask(): reference_fn=pil_reference_wrapper(F.elastic_image_pil), reference_inputs_fn=reference_inputs_elastic_mask, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + test_marks=xfails_image_degenerate_or_multi_batch_dims, + ), + KernelInfo( + F.elastic_video, + sample_inputs_fn=sample_inputs_elastic_video, ), ] ) @@ -1195,6 +1308,12 @@ def reference_inputs_center_crop_mask(): yield ArgsKwargs(mask_loader, output_size=output_size) +def sample_inputs_center_crop_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + height, width = video_loader.shape[-2:] + yield ArgsKwargs(video_loader, output_size=(height // 2, width // 2)) + + KERNEL_INFOS.extend( [ KernelInfo( @@ -1224,17 +1343,17 @@ def reference_inputs_center_crop_mask(): xfail_jit_integer_size("output_size"), ], ), + KernelInfo( + F.center_crop_video, + sample_inputs_fn=sample_inputs_center_crop_video, + ), ] ) def sample_inputs_gaussian_blur_image_tensor(): make_gaussian_blur_image_loaders = functools.partial( - make_image_loaders, - sizes=["random"], - color_spaces=[features.ColorSpace.RGB], - # FIXME: kernel should support arbitrary batch sizes - extra_dims=[(), (4,)], + make_image_loaders, sizes=["random"], color_spaces=[features.ColorSpace.RGB] ) for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]): @@ -1246,26 +1365,34 @@ def sample_inputs_gaussian_blur_image_tensor(): yield ArgsKwargs(image_loader, kernel_size=5, sigma=sigma) -KERNEL_INFOS.append( - KernelInfo( - F.gaussian_blur_image_tensor, - sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - test_marks=[ - xfail_jit_python_scalar_arg("kernel_size"), - xfail_jit_python_scalar_arg("sigma"), - ], - ) +def sample_inputs_gaussian_blur_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader, kernel_size=[3, 3]) + + +KERNEL_INFOS.extend( + [ + KernelInfo( + F.gaussian_blur_image_tensor, + sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor, + closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + test_marks=[ + xfail_jit_python_scalar_arg("kernel_size"), + xfail_jit_python_scalar_arg("sigma"), + *xfails_image_degenerate_or_multi_batch_dims, + ], + ), + KernelInfo( + F.gaussian_blur_video, + sample_inputs_fn=sample_inputs_gaussian_blur_video, + ), + ] ) def sample_inputs_equalize_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], - # FIXME: kernel should support arbitrary batch sizes - extra_dims=[(), (4,)], - color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), - dtypes=[torch.uint8], + sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), dtypes=[torch.uint8] ): yield ArgsKwargs(image_loader) @@ -1277,15 +1404,26 @@ def reference_inputs_equalize_image_tensor(): yield ArgsKwargs(image_loader) -KERNEL_INFOS.append( - KernelInfo( - F.equalize_image_tensor, - kernel_name="equalize_image_tensor", - sample_inputs_fn=sample_inputs_equalize_image_tensor, - reference_fn=pil_reference_wrapper(F.equalize_image_pil), - reference_inputs_fn=reference_inputs_equalize_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) +def sample_inputs_equalize_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader) + + +KERNEL_INFOS.extend( + [ + KernelInfo( + F.equalize_image_tensor, + kernel_name="equalize_image_tensor", + sample_inputs_fn=sample_inputs_equalize_image_tensor, + reference_fn=pil_reference_wrapper(F.equalize_image_pil), + reference_inputs_fn=reference_inputs_equalize_image_tensor, + closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + ), + KernelInfo( + F.equalize_video, + sample_inputs_fn=sample_inputs_equalize_video, + ), + ] ) @@ -1303,15 +1441,26 @@ def reference_inputs_invert_image_tensor(): yield ArgsKwargs(image_loader) -KERNEL_INFOS.append( - KernelInfo( - F.invert_image_tensor, - kernel_name="invert_image_tensor", - sample_inputs_fn=sample_inputs_invert_image_tensor, - reference_fn=pil_reference_wrapper(F.invert_image_pil), - reference_inputs_fn=reference_inputs_invert_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) +def sample_inputs_invert_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader) + + +KERNEL_INFOS.extend( + [ + KernelInfo( + F.invert_image_tensor, + kernel_name="invert_image_tensor", + sample_inputs_fn=sample_inputs_invert_image_tensor, + reference_fn=pil_reference_wrapper(F.invert_image_pil), + reference_inputs_fn=reference_inputs_invert_image_tensor, + closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + ), + KernelInfo( + F.invert_video, + sample_inputs_fn=sample_inputs_invert_video, + ), + ] ) @@ -1335,15 +1484,26 @@ def reference_inputs_posterize_image_tensor(): yield ArgsKwargs(image_loader, bits=bits) -KERNEL_INFOS.append( - KernelInfo( - F.posterize_image_tensor, - kernel_name="posterize_image_tensor", - sample_inputs_fn=sample_inputs_posterize_image_tensor, - reference_fn=pil_reference_wrapper(F.posterize_image_pil), - reference_inputs_fn=reference_inputs_posterize_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) +def sample_inputs_posterize_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader, bits=_POSTERIZE_BITS[0]) + + +KERNEL_INFOS.extend( + [ + KernelInfo( + F.posterize_image_tensor, + kernel_name="posterize_image_tensor", + sample_inputs_fn=sample_inputs_posterize_image_tensor, + reference_fn=pil_reference_wrapper(F.posterize_image_pil), + reference_inputs_fn=reference_inputs_posterize_image_tensor, + closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + ), + KernelInfo( + F.posterize_video, + sample_inputs_fn=sample_inputs_posterize_video, + ), + ] ) @@ -1368,15 +1528,26 @@ def reference_inputs_solarize_image_tensor(): yield ArgsKwargs(image_loader, threshold=threshold) -KERNEL_INFOS.append( - KernelInfo( - F.solarize_image_tensor, - kernel_name="solarize_image_tensor", - sample_inputs_fn=sample_inputs_solarize_image_tensor, - reference_fn=pil_reference_wrapper(F.solarize_image_pil), - reference_inputs_fn=reference_inputs_solarize_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) +def sample_inputs_solarize_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader, threshold=next(_get_solarize_thresholds(video_loader.dtype))) + + +KERNEL_INFOS.extend( + [ + KernelInfo( + F.solarize_image_tensor, + kernel_name="solarize_image_tensor", + sample_inputs_fn=sample_inputs_solarize_image_tensor, + reference_fn=pil_reference_wrapper(F.solarize_image_pil), + reference_inputs_fn=reference_inputs_solarize_image_tensor, + closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + ), + KernelInfo( + F.solarize_video, + sample_inputs_fn=sample_inputs_solarize_video, + ), + ] ) @@ -1394,15 +1565,26 @@ def reference_inputs_autocontrast_image_tensor(): yield ArgsKwargs(image_loader) -KERNEL_INFOS.append( - KernelInfo( - F.autocontrast_image_tensor, - kernel_name="autocontrast_image_tensor", - sample_inputs_fn=sample_inputs_autocontrast_image_tensor, - reference_fn=pil_reference_wrapper(F.autocontrast_image_pil), - reference_inputs_fn=reference_inputs_autocontrast_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) +def sample_inputs_autocontrast_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader) + + +KERNEL_INFOS.extend( + [ + KernelInfo( + F.autocontrast_image_tensor, + kernel_name="autocontrast_image_tensor", + sample_inputs_fn=sample_inputs_autocontrast_image_tensor, + reference_fn=pil_reference_wrapper(F.autocontrast_image_pil), + reference_inputs_fn=reference_inputs_autocontrast_image_tensor, + closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + ), + KernelInfo( + F.autocontrast_video, + sample_inputs_fn=sample_inputs_autocontrast_video, + ), + ] ) _ADJUST_SHARPNESS_FACTORS = [0.1, 0.5] @@ -1412,8 +1594,6 @@ def sample_inputs_adjust_sharpness_image_tensor(): for image_loader in make_image_loaders( sizes=["random", (2, 2)], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), - # FIXME: kernel should support arbitrary batch sizes - extra_dims=[(), (4,)], ): yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0]) @@ -1426,15 +1606,26 @@ def reference_inputs_adjust_sharpness_image_tensor(): yield ArgsKwargs(image_loader, sharpness_factor=sharpness_factor) -KERNEL_INFOS.append( - KernelInfo( - F.adjust_sharpness_image_tensor, - kernel_name="adjust_sharpness_image_tensor", - sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor, - reference_fn=pil_reference_wrapper(F.adjust_sharpness_image_pil), - reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) +def sample_inputs_adjust_sharpness_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0]) + + +KERNEL_INFOS.extend( + [ + KernelInfo( + F.adjust_sharpness_image_tensor, + kernel_name="adjust_sharpness_image_tensor", + sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor, + reference_fn=pil_reference_wrapper(F.adjust_sharpness_image_pil), + reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor, + closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + ), + KernelInfo( + F.adjust_sharpness_video, + sample_inputs_fn=sample_inputs_adjust_sharpness_video, + ), + ] ) @@ -1446,12 +1637,26 @@ def sample_inputs_erase_image_tensor(): yield ArgsKwargs(image_loader, i=1, j=2, h=h, w=w, v=v) -KERNEL_INFOS.append( - KernelInfo( - F.erase_image_tensor, - kernel_name="erase_image_tensor", - sample_inputs_fn=sample_inputs_erase_image_tensor, - ) +def sample_inputs_erase_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + # FIXME: make the parameters more diverse + h, w = 6, 7 + v = torch.rand(video_loader.num_channels, h, w) + yield ArgsKwargs(video_loader, i=1, j=2, h=h, w=w, v=v) + + +KERNEL_INFOS.extend( + [ + KernelInfo( + F.erase_image_tensor, + kernel_name="erase_image_tensor", + sample_inputs_fn=sample_inputs_erase_image_tensor, + ), + KernelInfo( + F.erase_video, + sample_inputs_fn=sample_inputs_erase_video, + ), + ] ) _ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5] @@ -1472,15 +1677,26 @@ def reference_inputs_adjust_brightness_image_tensor(): yield ArgsKwargs(image_loader, brightness_factor=brightness_factor) -KERNEL_INFOS.append( - KernelInfo( - F.adjust_brightness_image_tensor, - kernel_name="adjust_brightness_image_tensor", - sample_inputs_fn=sample_inputs_adjust_brightness_image_tensor, - reference_fn=pil_reference_wrapper(F.adjust_brightness_image_pil), - reference_inputs_fn=reference_inputs_adjust_brightness_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) +def sample_inputs_adjust_brightness_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0]) + + +KERNEL_INFOS.extend( + [ + KernelInfo( + F.adjust_brightness_image_tensor, + kernel_name="adjust_brightness_image_tensor", + sample_inputs_fn=sample_inputs_adjust_brightness_image_tensor, + reference_fn=pil_reference_wrapper(F.adjust_brightness_image_pil), + reference_inputs_fn=reference_inputs_adjust_brightness_image_tensor, + closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + ), + KernelInfo( + F.adjust_brightness_video, + sample_inputs_fn=sample_inputs_adjust_brightness_video, + ), + ] ) @@ -1502,15 +1718,26 @@ def reference_inputs_adjust_contrast_image_tensor(): yield ArgsKwargs(image_loader, contrast_factor=contrast_factor) -KERNEL_INFOS.append( - KernelInfo( - F.adjust_contrast_image_tensor, - kernel_name="adjust_contrast_image_tensor", - sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor, - reference_fn=pil_reference_wrapper(F.adjust_contrast_image_pil), - reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) +def sample_inputs_adjust_contrast_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0]) + + +KERNEL_INFOS.extend( + [ + KernelInfo( + F.adjust_contrast_image_tensor, + kernel_name="adjust_contrast_image_tensor", + sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor, + reference_fn=pil_reference_wrapper(F.adjust_contrast_image_pil), + reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor, + closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + ), + KernelInfo( + F.adjust_contrast_video, + sample_inputs_fn=sample_inputs_adjust_contrast_video, + ), + ] ) _ADJUST_GAMMA_GAMMAS_GAINS = [ @@ -1535,15 +1762,27 @@ def reference_inputs_adjust_gamma_image_tensor(): yield ArgsKwargs(image_loader, gamma=gamma, gain=gain) -KERNEL_INFOS.append( - KernelInfo( - F.adjust_gamma_image_tensor, - kernel_name="adjust_gamma_image_tensor", - sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor, - reference_fn=pil_reference_wrapper(F.adjust_gamma_image_pil), - reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) +def sample_inputs_adjust_gamma_video(): + gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0] + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader, gamma=gamma, gain=gain) + + +KERNEL_INFOS.extend( + [ + KernelInfo( + F.adjust_gamma_image_tensor, + kernel_name="adjust_gamma_image_tensor", + sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor, + reference_fn=pil_reference_wrapper(F.adjust_gamma_image_pil), + reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor, + closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + ), + KernelInfo( + F.adjust_gamma_video, + sample_inputs_fn=sample_inputs_adjust_gamma_video, + ), + ] ) @@ -1565,15 +1804,26 @@ def reference_inputs_adjust_hue_image_tensor(): yield ArgsKwargs(image_loader, hue_factor=hue_factor) -KERNEL_INFOS.append( - KernelInfo( - F.adjust_hue_image_tensor, - kernel_name="adjust_hue_image_tensor", - sample_inputs_fn=sample_inputs_adjust_hue_image_tensor, - reference_fn=pil_reference_wrapper(F.adjust_hue_image_pil), - reference_inputs_fn=reference_inputs_adjust_hue_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) +def sample_inputs_adjust_hue_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader, hue_factor=_ADJUST_HUE_FACTORS[0]) + + +KERNEL_INFOS.extend( + [ + KernelInfo( + F.adjust_hue_image_tensor, + kernel_name="adjust_hue_image_tensor", + sample_inputs_fn=sample_inputs_adjust_hue_image_tensor, + reference_fn=pil_reference_wrapper(F.adjust_hue_image_pil), + reference_inputs_fn=reference_inputs_adjust_hue_image_tensor, + closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + ), + KernelInfo( + F.adjust_hue_video, + sample_inputs_fn=sample_inputs_adjust_hue_video, + ), + ] ) _ADJUST_SATURATION_FACTORS = [0.1, 0.5] @@ -1594,15 +1844,26 @@ def reference_inputs_adjust_saturation_image_tensor(): yield ArgsKwargs(image_loader, saturation_factor=saturation_factor) -KERNEL_INFOS.append( - KernelInfo( - F.adjust_saturation_image_tensor, - kernel_name="adjust_saturation_image_tensor", - sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor, - reference_fn=pil_reference_wrapper(F.adjust_saturation_image_pil), - reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) +def sample_inputs_adjust_saturation_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0]) + + +KERNEL_INFOS.extend( + [ + KernelInfo( + F.adjust_saturation_image_tensor, + kernel_name="adjust_saturation_image_tensor", + sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor, + reference_fn=pil_reference_wrapper(F.adjust_saturation_image_pil), + reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor, + closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + ), + KernelInfo( + F.adjust_saturation_video, + sample_inputs_fn=sample_inputs_adjust_saturation_video, + ), + ] ) @@ -1702,10 +1963,24 @@ def sample_inputs_normalize_image_tensor(): yield ArgsKwargs(image_loader, mean=mean, std=std) -KERNEL_INFOS.append( - KernelInfo( - F.normalize_image_tensor, - kernel_name="normalize_image_tensor", - sample_inputs_fn=sample_inputs_normalize_image_tensor, - ) +def sample_inputs_normalize_video(): + mean, std = _NORMALIZE_MEANS_STDS[0] + for video_loader in make_video_loaders( + sizes=["random"], color_spaces=[features.ColorSpace.RGB], num_frames=["random"], dtypes=[torch.float32] + ): + yield ArgsKwargs(video_loader, mean=mean, std=std) + + +KERNEL_INFOS.extend( + [ + KernelInfo( + F.normalize_image_tensor, + kernel_name="normalize_image_tensor", + sample_inputs_fn=sample_inputs_normalize_image_tensor, + ), + KernelInfo( + F.normalize_video, + sample_inputs_fn=sample_inputs_normalize_video, + ), + ] ) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 9734a5dc30a..916861f4e04 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -17,6 +17,7 @@ make_masks, make_one_hot_labels, make_segmentation_mask, + make_videos, ) from torchvision.ops.boxes import box_iou from torchvision.prototype import features, transforms @@ -65,6 +66,7 @@ def parametrize_from_transforms(*transforms): make_vanilla_tensor_images, make_pil_images, make_masks, + make_videos, ]: inputs = list(creation_fn()) try: @@ -155,12 +157,14 @@ def test_mixup_cutmix(self, transform, input): features.ColorSpace.RGB, ], dtypes=[torch.uint8], - extra_dims=[(4,)], + extra_dims=[(), (4,)], + **(dict(num_frames=["random"]) if fn is make_videos else dict()), ) for fn in [ make_images, make_vanilla_tensor_images, make_pil_images, + make_videos, ] ), ) @@ -184,6 +188,7 @@ def test_auto_augment(self, transform, input): for fn in [ make_images, make_vanilla_tensor_images, + make_videos, ] ), ), @@ -200,6 +205,7 @@ def test_normalize(self, transform, input): make_images(extra_dims=[(4,)]), make_vanilla_tensor_images(), make_pil_images(), + make_videos(extra_dims=[()]), ), ) ] @@ -218,6 +224,7 @@ def test_random_resized_crop(self, transform, input): make_images, make_vanilla_tensor_images, make_pil_images, + make_videos, ) ] ), diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index a6523045c2d..5adea4d2663 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -129,6 +129,7 @@ def test_batched_vs_single(self, info, args_kwargs, device): # type all kernels should also work without differentiating between the two. Thus, we go with 2 here as # common ground. features.Mask: 2, + features.Video: 4, }.get(feature_type) if data_dims is None: raise pytest.UsageError( diff --git a/torchvision/prototype/features/__init__.py b/torchvision/prototype/features/__init__.py index df77e8b77b3..6fc2fb6ea94 100644 --- a/torchvision/prototype/features/__init__.py +++ b/torchvision/prototype/features/__init__.py @@ -13,3 +13,4 @@ ) from ._label import Label, OneHotLabel from ._mask import Mask +from ._video import ImageOrVideoType, ImageOrVideoTypeJIT, TensorImageOrVideoType, TensorImageOrVideoTypeJIT, Video diff --git a/torchvision/prototype/features/_video.py b/torchvision/prototype/features/_video.py new file mode 100644 index 00000000000..e19b6f7ed1c --- /dev/null +++ b/torchvision/prototype/features/_video.py @@ -0,0 +1,240 @@ +from __future__ import annotations + +import warnings +from typing import Any, cast, List, Optional, Tuple, Union + +import torch +from torchvision.transforms.functional import InterpolationMode + +from ._feature import _Feature, FillTypeJIT +from ._image import ColorSpace, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT + + +class Video(_Feature): + color_space: ColorSpace + + def __new__( + cls, + data: Any, + *, + color_space: Optional[Union[ColorSpace, str]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str, int]] = None, + requires_grad: bool = False, + ) -> Video: + data = torch.as_tensor(data, dtype=dtype, device=device) + if data.ndim < 4: + raise ValueError + video = super().__new__(cls, data, requires_grad=requires_grad) + + if color_space is None: + color_space = ColorSpace.from_tensor_shape(video.shape) # type: ignore[arg-type] + if color_space == ColorSpace.OTHER: + warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.") + elif isinstance(color_space, str): + color_space = ColorSpace.from_str(color_space.upper()) + elif not isinstance(color_space, ColorSpace): + raise ValueError + video.color_space = color_space + + return video + + def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] + return self._make_repr(color_space=self.color_space) + + @classmethod + def new_like( + cls, other: Video, data: Any, *, color_space: Optional[Union[ColorSpace, str]] = None, **kwargs: Any + ) -> Video: + return super().new_like( + other, data, color_space=color_space if color_space is not None else other.color_space, **kwargs + ) + + # TODO: rename this (and all instances of this term to spatial size) + @property + def image_size(self) -> Tuple[int, int]: + return cast(Tuple[int, int], tuple(self.shape[-2:])) + + @property + def num_channels(self) -> int: + return self.shape[-3] + + @property + def num_frames(self) -> int: + return self.shape[-4] + + def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Video: + if isinstance(color_space, str): + color_space = ColorSpace.from_str(color_space.upper()) + + return Video.new_like( + self, + self._F.convert_color_space_video( + self, old_color_space=self.color_space, new_color_space=color_space, copy=copy + ), + color_space=color_space, + ) + + def horizontal_flip(self) -> Video: + output = self._F.horizontal_flip_video(self) + return Video.new_like(self, output) + + def vertical_flip(self) -> Video: + output = self._F.vertical_flip_video(self) + return Video.new_like(self, output) + + def resize( # type: ignore[override] + self, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: bool = False, + ) -> Video: + output = self._F.resize_video(self, size, interpolation=interpolation, max_size=max_size, antialias=antialias) + return Video.new_like(self, output) + + def crop(self, top: int, left: int, height: int, width: int) -> Video: + output = self._F.crop_video(self, top, left, height, width) + return Video.new_like(self, output) + + def center_crop(self, output_size: List[int]) -> Video: + output = self._F.center_crop_video(self, output_size=output_size) + return Video.new_like(self, output) + + def resized_crop( + self, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias: bool = False, + ) -> Video: + output = self._F.resized_crop_video( + self, top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias + ) + return Video.new_like(self, output) + + def pad( + self, + padding: Union[int, List[int]], + fill: FillTypeJIT = None, + padding_mode: str = "constant", + ) -> Video: + output = self._F.pad_video(self, padding, fill=fill, padding_mode=padding_mode) + return Video.new_like(self, output) + + def rotate( + self, + angle: float, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + expand: bool = False, + fill: FillTypeJIT = None, + center: Optional[List[float]] = None, + ) -> Video: + output = self._F._geometry.rotate_video( + self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center + ) + return Video.new_like(self, output) + + def affine( + self, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: FillTypeJIT = None, + center: Optional[List[float]] = None, + ) -> Video: + output = self._F._geometry.affine_video( + self, + angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) + return Video.new_like(self, output) + + def perspective( + self, + perspective_coeffs: List[float], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: FillTypeJIT = None, + ) -> Video: + output = self._F._geometry.perspective_video(self, perspective_coeffs, interpolation=interpolation, fill=fill) + return Video.new_like(self, output) + + def elastic( + self, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: FillTypeJIT = None, + ) -> Video: + output = self._F._geometry.elastic_video(self, displacement, interpolation=interpolation, fill=fill) + return Video.new_like(self, output) + + def adjust_brightness(self, brightness_factor: float) -> Video: + output = self._F.adjust_brightness_video(self, brightness_factor=brightness_factor) + return Video.new_like(self, output) + + def adjust_saturation(self, saturation_factor: float) -> Video: + output = self._F.adjust_saturation_video(self, saturation_factor=saturation_factor) + return Video.new_like(self, output) + + def adjust_contrast(self, contrast_factor: float) -> Video: + output = self._F.adjust_contrast_video(self, contrast_factor=contrast_factor) + return Video.new_like(self, output) + + def adjust_sharpness(self, sharpness_factor: float) -> Video: + output = self._F.adjust_sharpness_video(self, sharpness_factor=sharpness_factor) + return Video.new_like(self, output) + + def adjust_hue(self, hue_factor: float) -> Video: + output = self._F.adjust_hue_video(self, hue_factor=hue_factor) + return Video.new_like(self, output) + + def adjust_gamma(self, gamma: float, gain: float = 1) -> Video: + output = self._F.adjust_gamma_video(self, gamma=gamma, gain=gain) + return Video.new_like(self, output) + + def posterize(self, bits: int) -> Video: + output = self._F.posterize_video(self, bits=bits) + return Video.new_like(self, output) + + def solarize(self, threshold: float) -> Video: + output = self._F.solarize_video(self, threshold=threshold) + return Video.new_like(self, output) + + def autocontrast(self) -> Video: + output = self._F.autocontrast_video(self) + return Video.new_like(self, output) + + def equalize(self) -> Video: + output = self._F.equalize_video(self) + return Video.new_like(self, output) + + def invert(self) -> Video: + output = self._F.invert_video(self) + return Video.new_like(self, output) + + def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Video: + output = self._F.gaussian_blur_video(self, kernel_size=kernel_size, sigma=sigma) + return Video.new_like(self, output) + + +VideoType = Union[torch.Tensor, Video] +VideoTypeJIT = torch.Tensor +LegacyVideoType = torch.Tensor +LegacyVideoTypeJIT = torch.Tensor +TensorVideoType = Union[torch.Tensor, Video] +TensorVideoTypeJIT = torch.Tensor + +ImageOrVideoType = Union[ImageType, VideoType] +ImageOrVideoTypeJIT = Union[ImageTypeJIT, VideoTypeJIT] +TensorImageOrVideoType = Union[TensorImageType, TensorVideoType] +TensorImageOrVideoTypeJIT = Union[TensorImageTypeJIT, TensorVideoTypeJIT] diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 3cd925fd996..311ad6d5aa4 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -15,7 +15,7 @@ class RandomErasing(_RandomApplyTransform): - _transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image) + _transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image, features.Video) def __init__( self, @@ -92,7 +92,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(i=i, j=j, h=h, w=w, v=v) - def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: + def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType: if params["v"] is not None: inpt = F.erase(inpt, **params, inplace=self.inplace) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index c98e5c36e4a..4732f88d4f2 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -31,40 +31,41 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: key = keys[int(torch.randint(len(keys), ()))] return key, dct[key] - def _extract_image( + def _extract_image_or_video( self, sample: Any, unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask), - ) -> Tuple[int, features.ImageType]: + ) -> Tuple[int, features.ImageOrVideoType]: sample_flat, _ = tree_flatten(sample) - images = [] + image_or_videos = [] for id, inpt in enumerate(sample_flat): - if _isinstance(inpt, (features.Image, PIL.Image.Image, features.is_simple_tensor)): - images.append((id, inpt)) + if _isinstance(inpt, (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)): + image_or_videos.append((id, inpt)) elif isinstance(inpt, unsupported_types): raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()") - if not images: + if not image_or_videos: raise TypeError("Found no image in the sample.") - if len(images) > 1: + if len(image_or_videos) > 1: raise TypeError( - f"Auto augment transformations are only properly defined for a single image, but found {len(images)}." + f"Auto augment transformations are only properly defined for a single image or video, " + f"but found {len(image_or_videos)}." ) - return images[0] + return image_or_videos[0] def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any: sample_flat, spec = tree_flatten(sample) sample_flat[id] = item return tree_unflatten(sample_flat, spec) - def _apply_image_transform( + def _apply_image_or_video_transform( self, - image: features.ImageType, + image: features.ImageOrVideoType, transform_id: str, magnitude: float, interpolation: InterpolationMode, fill: Dict[Type, features.FillType], - ) -> features.ImageType: + ) -> features.ImageOrVideoType: fill_ = fill[type(image)] fill_ = F._geometry._convert_fill_arg(fill_) @@ -276,8 +277,8 @@ def _get_policies( def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - id, image = self._extract_image(sample) - _, height, width = get_chw(image) + id, image_or_video = self._extract_image_or_video(sample) + _, height, width = get_chw(image_or_video) policy = self._policies[int(torch.randint(len(self._policies), ()))] @@ -295,11 +296,11 @@ def forward(self, *inputs: Any) -> Any: else: magnitude = 0.0 - image = self._apply_image_transform( - image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill + image_or_video = self._apply_image_or_video_transform( + image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill ) - return self._put_into_sample(sample, id, image) + return self._put_into_sample(sample, id, image_or_video) class RandAugment(_AutoAugmentBase): @@ -347,8 +348,8 @@ def __init__( def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - id, image = self._extract_image(sample) - _, height, width = get_chw(image) + id, image_or_video = self._extract_image_or_video(sample) + _, height, width = get_chw(image_or_video) for _ in range(self.num_ops): transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) @@ -359,11 +360,11 @@ def forward(self, *inputs: Any) -> Any: magnitude *= -1 else: magnitude = 0.0 - image = self._apply_image_transform( - image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill + image_or_video = self._apply_image_or_video_transform( + image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill ) - return self._put_into_sample(sample, id, image) + return self._put_into_sample(sample, id, image_or_video) class TrivialAugmentWide(_AutoAugmentBase): @@ -401,8 +402,8 @@ def __init__( def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - id, image = self._extract_image(sample) - _, height, width = get_chw(image) + id, image_or_video = self._extract_image_or_video(sample) + _, height, width = get_chw(image_or_video) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) @@ -414,10 +415,10 @@ def forward(self, *inputs: Any) -> Any: else: magnitude = 0.0 - image = self._apply_image_transform( - image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill + image_or_video = self._apply_image_or_video_transform( + image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill ) - return self._put_into_sample(sample, id, image) + return self._put_into_sample(sample, id, image_or_video) class AugMix(_AutoAugmentBase): @@ -471,27 +472,28 @@ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor: def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - id, orig_image = self._extract_image(sample) - _, height, width = get_chw(orig_image) + id, orig_image_or_video = self._extract_image_or_video(sample) + _, height, width = get_chw(orig_image_or_video) - if isinstance(orig_image, torch.Tensor): - image = orig_image + if isinstance(orig_image_or_video, torch.Tensor): + image_or_video = orig_image_or_video else: # isinstance(inpt, PIL.Image.Image): - image = F.pil_to_tensor(orig_image) + image_or_video = F.pil_to_tensor(orig_image_or_video) augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE - orig_dims = list(image.shape) - batch = image.view([1] * max(4 - image.ndim, 0) + orig_dims) + orig_dims = list(image_or_video.shape) + batch = image_or_video.view([1] * max(4 - image_or_video.ndim, 0) + orig_dims) batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) - # Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet - # with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image. + # Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a + # Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of + # augmented image or video. m = self._sample_dirichlet( torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1) ) - # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images. + # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos. combined_weights = self._sample_dirichlet( torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1) ) * m[:, 1].view([batch_dims[0], -1]) @@ -511,15 +513,15 @@ def forward(self, *inputs: Any) -> Any: else: magnitude = 0.0 - aug = self._apply_image_transform( + aug = self._apply_image_or_video_transform( aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill ) mix.add_(combined_weights[:, i].view(batch_dims) * aug) - mix = mix.view(orig_dims).to(dtype=image.dtype) + mix = mix.view(orig_dims).to(dtype=image_or_video.dtype) - if isinstance(orig_image, features.Image): - mix = features.Image.new_like(orig_image, mix) - elif isinstance(orig_image, PIL.Image.Image): + if isinstance(orig_image_or_video, (features.Image, features.Video)): + mix = type(orig_image_or_video).new_like(orig_image_or_video, mix) # type: ignore[arg-type] + elif isinstance(orig_image_or_video, PIL.Image.Image): mix = F.to_image_pil(mix) return self._put_into_sample(sample, id, mix) diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index e0ee8d1b96a..451b57b66c0 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -82,7 +82,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class RandomPhotometricDistort(Transform): - _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor) + _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video) def __init__( self, @@ -110,20 +110,22 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None, ) - def _permute_channels(self, inpt: features.ImageType, permutation: torch.Tensor) -> features.ImageType: + def _permute_channels( + self, inpt: features.ImageOrVideoType, permutation: torch.Tensor + ) -> features.ImageOrVideoType: if isinstance(inpt, PIL.Image.Image): inpt = F.pil_to_tensor(inpt) output = inpt[..., permutation, :, :] - if isinstance(inpt, features.Image): - output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER) + if isinstance(inpt, (features.Image, features.Video)): + output = type(inpt).new_like(inpt, output, color_space=features.ColorSpace.OTHER) # type: ignore[arg-type] elif isinstance(inpt, PIL.Image.Image): output = F.to_image_pil(output) return output - def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: + def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType: if params["brightness"]: inpt = F.adjust_brightness( inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1]) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 008d4d195cb..1f132ec9238 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -855,8 +855,10 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return inpt def forward(self, *inputs: Any) -> Any: - if not has_any(inputs, PIL.Image.Image, features.Image, features.is_simple_tensor): - raise TypeError(f"{type(self).__name__}() requires input sample to contain an tensor or PIL image.") + if not has_any(inputs, PIL.Image.Image, features.Image, features.is_simple_tensor, features.Video): + raise TypeError( + f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video." + ) if has_any(inputs, features.BoundingBox) and not has_any(inputs, features.Label, features.OneHotLabel): raise TypeError( diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 2ea3014aa6c..cb090492a48 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -34,7 +34,7 @@ def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> class ConvertColorSpace(Transform): - _transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image) + _transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image, features.Video) def __init__( self, @@ -54,7 +54,7 @@ def __init__( self.copy = copy - def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: + def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType: return F.convert_color_space( inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy ) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 976e9f8b5ff..2531bf8f6fa 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -38,7 +38,7 @@ def extra_repr(self) -> str: class LinearTransformation(Transform): - _transformed_types = (features.is_simple_tensor, features.Image) + _transformed_types = (features.is_simple_tensor, features.Image, features.Video) def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor): super().__init__() @@ -68,7 +68,7 @@ def forward(self, *inputs: Any) -> Any: return super().forward(*inputs) - def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> torch.Tensor: + def _transform(self, inpt: features.TensorImageOrVideoType, params: Dict[str, Any]) -> torch.Tensor: # Image instance after linear transformation is not Image anymore due to unknown data range # Thus we will return Tensor for input Image @@ -93,7 +93,7 @@ def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> class Normalize(Transform): - _transformed_types = (features.Image, features.is_simple_tensor) + _transformed_types = (features.Image, features.is_simple_tensor, features.Video) def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False): super().__init__() @@ -101,7 +101,7 @@ def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = self.std = list(std) self.inplace = inplace - def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> torch.Tensor: + def _transform(self, inpt: features.TensorImageOrVideoType, params: Dict[str, Any]) -> torch.Tensor: return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace) def forward(self, *inpts: Any) -> Any: diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 219e6e50586..a76891a348a 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -82,10 +82,10 @@ def query_chw(sample: Any) -> Tuple[int, int, int]: chws = { get_chw(item) for item in flat_sample - if isinstance(item, (features.Image, PIL.Image.Image)) or features.is_simple_tensor(item) + if isinstance(item, (features.Image, PIL.Image.Image, features.Video)) or features.is_simple_tensor(item) } if not chws: - raise TypeError("No image was found in the sample") + raise TypeError("No image or video was found in the sample") elif len(chws) > 1: raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}") return chws.pop() diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index f081d101dff..cb801df73c7 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -6,6 +6,7 @@ convert_format_bounding_box, convert_color_space_image_tensor, convert_color_space_image_pil, + convert_color_space_video, convert_color_space, get_dimensions, get_image_num_channels, @@ -13,41 +14,52 @@ get_spatial_size, ) # usort: skip -from ._augment import erase, erase_image_pil, erase_image_tensor +from ._augment import erase, erase_image_pil, erase_image_tensor, erase_video from ._color import ( adjust_brightness, adjust_brightness_image_pil, adjust_brightness_image_tensor, + adjust_brightness_video, adjust_contrast, adjust_contrast_image_pil, adjust_contrast_image_tensor, + adjust_contrast_video, adjust_gamma, adjust_gamma_image_pil, adjust_gamma_image_tensor, + adjust_gamma_video, adjust_hue, adjust_hue_image_pil, adjust_hue_image_tensor, + adjust_hue_video, adjust_saturation, adjust_saturation_image_pil, adjust_saturation_image_tensor, + adjust_saturation_video, adjust_sharpness, adjust_sharpness_image_pil, adjust_sharpness_image_tensor, + adjust_sharpness_video, autocontrast, autocontrast_image_pil, autocontrast_image_tensor, + autocontrast_video, equalize, equalize_image_pil, equalize_image_tensor, + equalize_video, invert, invert_image_pil, invert_image_tensor, + invert_video, posterize, posterize_image_pil, posterize_image_tensor, + posterize_video, solarize, solarize_image_pil, solarize_image_tensor, + solarize_video, ) from ._geometry import ( affine, @@ -55,22 +67,26 @@ affine_image_pil, affine_image_tensor, affine_mask, + affine_video, center_crop, center_crop_bounding_box, center_crop_image_pil, center_crop_image_tensor, center_crop_mask, + center_crop_video, crop, crop_bounding_box, crop_image_pil, crop_image_tensor, crop_mask, + crop_video, elastic, elastic_bounding_box, elastic_image_pil, elastic_image_tensor, elastic_mask, elastic_transform, + elastic_video, five_crop, five_crop_image_pil, five_crop_image_tensor, @@ -80,31 +96,37 @@ horizontal_flip_image_pil, horizontal_flip_image_tensor, horizontal_flip_mask, + horizontal_flip_video, pad, pad_bounding_box, pad_image_pil, pad_image_tensor, pad_mask, + pad_video, perspective, perspective_bounding_box, perspective_image_pil, perspective_image_tensor, perspective_mask, + perspective_video, resize, resize_bounding_box, resize_image_pil, resize_image_tensor, resize_mask, + resize_video, resized_crop, resized_crop_bounding_box, resized_crop_image_pil, resized_crop_image_tensor, resized_crop_mask, + resized_crop_video, rotate, rotate_bounding_box, rotate_image_pil, rotate_image_tensor, rotate_mask, + rotate_video, ten_crop, ten_crop_image_pil, ten_crop_image_tensor, @@ -113,9 +135,18 @@ vertical_flip_image_pil, vertical_flip_image_tensor, vertical_flip_mask, + vertical_flip_video, vflip, ) -from ._misc import gaussian_blur, gaussian_blur_image_pil, gaussian_blur_image_tensor, normalize, normalize_image_tensor +from ._misc import ( + gaussian_blur, + gaussian_blur_image_pil, + gaussian_blur_image_tensor, + gaussian_blur_video, + normalize, + normalize_image_tensor, + normalize_video, +) from ._type_conversion import ( convert_image_dtype, decode_image_with_pil, diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index fb48c35888d..976feb99ea2 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -17,19 +17,25 @@ def erase_image_pil( return to_pil_image(output, mode=image.mode) +def erase_video( + video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False +) -> torch.Tensor: + return erase_image_tensor(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + + def erase( - inpt: features.ImageTypeJIT, + inpt: features.ImageOrVideoTypeJIT, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False, -) -> features.ImageTypeJIT: +) -> features.ImageOrVideoTypeJIT: if isinstance(inpt, torch.Tensor): output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) - if not torch.jit.is_scripting() and isinstance(inpt, features.Image): - output = features.Image.new_like(inpt, output) + if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): + output = type(inpt).new_like(inpt, output) # type: ignore[arg-type] return output else: # isinstance(inpt, PIL.Image.Image): return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index f375cb048c6..d11dd3c3b9f 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -2,10 +2,16 @@ from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT +from ._meta import get_dimensions_image_tensor + adjust_brightness_image_tensor = _FT.adjust_brightness adjust_brightness_image_pil = _FP.adjust_brightness +def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor: + return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor) + + def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) @@ -19,6 +25,10 @@ def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) -> adjust_saturation_image_pil = _FP.adjust_saturation +def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> torch.Tensor: + return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor) + + def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) @@ -32,6 +42,10 @@ def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) -> adjust_contrast_image_pil = _FP.adjust_contrast +def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.Tensor: + return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor) + + def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) @@ -41,10 +55,40 @@ def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> feat return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor) -adjust_sharpness_image_tensor = _FT.adjust_sharpness +def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor: + num_channels, height, width = get_dimensions_image_tensor(image) + if num_channels not in (1, 3): + raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}") + + if sharpness_factor < 0: + raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.") + + if image.numel() == 0 or height <= 2 or width <= 2: + return image + + shape = image.shape + + if image.ndim > 4: + image = image.view(-1, num_channels, height, width) + needs_unsquash = True + else: + needs_unsquash = False + + output = _FT._blend(image, _FT._blurred_degenerate_image(image), sharpness_factor) + + if needs_unsquash: + output = output.view(shape) + + return output + + adjust_sharpness_image_pil = _FP.adjust_sharpness +def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torch.Tensor: + return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor) + + def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) @@ -58,6 +102,10 @@ def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> fe adjust_hue_image_pil = _FP.adjust_hue +def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: + return adjust_hue_image_tensor(video, hue_factor=hue_factor) + + def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) @@ -71,6 +119,10 @@ def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.Input adjust_gamma_image_pil = _FP.adjust_gamma +def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor: + return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain) + + def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) @@ -84,6 +136,10 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> posterize_image_pil = _FP.posterize +def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: + return posterize_image_tensor(video, bits=bits) + + def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return posterize_image_tensor(inpt, bits=bits) @@ -97,6 +153,10 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT: solarize_image_pil = _FP.solarize +def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: + return solarize_image_tensor(video, threshold=threshold) + + def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return solarize_image_tensor(inpt, threshold=threshold) @@ -110,6 +170,10 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp autocontrast_image_pil = _FP.autocontrast +def autocontrast_video(video: torch.Tensor) -> torch.Tensor: + return autocontrast_image_tensor(video) + + def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return autocontrast_image_tensor(inpt) @@ -119,10 +183,35 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT: return autocontrast_image_pil(inpt) -equalize_image_tensor = _FT.equalize +def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: + if image.dtype != torch.uint8: + raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}") + + num_channels, height, width = get_dimensions_image_tensor(image) + if num_channels not in (1, 3): + raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}") + + if image.numel() == 0: + return image + elif image.ndim == 2: + return _FT._scale_channel(image) + else: + return torch.stack( + [ + # TODO: when merging transforms v1 and v2, we can inline this function call + _FT._equalize_single_image(single_image) + for single_image in image.view(-1, num_channels, height, width) + ] + ).view(image.shape) + + equalize_image_pil = _FP.equalize +def equalize_video(video: torch.Tensor) -> torch.Tensor: + return equalize_image_tensor(video) + + def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return equalize_image_tensor(inpt) @@ -136,6 +225,10 @@ def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT: invert_image_pil = _FP.invert +def invert_video(video: torch.Tensor) -> torch.Tensor: + return invert_image_tensor(video) + + def invert(inpt: features.InputTypeJIT) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return invert_image_tensor(inpt) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 7a291967bfd..f205b5aeabe 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -47,6 +47,10 @@ def horizontal_flip_bounding_box( ).view(shape) +def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: + return horizontal_flip_image_tensor(video) + + def horizontal_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return horizontal_flip_image_tensor(inpt) @@ -80,6 +84,10 @@ def vertical_flip_bounding_box( ).view(shape) +def vertical_flip_video(video: torch.Tensor) -> torch.Tensor: + return vertical_flip_image_tensor(video) + + def vertical_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return vertical_flip_image_tensor(inpt) @@ -185,6 +193,16 @@ def resize_bounding_box( ) +def resize_video( + video: torch.Tensor, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: bool = False, +) -> torch.Tensor: + return resize_image_tensor(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) + + def resize( inpt: features.InputTypeJIT, size: List[int], @@ -441,6 +459,28 @@ def affine_mask( return output +def affine_video( + video: torch.Tensor, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: features.FillTypeJIT = None, + center: Optional[List[float]] = None, +) -> torch.Tensor: + return affine_image_tensor( + video, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) + + def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT: # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # So, we can't reassign fill to 0 @@ -614,6 +654,17 @@ def rotate_mask( return output +def rotate_video( + video: torch.Tensor, + angle: float, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + expand: bool = False, + fill: features.FillTypeJIT = None, + center: Optional[List[float]] = None, +) -> torch.Tensor: + return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) + + def rotate( inpt: features.InputTypeJIT, angle: float, @@ -751,6 +802,15 @@ def pad_bounding_box( return bounding_box, (height, width) +def pad_video( + video: torch.Tensor, + padding: Union[int, List[int]], + fill: features.FillTypeJIT = None, + padding_mode: str = "constant", +) -> torch.Tensor: + return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode) + + def pad( inpt: features.InputTypeJIT, padding: Union[int, List[int]], @@ -798,6 +858,10 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) return crop_image_tensor(mask, top, left, height, width) +def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: + return crop_image_tensor(video, top, left, height, width) + + def crop(inpt: features.InputTypeJIT, top: int, left: int, height: int, width: int) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return crop_image_tensor(inpt, top, left, height, width) @@ -932,6 +996,33 @@ def perspective_mask( return output +def perspective_video( + video: torch.Tensor, + perspective_coeffs: List[float], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: features.FillTypeJIT = None, +) -> torch.Tensor: + # TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when + # https://github.com/pytorch/vision/issues/6670 is resolved. + if video.numel() == 0: + return video + + shape = video.shape + + if video.ndim > 4: + video = video.view((-1,) + shape[-3:]) + needs_unsquash = True + else: + needs_unsquash = False + + output = perspective_image_tensor(video, perspective_coeffs, interpolation=interpolation, fill=fill) + + if needs_unsquash: + output = output.view(shape) + + return output + + def perspective( inpt: features.InputTypeJIT, perspective_coeffs: List[float], @@ -1026,6 +1117,33 @@ def elastic_mask( return output +def elastic_video( + video: torch.Tensor, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: features.FillTypeJIT = None, +) -> torch.Tensor: + # TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when + # https://github.com/pytorch/vision/issues/6670 is resolved. + if video.numel() == 0: + return video + + shape = video.shape + + if video.ndim > 4: + video = video.view((-1,) + shape[-3:]) + needs_unsquash = True + else: + needs_unsquash = False + + output = elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill) + + if needs_unsquash: + output = output.view(shape) + + return output + + def elastic( inpt: features.InputTypeJIT, displacement: torch.Tensor, @@ -1128,6 +1246,10 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor return output +def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor: + return center_crop_image_tensor(video, output_size) + + def center_crop(inpt: features.InputTypeJIT, output_size: List[int]) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return center_crop_image_tensor(inpt, output_size) @@ -1190,6 +1312,21 @@ def resized_crop_mask( return resize_mask(mask, size) +def resized_crop_video( + video: torch.Tensor, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias: bool = False, +) -> torch.Tensor: + return resized_crop_image_tensor( + video, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation + ) + + def resized_crop( inpt: features.InputTypeJIT, top: int, diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 90cfffcf276..1e53edf3940 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -11,10 +11,12 @@ # TODO: Should this be prefixed with `_` similar to other methods that don't get exposed by init? -def get_chw(image: features.ImageTypeJIT) -> Tuple[int, int, int]: - if isinstance(image, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(image, features.Image)): +def get_chw(image: features.ImageOrVideoTypeJIT) -> Tuple[int, int, int]: + if isinstance(image, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video)) + ): channels, height, width = get_dimensions_image_tensor(image) - elif isinstance(image, features.Image): + elif isinstance(image, (features.Image, features.Video)): channels = image.num_channels height, width = image.image_size else: # isinstance(image, PIL.Image.Image) @@ -29,11 +31,11 @@ def get_chw(image: features.ImageTypeJIT) -> Tuple[int, int, int]: # detailed above. -def get_dimensions(image: features.ImageTypeJIT) -> List[int]: +def get_dimensions(image: features.ImageOrVideoTypeJIT) -> List[int]: return list(get_chw(image)) -def get_num_channels(image: features.ImageTypeJIT) -> int: +def get_num_channels(image: features.ImageOrVideoTypeJIT) -> int: num_channels, *_ = get_chw(image) return num_channels @@ -43,7 +45,7 @@ def get_num_channels(image: features.ImageTypeJIT) -> int: get_image_num_channels = get_num_channels -def get_spatial_size(image: features.ImageTypeJIT) -> List[int]: +def get_spatial_size(image: features.ImageOrVideoTypeJIT) -> List[int]: _, *size = get_chw(image) return size @@ -207,13 +209,23 @@ def convert_color_space_image_pil( return image.convert(new_mode) +def convert_color_space_video( + video: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace, copy: bool = True +) -> torch.Tensor: + return convert_color_space_image_tensor( + video, old_color_space=old_color_space, new_color_space=new_color_space, copy=copy + ) + + def convert_color_space( - inpt: features.ImageTypeJIT, + inpt: features.ImageOrVideoTypeJIT, color_space: ColorSpace, old_color_space: Optional[ColorSpace] = None, copy: bool = True, -) -> features.ImageTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image)): +) -> features.ImageOrVideoTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) + ): if old_color_space is None: raise RuntimeError( "In order to convert the color space of simple tensor images, " @@ -222,7 +234,7 @@ def convert_color_space( return convert_color_space_image_tensor( inpt, old_color_space=old_color_space, new_color_space=color_space, copy=copy ) - elif isinstance(inpt, features.Image): + elif isinstance(inpt, (features.Image, features.Video)): return inpt.to_color_space(color_space, copy=copy) else: - return cast(features.ImageTypeJIT, convert_color_space_image_pil(inpt, color_space, copy=copy)) + return cast(features.ImageOrVideoTypeJIT, convert_color_space_image_pil(inpt, color_space, copy=copy)) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 6f35781d4a9..7b3773e63a1 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -9,18 +9,22 @@ normalize_image_tensor = _FT.normalize +def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: + return normalize_image_tensor(video, mean, std, inplace=inplace) + + def normalize( - inpt: features.TensorImageTypeJIT, mean: List[float], std: List[float], inplace: bool = False + inpt: features.TensorImageOrVideoTypeJIT, mean: List[float], std: List[float], inplace: bool = False ) -> torch.Tensor: if torch.jit.is_scripting(): correct_type = isinstance(inpt, torch.Tensor) else: - correct_type = features.is_simple_tensor(inpt) or isinstance(inpt, features.Image) + correct_type = features.is_simple_tensor(inpt) or isinstance(inpt, (features.Image, features.Video)) inpt = inpt.as_subclass(torch.Tensor) if not correct_type: raise TypeError(f"img should be Tensor Image. Got {type(inpt)}") - # Image instance after normalization is not Image anymore due to unknown data range + # Image or Video type should not be retained after normalization due to unknown data range # Thus we return Tensor for input Image return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) @@ -64,6 +68,30 @@ def gaussian_blur_image_pil( return to_pil_image(output, mode=image.mode) +def gaussian_blur_video( + video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None +) -> torch.Tensor: + # TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when + # https://github.com/pytorch/vision/issues/6670 is resolved. + if video.numel() == 0: + return video + + shape = video.shape + + if video.ndim > 4: + video = video.view((-1,) + shape[-3:]) + needs_unsquash = True + else: + needs_unsquash = False + + output = gaussian_blur_image_tensor(video, kernel_size, sigma) + + if needs_unsquash: + output = output.view(shape) + + return output + + def gaussian_blur( inpt: features.InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None ) -> features.InputTypeJIT: