Skip to content

Commit d69e4b1

Browse files
committed
Adapted T.Resize and F.resize with a test
1 parent 7936258 commit d69e4b1

File tree

3 files changed

+48
-11
lines changed

3 files changed

+48
-11
lines changed

test/test_transforms_tensor.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from torchvision import transforms as T
33
from torchvision.transforms import functional as F
44
from PIL import Image
5+
from PIL.Image import NEAREST, BILINEAR, BICUBIC
56

67
import numpy as np
78

@@ -217,6 +218,33 @@ def test_ten_crop(self):
217218
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
218219
)
219220

221+
def test_resize(self):
222+
tensor, _ = self._create_data(height=34, width=36)
223+
script_fn = torch.jit.script(F.resize)
224+
225+
for dt in [None, torch.float32, torch.float64]:
226+
if dt is not None:
227+
# This is a trivial cast to float of uint8 data to test all cases
228+
tensor = tensor.to(dt)
229+
for size in [32, [32, ], [32, 32], (32, 32), ]:
230+
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
231+
232+
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation)
233+
234+
if isinstance(size, int):
235+
script_size = [size, ]
236+
else:
237+
script_size = size
238+
239+
s_resized_tensor = script_fn(tensor, size=script_size, interpolation=interpolation)
240+
self.assertTrue(s_resized_tensor.equal(resized_tensor))
241+
242+
transform = T.Resize(size=script_size, interpolation=interpolation)
243+
resized_tensor = transform(tensor)
244+
script_transform = torch.jit.script(transform)
245+
s_resized_tensor = script_transform(tensor)
246+
self.assertTrue(s_resized_tensor.equal(resized_tensor))
247+
220248

221249
if __name__ == '__main__':
222250
unittest.main()

torchvision/transforms/functional.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,9 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
322322
(h, w), the output size will be matched to this. If size is an int,
323323
the smaller edge of the image will be matched to this number maintaining
324324
the aspect ratio. i.e, if height > width, then image will be rescaled to
325-
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`
325+
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
326+
In torchscript mode padding as single int is not supported, use a tuple or
327+
list of length 1: ``[size, ]``.
326328
interpolation (int, optional): Desired interpolation. Default is bilinear.
327329
328330
Returns:

torchvision/transforms/transforms.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numbers
33
import random
44
import warnings
5-
from collections.abc import Sequence, Iterable
5+
from collections.abc import Sequence
66
from typing import Tuple, List, Optional
77

88
import numpy as np
@@ -209,31 +209,38 @@ def __repr__(self):
209209
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
210210

211211

212-
class Resize(object):
213-
"""Resize the input PIL Image to the given size.
212+
class Resize(torch.nn.Module):
213+
"""Resize the input image to the given size.
214+
The image can be a PIL Image or a torch Tensor, in which case it is expected
215+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
214216
215217
Args:
216218
size (sequence or int): Desired output size. If size is a sequence like
217219
(h, w), output size will be matched to this. If size is an int,
218220
smaller edge of the image will be matched to this number.
219221
i.e, if height > width, then image will be rescaled to
220-
(size * height / width, size)
221-
interpolation (int, optional): Desired interpolation. Default is
222-
``PIL.Image.BILINEAR``
222+
(size * height / width, size).
223+
In torchscript mode padding as single int is not supported, use a tuple or
224+
list of length 1: ``[padding, ]``.
225+
interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR``
223226
"""
224227

225228
def __init__(self, size, interpolation=Image.BILINEAR):
226-
assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
229+
super().__init__()
230+
if not isinstance(size, (int, Sequence)):
231+
raise TypeError("Size should be int or sequence. Got {}".format(type(size)))
232+
if isinstance(size, Sequence) and len(size) not in (1, 2):
233+
raise ValueError("If size is a sequence, it should have 1 or 2 values")
227234
self.size = size
228235
self.interpolation = interpolation
229236

230-
def __call__(self, img):
237+
def forward(self, img):
231238
"""
232239
Args:
233-
img (PIL Image): Image to be scaled.
240+
img (PIL Image or Tensor): Image to be scaled.
234241
235242
Returns:
236-
PIL Image: Rescaled image.
243+
PIL Image or Tensor: Rescaled image.
237244
"""
238245
return F.resize(img, self.size, self.interpolation)
239246

0 commit comments

Comments
 (0)