Skip to content

Commit fb66035

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Fix bug on autocontrast when min==max (#4999)
Summary: * Fix bug on autocontrast when `min==max` * Adding PIL vs TorchVision test for min==max Reviewed By: NicolasHug Differential Revision: D32759202 fbshipit-source-id: 609b591236fce73ef74fa9f00707af111f8154e7
1 parent 1c175d8 commit fb66035

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

test/test_functional_tensor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,18 @@ def test_autocontrast(device, dtype, channels):
850850
)
851851

852852

853+
@pytest.mark.parametrize("device", cpu_and_gpu())
854+
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
855+
@pytest.mark.parametrize("channels", [1, 3])
856+
def test_autocontrast_equal_minmax(device, dtype, channels):
857+
a = _create_data_batch(32, 32, num_samples=1, channels=channels, device=device)
858+
a = a / 2.0 + 0.3
859+
assert (F.autocontrast(a)[0] == F.autocontrast(a[0])).all()
860+
861+
a[0, 0] = 0.7
862+
assert (F.autocontrast(a)[0] == F.autocontrast(a[0])).all()
863+
864+
853865
@pytest.mark.parametrize("device", cpu_and_gpu())
854866
@pytest.mark.parametrize("channels", [1, 3])
855867
def test_equalize(device, channels):

test/test_transforms.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,15 @@ def test_randomness(fn, trans, kwargs, seed, p):
573573
trans(**kwargs).__repr__()
574574

575575

576+
def test_autocontrast_equal_minmax():
577+
img_tensor = torch.tensor([[[10]], [[128]], [[245]]], dtype=torch.uint8).expand(3, 32, 32)
578+
img_pil = F.to_pil_image(img_tensor)
579+
580+
img_tensor = F.autocontrast(img_tensor)
581+
img_pil = F.autocontrast(img_pil)
582+
torch.testing.assert_close(img_tensor, F.pil_to_tensor(img_pil))
583+
584+
576585
class TestToPil:
577586
def _get_1_channel_tensor_various_types():
578587
img_data_float = torch.Tensor(1, 4, 4).uniform_()

torchvision/transforms/functional_tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -945,10 +945,10 @@ def autocontrast(img: Tensor) -> Tensor:
945945

946946
minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
947947
maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype)
948-
eq_idxs = torch.where(minimum == maximum)[0]
949-
minimum[eq_idxs] = 0
950-
maximum[eq_idxs] = bound
951948
scale = bound / (maximum - minimum)
949+
eq_idxs = torch.isfinite(scale).logical_not()
950+
minimum[eq_idxs] = 0
951+
scale[eq_idxs] = 1
952952

953953
return ((img - minimum) * scale).clamp(0, bound).to(img.dtype)
954954

0 commit comments

Comments
 (0)