|
2 | 2 | import numbers
|
3 | 3 | import random
|
4 | 4 | import warnings
|
5 |
| -from collections.abc import Sequence, Iterable |
| 5 | +from collections.abc import Sequence |
6 | 6 | from typing import Tuple, List, Optional
|
7 | 7 |
|
8 | 8 | import numpy as np
|
@@ -209,31 +209,38 @@ def __repr__(self):
|
209 | 209 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
|
210 | 210 |
|
211 | 211 |
|
212 |
| -class Resize(object): |
213 |
| - """Resize the input PIL Image to the given size. |
| 212 | +class Resize(torch.nn.Module): |
| 213 | + """Resize the input image to the given size. |
| 214 | + The image can be a PIL Image or a torch Tensor, in which case it is expected |
| 215 | + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions |
214 | 216 |
|
215 | 217 | Args:
|
216 | 218 | size (sequence or int): Desired output size. If size is a sequence like
|
217 | 219 | (h, w), output size will be matched to this. If size is an int,
|
218 | 220 | smaller edge of the image will be matched to this number.
|
219 | 221 | i.e, if height > width, then image will be rescaled to
|
220 |
| - (size * height / width, size) |
221 |
| - interpolation (int, optional): Desired interpolation. Default is |
222 |
| - ``PIL.Image.BILINEAR`` |
| 222 | + (size * height / width, size). |
| 223 | + In torchscript mode padding as single int is not supported, use a tuple or |
| 224 | + list of length 1: ``[padding, ]``. |
| 225 | + interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR`` |
223 | 226 | """
|
224 | 227 |
|
225 | 228 | def __init__(self, size, interpolation=Image.BILINEAR):
|
226 |
| - assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2) |
| 229 | + super().__init__() |
| 230 | + if not isinstance(size, (int, Sequence)): |
| 231 | + raise TypeError("Size should be int or sequence. Got {}".format(type(size))) |
| 232 | + if isinstance(size, Sequence) and len(size) not in (1, 2): |
| 233 | + raise ValueError("If size is a sequence, it should have 1 or 2 values") |
227 | 234 | self.size = size
|
228 | 235 | self.interpolation = interpolation
|
229 | 236 |
|
230 |
| - def __call__(self, img): |
| 237 | + def forward(self, img): |
231 | 238 | """
|
232 | 239 | Args:
|
233 |
| - img (PIL Image): Image to be scaled. |
| 240 | + img (PIL Image or Tensor): Image to be scaled. |
234 | 241 |
|
235 | 242 | Returns:
|
236 |
| - PIL Image: Rescaled image. |
| 243 | + PIL Image or Tensor: Rescaled image. |
237 | 244 | """
|
238 | 245 | return F.resize(img, self.size, self.interpolation)
|
239 | 246 |
|
|
0 commit comments