diff --git a/test/prototype_transforms_dispatcher_infos.py b/test/prototype_transforms_dispatcher_infos.py index b92278fef56..90e2e7f570f 100644 --- a/test/prototype_transforms_dispatcher_infos.py +++ b/test/prototype_transforms_dispatcher_infos.py @@ -426,7 +426,6 @@ def fill_sequence_needs_broadcast(args_kwargs): datapoints.Video: F.normalize_video, }, test_marks=[ - skip_dispatch_feature, xfail_jit_python_scalar_arg("mean"), xfail_jit_python_scalar_arg("std"), ], diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index d199625df0f..a80e0f4570d 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -13,7 +13,7 @@ import torchvision.prototype.transforms.utils from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed -from prototype_common_utils import assert_close, make_bounding_boxes, make_image, parametrized_error_message +from prototype_common_utils import assert_close, make_bounding_boxes, parametrized_error_message from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS from prototype_transforms_kernel_infos import KERNEL_INFOS from torch.utils._pytree import tree_map @@ -1185,18 +1185,6 @@ def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize, torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}") -def test_normalize_output_type(): - inpt = torch.rand(1, 3, 32, 32) - output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]) - assert type(output) is torch.Tensor - torch.testing.assert_close(inpt - 0.5, output) - - inpt = make_image(color_space=datapoints.ColorSpace.RGB) - output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]) - assert type(output) is torch.Tensor - torch.testing.assert_close(inpt - 0.5, output) - - @pytest.mark.parametrize( "inpt", [ diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index fc20691100f..d674745a716 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -289,6 +289,10 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N ) return Image.wrap_like(self, output) + def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Image: + output = self._F.normalize_image_tensor(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace) + return Image.wrap_like(self, output) + ImageType = Union[torch.Tensor, PIL.Image.Image, Image] ImageTypeJIT = torch.Tensor diff --git a/torchvision/prototype/datapoints/_video.py b/torchvision/prototype/datapoints/_video.py index 5c55d23a149..c7273874655 100644 --- a/torchvision/prototype/datapoints/_video.py +++ b/torchvision/prototype/datapoints/_video.py @@ -241,6 +241,10 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N output = self._F.gaussian_blur_video(self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma) return Video.wrap_like(self, output) + def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Video: + output = self._F.normalize_video(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace) + return Video.wrap_like(self, output) + VideoType = Union[torch.Tensor, Video] VideoTypeJIT = torch.Tensor diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 0254dd7c225..0eb20e57764 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -82,6 +82,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return output +# TODO: This class seems to be untested class RandomPhotometricDistort(Transform): _transformed_types = ( datapoints.Image, @@ -119,15 +120,14 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _permute_channels( self, inpt: Union[datapoints.ImageType, datapoints.VideoType], permutation: torch.Tensor ) -> Union[datapoints.ImageType, datapoints.VideoType]: - if isinstance(inpt, PIL.Image.Image): + + orig_inpt = inpt + if isinstance(orig_inpt, PIL.Image.Image): inpt = F.pil_to_tensor(inpt) output = inpt[..., permutation, :, :] - if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.OTHER) # type: ignore[arg-type] - - elif isinstance(inpt, PIL.Image.Image): + if isinstance(orig_inpt, PIL.Image.Image): output = F.to_image_pil(output) return output diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 59570768160..9d0a00f88c3 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -60,18 +60,14 @@ def normalize( ) -> torch.Tensor: if not torch.jit.is_scripting(): _log_api_usage_once(normalize) - - if isinstance(inpt, (datapoints.Image, datapoints.Video)): - inpt = inpt.as_subclass(torch.Tensor) - elif not is_simple_tensor(inpt): - raise TypeError( - f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " - f"but got {type(inpt)} instead." - ) - - # Image or Video type should not be retained after normalization due to unknown data range - # Thus we return Tensor for input Image - return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) + elif isinstance(inpt, (datapoints.Image, datapoints.Video)): + return inpt.normalize(mean=mean, std=std, inplace=inplace) + else: + raise TypeError( + f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead." + ) def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor: