Skip to content

Commit cb7d27a

Browse files
committed
Fix bug on autocontrast when min==max
1 parent 47281bb commit cb7d27a

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

test/test_functional_tensor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,18 @@ def test_autocontrast(device, dtype, channels):
820820
)
821821

822822

823+
@pytest.mark.parametrize("device", cpu_and_gpu())
824+
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
825+
@pytest.mark.parametrize("channels", [1, 3])
826+
def test_autocontrast_equal_minmax(device, dtype, channels):
827+
a = _create_data_batch(32, 32, num_samples=1, channels=channels, device=device)
828+
a = a / 2.0 + 0.3
829+
assert (F.autocontrast(a)[0] == F.autocontrast(a[0])).all()
830+
831+
a[0, 0] = 0.7
832+
assert (F.autocontrast(a)[0] == F.autocontrast(a[0])).all()
833+
834+
823835
@pytest.mark.parametrize("device", cpu_and_gpu())
824836
@pytest.mark.parametrize("channels", [1, 3])
825837
def test_equalize(device, channels):

torchvision/transforms/functional_tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -937,10 +937,10 @@ def autocontrast(img: Tensor) -> Tensor:
937937

938938
minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
939939
maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype)
940-
eq_idxs = torch.where(minimum == maximum)[0]
941-
minimum[eq_idxs] = 0
942-
maximum[eq_idxs] = bound
943940
scale = bound / (maximum - minimum)
941+
eq_idxs = torch.isfinite(scale).logical_not()
942+
minimum[eq_idxs] = 0
943+
scale[eq_idxs] = 1
944944

945945
return ((img - minimum) * scale).clamp(0, bound).to(img.dtype)
946946

0 commit comments

Comments
 (0)