Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,19 @@ def test_resize_antialias_error():
t(img)


@pytest.mark.parametrize("height, width", ((32, 64), (64, 32)))
def test_resize_size_equals_small_edge_size(height, width):
# Non-regression test for https://github.com/pytorch/vision/issues/5405
# max_size used to be ignored if size == small_edge_size
max_size = 40
img = Image.new("RGB", size=(width, height), color=127)

small_edge = min(height, width)
t = transforms.Resize(small_edge, max_size=max_size)
result = t(img)
assert max(result.size) == max_size


class TestPad:
def test_pad(self):
height = random.randint(10, 32) * 2
Expand Down
9 changes: 5 additions & 4 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,6 @@ def resize(
w, h = img.size

short, long = (w, h) if w <= h else (h, w)
if short == size:
return img

new_short, new_long = size, int(size * long / short)

if max_size is not None:
Expand All @@ -255,7 +252,11 @@ def resize(
new_short, new_long = int(max_size * new_short / new_long), max_size

new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
return img.resize((new_w, new_h), interpolation)

if (w, h) == (new_w, new_h):
return img
else:
return img.resize((new_w, new_h), interpolation)
else:
if max_size is not None:
raise ValueError(
Expand Down
6 changes: 3 additions & 3 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,6 @@ def resize(
short, long = (w, h) if w <= h else (h, w)
requested_new_short = size if isinstance(size, int) else size[0]

if short == requested_new_short:
return img

new_short, new_long = requested_new_short, int(requested_new_short * long / short)

if max_size is not None:
Expand All @@ -473,6 +470,9 @@ def resize(

new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)

if (w, h) == (new_w, new_h):
return img

else: # specified both h and w
new_w, new_h = size[1], size[0]

Expand Down