diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 40a51775a09..8f923475664 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -850,6 +850,18 @@ def test_autocontrast(device, dtype, channels): ) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) +@pytest.mark.parametrize("channels", [1, 3]) +def test_autocontrast_equal_minmax(device, dtype, channels): + a = _create_data_batch(32, 32, num_samples=1, channels=channels, device=device) + a = a / 2.0 + 0.3 + assert (F.autocontrast(a)[0] == F.autocontrast(a[0])).all() + + a[0, 0] = 0.7 + assert (F.autocontrast(a)[0] == F.autocontrast(a[0])).all() + + @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("channels", [1, 3]) def test_equalize(device, channels): diff --git a/test/test_transforms.py b/test/test_transforms.py index 876452abab4..512a343ee59 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -573,6 +573,15 @@ def test_randomness(fn, trans, kwargs, seed, p): trans(**kwargs).__repr__() +def test_autocontrast_equal_minmax(): + img_tensor = torch.tensor([[[10]], [[128]], [[245]]], dtype=torch.uint8).expand(3, 32, 32) + img_pil = F.to_pil_image(img_tensor) + + img_tensor = F.autocontrast(img_tensor) + img_pil = F.autocontrast(img_pil) + torch.testing.assert_close(img_tensor, F.pil_to_tensor(img_pil)) + + class TestToPil: def _get_1_channel_tensor_various_types(): img_data_float = torch.Tensor(1, 4, 4).uniform_() diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index dfb32a41adf..4e20c19e45f 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -945,10 +945,10 @@ def autocontrast(img: Tensor) -> Tensor: minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype) maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype) - eq_idxs = torch.where(minimum == maximum)[0] - minimum[eq_idxs] = 0 - maximum[eq_idxs] = bound scale = bound / (maximum - minimum) + eq_idxs = torch.isfinite(scale).logical_not() + minimum[eq_idxs] = 0 + scale[eq_idxs] = 1 return ((img - minimum) * scale).clamp(0, bound).to(img.dtype)