Skip to content

Commit 6fcf0a2

Browse files
authored
Fix resize when size == small_edge_size and max_size isn't None (#5409)
* Fix resize when size == small_edge_size and max_size isn't None * Better test name
1 parent 26fe8fa commit 6fcf0a2

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

test/test_transforms.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,19 @@ def test_resize_antialias_error():
440440
t(img)
441441

442442

443+
@pytest.mark.parametrize("height, width", ((32, 64), (64, 32)))
444+
def test_resize_size_equals_small_edge_size(height, width):
445+
# Non-regression test for https://github.com/pytorch/vision/issues/5405
446+
# max_size used to be ignored if size == small_edge_size
447+
max_size = 40
448+
img = Image.new("RGB", size=(width, height), color=127)
449+
450+
small_edge = min(height, width)
451+
t = transforms.Resize(small_edge, max_size=max_size)
452+
result = t(img)
453+
assert max(result.size) == max_size
454+
455+
443456
class TestPad:
444457
def test_pad(self):
445458
height = random.randint(10, 32) * 2

torchvision/transforms/functional_pil.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,6 @@ def resize(
240240
w, h = img.size
241241

242242
short, long = (w, h) if w <= h else (h, w)
243-
if short == size:
244-
return img
245-
246243
new_short, new_long = size, int(size * long / short)
247244

248245
if max_size is not None:
@@ -255,7 +252,11 @@ def resize(
255252
new_short, new_long = int(max_size * new_short / new_long), max_size
256253

257254
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
258-
return img.resize((new_w, new_h), interpolation)
255+
256+
if (w, h) == (new_w, new_h):
257+
return img
258+
else:
259+
return img.resize((new_w, new_h), interpolation)
259260
else:
260261
if max_size is not None:
261262
raise ValueError(

torchvision/transforms/functional_tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -457,9 +457,6 @@ def resize(
457457
short, long = (w, h) if w <= h else (h, w)
458458
requested_new_short = size if isinstance(size, int) else size[0]
459459

460-
if short == requested_new_short:
461-
return img
462-
463460
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
464461

465462
if max_size is not None:
@@ -473,6 +470,9 @@ def resize(
473470

474471
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
475472

473+
if (w, h) == (new_w, new_h):
474+
return img
475+
476476
else: # specified both h and w
477477
new_w, new_h = size[1], size[0]
478478

0 commit comments

Comments
 (0)