From 69246ef4d6649783a04f67dce1b14628611c5087 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 20 Jan 2023 10:43:06 +0100 Subject: [PATCH 1/3] add reference test for normalize_image_tensor --- test/prototype_transforms_kernel_infos.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index ded888a4a00..e1420d1cc7b 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -2232,6 +2232,22 @@ def sample_inputs_normalize_image_tensor(): yield ArgsKwargs(image_loader, mean=mean, std=std) +def reference_normalize_image_tensor(image, mean, std, inplace=False): + mean = torch.tensor(mean).view(-1, 1, 1) + std = torch.tensor(std).view(-1, 1, 1) + + sub = torch.Tensor.sub_ if inplace else torch.Tensor.sub + return sub(image, mean).div_(std) + + +def reference_inputs_normalize_image_tensor(): + yield ArgsKwargs( + make_image_loader(size=(32, 32), color_space=datapoints.ColorSpace.RGB, extra_dims=[1]), + mean=[0.5, 0.5, 0.5], + std=[1.0, 1.0, 1.0], + ) + + def sample_inputs_normalize_video(): mean, std = _NORMALIZE_MEANS_STDS[0] for video_loader in make_video_loaders( @@ -2246,6 +2262,8 @@ def sample_inputs_normalize_video(): F.normalize_image_tensor, kernel_name="normalize_image_tensor", sample_inputs_fn=sample_inputs_normalize_image_tensor, + reference_fn=reference_normalize_image_tensor, + reference_inputs_fn=reference_inputs_normalize_image_tensor, test_marks=[ xfail_jit_python_scalar_arg("mean"), xfail_jit_python_scalar_arg("std"), From 09b7fb30672e006c9e8758a9b6de8b8224415d70 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 23 Jan 2023 10:37:39 +0100 Subject: [PATCH 2/3] port stats test --- test/test_prototype_transforms_functional.py | 23 +++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index bc299fd1f50..3387810eff8 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -13,7 +13,12 @@ import torchvision.prototype.transforms.utils from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed -from prototype_common_utils import assert_close, make_bounding_boxes, parametrized_error_message +from prototype_common_utils import ( + assert_close, + DEFAULT_SQUARE_SPATIAL_SIZE, + make_bounding_boxes, + parametrized_error_message, +) from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS from prototype_transforms_kernel_infos import KERNEL_INFOS from torch.utils._pytree import tree_map @@ -501,6 +506,22 @@ def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device): assert output.device == input.device +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("num_channels", [3]) +def test_normalize_image_tensor_stats(device, num_channels): + stats = pytest.importorskip("scipy.stats", reason="SciPy is not available") + + def assert_samples_from_standard_normal(t): + p_value = stats.kstest(t.flatten(), cdf="norm", args=(0, 1)).pvalue + return p_value > 1e-4 + + image = torch.rand(num_channels, DEFAULT_SQUARE_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE) + mean = image.mean(dim=(1, 2)).tolist() + std = image.std(dim=(1, 2)).tolist() + + assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std)) + + # TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in # `prototype_transforms_kernel_infos.py` From c2e0b6c854d30c8aaf65d4e37898f2ddca572833 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 23 Jan 2023 11:31:33 +0100 Subject: [PATCH 3/3] Update test/test_prototype_transforms_functional.py --- test/test_prototype_transforms_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 6bb70c44e04..7f0781fb010 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -544,7 +544,7 @@ def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device): @pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize("num_channels", [3]) +@pytest.mark.parametrize("num_channels", [1, 3]) def test_normalize_image_tensor_stats(device, num_channels): stats = pytest.importorskip("scipy.stats", reason="SciPy is not available")