Skip to content

Commit c206a47

Browse files
authored
add reference test for normalize_image_tensor (#7119)
1 parent d2d448c commit c206a47

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

test/prototype_transforms_kernel_infos.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2232,6 +2232,22 @@ def sample_inputs_normalize_image_tensor():
22322232
yield ArgsKwargs(image_loader, mean=mean, std=std)
22332233

22342234

2235+
def reference_normalize_image_tensor(image, mean, std, inplace=False):
2236+
mean = torch.tensor(mean).view(-1, 1, 1)
2237+
std = torch.tensor(std).view(-1, 1, 1)
2238+
2239+
sub = torch.Tensor.sub_ if inplace else torch.Tensor.sub
2240+
return sub(image, mean).div_(std)
2241+
2242+
2243+
def reference_inputs_normalize_image_tensor():
2244+
yield ArgsKwargs(
2245+
make_image_loader(size=(32, 32), color_space=datapoints.ColorSpace.RGB, extra_dims=[1]),
2246+
mean=[0.5, 0.5, 0.5],
2247+
std=[1.0, 1.0, 1.0],
2248+
)
2249+
2250+
22352251
def sample_inputs_normalize_video():
22362252
mean, std = _NORMALIZE_MEANS_STDS[0]
22372253
for video_loader in make_video_loaders(
@@ -2246,6 +2262,8 @@ def sample_inputs_normalize_video():
22462262
F.normalize_image_tensor,
22472263
kernel_name="normalize_image_tensor",
22482264
sample_inputs_fn=sample_inputs_normalize_image_tensor,
2265+
reference_fn=reference_normalize_image_tensor,
2266+
reference_inputs_fn=reference_inputs_normalize_image_tensor,
22492267
test_marks=[
22502268
xfail_jit_python_scalar_arg("mean"),
22512269
xfail_jit_python_scalar_arg("std"),

test/test_prototype_transforms_functional.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313

1414
import torchvision.prototype.transforms.utils
1515
from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed
16-
from prototype_common_utils import assert_close, make_bounding_boxes, parametrized_error_message
16+
from prototype_common_utils import (
17+
assert_close,
18+
DEFAULT_SQUARE_SPATIAL_SIZE,
19+
make_bounding_boxes,
20+
parametrized_error_message,
21+
)
1722
from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
1823
from prototype_transforms_kernel_infos import KERNEL_INFOS
1924
from torch.utils._pytree import tree_map
@@ -538,6 +543,22 @@ def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device):
538543
assert output.device == input.device
539544

540545

546+
@pytest.mark.parametrize("device", cpu_and_gpu())
547+
@pytest.mark.parametrize("num_channels", [1, 3])
548+
def test_normalize_image_tensor_stats(device, num_channels):
549+
stats = pytest.importorskip("scipy.stats", reason="SciPy is not available")
550+
551+
def assert_samples_from_standard_normal(t):
552+
p_value = stats.kstest(t.flatten(), cdf="norm", args=(0, 1)).pvalue
553+
return p_value > 1e-4
554+
555+
image = torch.rand(num_channels, DEFAULT_SQUARE_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE)
556+
mean = image.mean(dim=(1, 2)).tolist()
557+
std = image.std(dim=(1, 2)).tolist()
558+
559+
assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std))
560+
561+
541562
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
542563
# `prototype_transforms_kernel_infos.py`
543564

0 commit comments

Comments
 (0)