diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index ded888a4a00..e1420d1cc7b 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -2232,6 +2232,22 @@ def sample_inputs_normalize_image_tensor(): yield ArgsKwargs(image_loader, mean=mean, std=std) +def reference_normalize_image_tensor(image, mean, std, inplace=False): + mean = torch.tensor(mean).view(-1, 1, 1) + std = torch.tensor(std).view(-1, 1, 1) + + sub = torch.Tensor.sub_ if inplace else torch.Tensor.sub + return sub(image, mean).div_(std) + + +def reference_inputs_normalize_image_tensor(): + yield ArgsKwargs( + make_image_loader(size=(32, 32), color_space=datapoints.ColorSpace.RGB, extra_dims=[1]), + mean=[0.5, 0.5, 0.5], + std=[1.0, 1.0, 1.0], + ) + + def sample_inputs_normalize_video(): mean, std = _NORMALIZE_MEANS_STDS[0] for video_loader in make_video_loaders( @@ -2246,6 +2262,8 @@ def sample_inputs_normalize_video(): F.normalize_image_tensor, kernel_name="normalize_image_tensor", sample_inputs_fn=sample_inputs_normalize_image_tensor, + reference_fn=reference_normalize_image_tensor, + reference_inputs_fn=reference_inputs_normalize_image_tensor, test_marks=[ xfail_jit_python_scalar_arg("mean"), xfail_jit_python_scalar_arg("std"), diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 102f78e6e11..7f0781fb010 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -13,7 +13,12 @@ import torchvision.prototype.transforms.utils from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed -from prototype_common_utils import assert_close, make_bounding_boxes, parametrized_error_message +from prototype_common_utils import ( + assert_close, + DEFAULT_SQUARE_SPATIAL_SIZE, + make_bounding_boxes, + parametrized_error_message, +) from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS from prototype_transforms_kernel_infos import KERNEL_INFOS from torch.utils._pytree import tree_map @@ -538,6 +543,22 @@ def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device): assert output.device == input.device +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("num_channels", [1, 3]) +def test_normalize_image_tensor_stats(device, num_channels): + stats = pytest.importorskip("scipy.stats", reason="SciPy is not available") + + def assert_samples_from_standard_normal(t): + p_value = stats.kstest(t.flatten(), cdf="norm", args=(0, 1)).pvalue + return p_value > 1e-4 + + image = torch.rand(num_channels, DEFAULT_SQUARE_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE) + mean = image.mean(dim=(1, 2)).tolist() + std = image.std(dim=(1, 2)).tolist() + + assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std)) + + # TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in # `prototype_transforms_kernel_infos.py`