|
13 | 13 |
|
14 | 14 | import torchvision.prototype.transforms.utils
|
15 | 15 | from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed
|
16 |
| -from prototype_common_utils import assert_close, make_bounding_boxes, parametrized_error_message |
| 16 | +from prototype_common_utils import ( |
| 17 | + assert_close, |
| 18 | + DEFAULT_SQUARE_SPATIAL_SIZE, |
| 19 | + make_bounding_boxes, |
| 20 | + parametrized_error_message, |
| 21 | +) |
17 | 22 | from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
|
18 | 23 | from prototype_transforms_kernel_infos import KERNEL_INFOS
|
19 | 24 | from torch.utils._pytree import tree_map
|
@@ -538,6 +543,22 @@ def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device):
|
538 | 543 | assert output.device == input.device
|
539 | 544 |
|
540 | 545 |
|
| 546 | +@pytest.mark.parametrize("device", cpu_and_gpu()) |
| 547 | +@pytest.mark.parametrize("num_channels", [1, 3]) |
| 548 | +def test_normalize_image_tensor_stats(device, num_channels): |
| 549 | + stats = pytest.importorskip("scipy.stats", reason="SciPy is not available") |
| 550 | + |
| 551 | + def assert_samples_from_standard_normal(t): |
| 552 | + p_value = stats.kstest(t.flatten(), cdf="norm", args=(0, 1)).pvalue |
| 553 | + return p_value > 1e-4 |
| 554 | + |
| 555 | + image = torch.rand(num_channels, DEFAULT_SQUARE_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE) |
| 556 | + mean = image.mean(dim=(1, 2)).tolist() |
| 557 | + std = image.std(dim=(1, 2)).tolist() |
| 558 | + |
| 559 | + assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std)) |
| 560 | + |
| 561 | + |
541 | 562 | # TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
|
542 | 563 | # `prototype_transforms_kernel_infos.py`
|
543 | 564 |
|
|
0 commit comments