From cb7d27aaf0e4ec6193cfda19a514442d95dcaf3c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sun, 28 Nov 2021 10:38:51 +0000 Subject: [PATCH 1/2] Fix bug on autocontrast when `min==max` --- test/test_functional_tensor.py | 12 ++++++++++++ torchvision/transforms/functional_tensor.py | 6 +++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 24a7523b62a..badf4c5365d 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -820,6 +820,18 @@ def test_autocontrast(device, dtype, channels): ) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) +@pytest.mark.parametrize("channels", [1, 3]) +def test_autocontrast_equal_minmax(device, dtype, channels): + a = _create_data_batch(32, 32, num_samples=1, channels=channels, device=device) + a = a / 2.0 + 0.3 + assert (F.autocontrast(a)[0] == F.autocontrast(a[0])).all() + + a[0, 0] = 0.7 + assert (F.autocontrast(a)[0] == F.autocontrast(a[0])).all() + + @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("channels", [1, 3]) def test_equalize(device, channels): diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 09ae726931c..866ac047298 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -937,10 +937,10 @@ def autocontrast(img: Tensor) -> Tensor: minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype) maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype) - eq_idxs = torch.where(minimum == maximum)[0] - minimum[eq_idxs] = 0 - maximum[eq_idxs] = bound scale = bound / (maximum - minimum) + eq_idxs = torch.isfinite(scale).logical_not() + minimum[eq_idxs] = 0 + scale[eq_idxs] = 1 return ((img - minimum) * scale).clamp(0, bound).to(img.dtype) From 35de04cd87d4b6eac361548093faed0644b6fece Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sun, 28 Nov 2021 11:13:45 +0000 Subject: [PATCH 2/2] Adding PIL vs TorchVision test for min==max --- test/test_transforms.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index 876452abab4..512a343ee59 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -573,6 +573,15 @@ def test_randomness(fn, trans, kwargs, seed, p): trans(**kwargs).__repr__() +def test_autocontrast_equal_minmax(): + img_tensor = torch.tensor([[[10]], [[128]], [[245]]], dtype=torch.uint8).expand(3, 32, 32) + img_pil = F.to_pil_image(img_tensor) + + img_tensor = F.autocontrast(img_tensor) + img_pil = F.autocontrast(img_pil) + torch.testing.assert_close(img_tensor, F.pil_to_tensor(img_pil)) + + class TestToPil: def _get_1_channel_tensor_various_types(): img_data_float = torch.Tensor(1, 4, 4).uniform_()