Skip to content

Commit 8088cc9

Browse files
authored
Fixed bug with Resize.size if input is integer (#2869)
1 parent d1e134c commit 8088cc9

File tree

3 files changed

+21
-1
lines changed

3 files changed

+21
-1
lines changed

test/test_transforms.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,11 @@ def test_resize(self):
219219
width = random.randint(24, 32) * 2
220220
osize = random.randint(5, 12) * 2
221221

222+
# TODO: Check output size check for bug-fix, improve this later
223+
t = transforms.Resize(osize)
224+
self.assertTrue(isinstance(t.size, int))
225+
self.assertEqual(t.size, osize)
226+
222227
img = torch.ones(3, height, width)
223228
result = transforms.Compose([
224229
transforms.ToPILImage(),

test/test_transforms_tensor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,17 @@ def test_ten_crop(self):
280280
)
281281

282282
def test_resize(self):
283+
284+
# TODO: Minimal check for bug-fix, improve this later
285+
x = torch.rand(3, 32, 46)
286+
t = T.Resize(size=38)
287+
y = t(x)
288+
# If size is an int, smaller edge of the image will be matched to this number.
289+
# i.e, if height > width, then image will be rescaled to (size * height / width, size).
290+
self.assertTrue(isinstance(y, torch.Tensor))
291+
self.assertEqual(y.shape[1], 38)
292+
self.assertEqual(y.shape[2], int(38 * 46 / 32))
293+
283294
tensor, _ = self._create_data(height=34, width=36, device=self.device)
284295
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
285296
script_fn = torch.jit.script(F.resize)

torchvision/transforms/transforms.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,11 @@ class Resize(torch.nn.Module):
249249

250250
def __init__(self, size, interpolation=Image.BILINEAR):
251251
super().__init__()
252-
self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values")
252+
if not isinstance(size, (int, Sequence)):
253+
raise TypeError("Size should be int or sequence. Got {}".format(type(size)))
254+
if isinstance(size, Sequence) and len(size) not in (1, 2):
255+
raise ValueError("If size is a sequence, it should have 1 or 2 values")
256+
self.size = size
253257
self.interpolation = interpolation
254258

255259
def forward(self, img):

0 commit comments

Comments
 (0)