Skip to content

Commit 9b80465

Browse files
authored
Unified input for resized crop op (#2396)
* [WIP] Unify random resized crop * Unify input for RandomResizedCrop * Fixed bugs and updated test * Added resized crop functional test - fixed bug with size convention * Fixed incoherent sampling * Fixed torch randint review remark
1 parent b572d5e commit 9b80465

File tree

5 files changed

+87
-31
lines changed

5 files changed

+87
-31
lines changed

test/test_functional_tensor.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,23 @@ def test_resize(self):
331331
pad_tensor_script = script_fn(tensor, size=script_size, interpolation=interpolation)
332332
self.assertTrue(resized_tensor.equal(pad_tensor_script), msg="{}, {}".format(size, interpolation))
333333

334+
def test_resized_crop(self):
335+
# test values of F.resized_crop in several cases:
336+
# 1) resize to the same size, crop to the same size => should be identity
337+
tensor, _ = self._create_data(26, 36)
338+
for i in [0, 2, 3]:
339+
out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=i)
340+
self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
341+
342+
# 2) resize by half and crop a TL corner
343+
tensor, _ = self._create_data(26, 36)
344+
out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=0)
345+
expected_out_tensor = tensor[:, :20:2, :30:2]
346+
self.assertTrue(
347+
expected_out_tensor.equal(out_tensor),
348+
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10])
349+
)
350+
334351

335352
if __name__ == '__main__':
336353
unittest.main()

test/test_transforms_tensor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,25 @@ def test_resize(self):
245245
s_resized_tensor = script_transform(tensor)
246246
self.assertTrue(s_resized_tensor.equal(resized_tensor))
247247

248+
def test_resized_crop(self):
249+
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8)
250+
251+
scale = (0.7, 1.2)
252+
ratio = (0.75, 1.333)
253+
254+
for size in [(32, ), [32, ], [32, 32], (32, 32)]:
255+
for interpolation in [NEAREST, BILINEAR, BICUBIC]:
256+
transform = T.RandomResizedCrop(
257+
size=size, scale=scale, ratio=ratio, interpolation=interpolation
258+
)
259+
s_transform = torch.jit.script(transform)
260+
261+
torch.manual_seed(12)
262+
out1 = transform(tensor)
263+
torch.manual_seed(12)
264+
out2 = s_transform(tensor)
265+
self.assertTrue(out1.equal(out2))
266+
248267

249268
if __name__ == '__main__':
250269
unittest.main()

torchvision/transforms/functional.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -439,24 +439,26 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
439439
return crop(img, crop_top, crop_left, crop_height, crop_width)
440440

441441

442-
def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINEAR):
443-
"""Crop the given PIL Image and resize it to desired size.
442+
def resized_crop(
443+
img: Tensor, top: int, left: int, height: int, width: int, size: List[int], interpolation: int = Image.BILINEAR
444+
) -> Tensor:
445+
"""Crop the given image and resize it to desired size.
446+
The image can be a PIL Image or a Tensor, in which case it is expected
447+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
444448
445449
Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
446450
447451
Args:
448-
img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image.
452+
img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
449453
top (int): Vertical component of the top left corner of the crop box.
450454
left (int): Horizontal component of the top left corner of the crop box.
451455
height (int): Height of the crop box.
452456
width (int): Width of the crop box.
453457
size (sequence or int): Desired output size. Same semantics as ``resize``.
454-
interpolation (int, optional): Desired interpolation. Default is
455-
``PIL.Image.BILINEAR``.
458+
interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR``.
456459
Returns:
457-
PIL Image: Cropped image.
460+
PIL Image or Tensor: Cropped image.
458461
"""
459-
assert F_pil._is_pil_image(img), 'img should be PIL Image'
460462
img = crop(img, top, left, height, width)
461463
img = resize(img, size, interpolation)
462464
return img

torchvision/transforms/functional_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
532532
elif len(size) < 2:
533533
size_w, size_h = size[0], size[0]
534534
else:
535-
size_w, size_h = size[0], size[1]
535+
size_w, size_h = size[1], size[0] # Convention (h, w)
536536

537537
if isinstance(size, int) or len(size) < 2:
538538
if w < h:

torchvision/transforms/transforms.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -687,40 +687,56 @@ def __repr__(self):
687687
return self.__class__.__name__ + '(p={})'.format(self.p)
688688

689689

690-
class RandomResizedCrop(object):
691-
"""Crop the given PIL Image to random size and aspect ratio.
690+
class RandomResizedCrop(torch.nn.Module):
691+
"""Crop the given image to random size and aspect ratio.
692+
The image can be a PIL Image or a Tensor, in which case it is expected
693+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
692694
693695
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
694696
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
695697
is finally resized to given size.
696698
This is popularly used to train the Inception networks.
697699
698700
Args:
699-
size: expected output size of each edge
700-
scale: range of size of the origin size cropped
701-
ratio: range of aspect ratio of the origin aspect ratio cropped
702-
interpolation: Default: PIL.Image.BILINEAR
701+
size (int or sequence): expected output size of each edge. If size is an
702+
int instead of sequence like (h, w), a square output size ``(size, size)`` is
703+
made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
704+
scale (tuple of float): range of size of the origin size cropped
705+
ratio (tuple of float): range of aspect ratio of the origin aspect ratio cropped.
706+
interpolation (int): Desired interpolation. Default: ``PIL.Image.BILINEAR``
703707
"""
704708

705709
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
706-
if isinstance(size, (tuple, list)):
707-
self.size = size
710+
super().__init__()
711+
if isinstance(size, numbers.Number):
712+
self.size = (int(size), int(size))
713+
elif isinstance(size, Sequence) and len(size) == 1:
714+
self.size = (size[0], size[0])
708715
else:
709-
self.size = (size, size)
716+
if len(size) != 2:
717+
raise ValueError("Please provide only two dimensions (h, w) for size.")
718+
self.size = size
719+
720+
if not isinstance(scale, (tuple, list)):
721+
raise TypeError("Scale should be a sequence")
722+
if not isinstance(ratio, (tuple, list)):
723+
raise TypeError("Ratio should be a sequence")
710724
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
711-
warnings.warn("range should be of kind (min, max)")
725+
warnings.warn("Scale and ratio should be of kind (min, max)")
712726

713727
self.interpolation = interpolation
714728
self.scale = scale
715729
self.ratio = ratio
716730

717731
@staticmethod
718-
def get_params(img, scale, ratio):
732+
def get_params(
733+
img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float]
734+
) -> Tuple[int, int, int, int]:
719735
"""Get parameters for ``crop`` for a random sized crop.
720736
721737
Args:
722-
img (PIL Image): Image to be cropped.
723-
scale (tuple): range of size of the origin size cropped
738+
img (PIL Image or Tensor): Input image.
739+
scale (tuple): range of scale of the origin size cropped
724740
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
725741
726742
Returns:
@@ -731,24 +747,26 @@ def get_params(img, scale, ratio):
731747
area = height * width
732748

733749
for _ in range(10):
734-
target_area = random.uniform(*scale) * area
735-
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
736-
aspect_ratio = math.exp(random.uniform(*log_ratio))
750+
target_area = area * torch.empty(1).uniform_(*scale).item()
751+
log_ratio = torch.log(torch.tensor(ratio))
752+
aspect_ratio = torch.exp(
753+
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
754+
).item()
737755

738756
w = int(round(math.sqrt(target_area * aspect_ratio)))
739757
h = int(round(math.sqrt(target_area / aspect_ratio)))
740758

741759
if 0 < w <= width and 0 < h <= height:
742-
i = random.randint(0, height - h)
743-
j = random.randint(0, width - w)
760+
i = torch.randint(0, height - h + 1, size=(1,)).item()
761+
j = torch.randint(0, width - w + 1, size=(1,)).item()
744762
return i, j, h, w
745763

746764
# Fallback to central crop
747765
in_ratio = float(width) / float(height)
748-
if (in_ratio < min(ratio)):
766+
if in_ratio < min(ratio):
749767
w = width
750768
h = int(round(w / min(ratio)))
751-
elif (in_ratio > max(ratio)):
769+
elif in_ratio > max(ratio):
752770
h = height
753771
w = int(round(h * max(ratio)))
754772
else: # whole image
@@ -758,13 +776,13 @@ def get_params(img, scale, ratio):
758776
j = (width - w) // 2
759777
return i, j, h, w
760778

761-
def __call__(self, img):
779+
def forward(self, img):
762780
"""
763781
Args:
764-
img (PIL Image): Image to be cropped and resized.
782+
img (PIL Image or Tensor): Image to be cropped and resized.
765783
766784
Returns:
767-
PIL Image: Randomly cropped and resized image.
785+
PIL Image or Tensor: Randomly cropped and resized image.
768786
"""
769787
i, j, h, w = self.get_params(img, self.scale, self.ratio)
770788
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)

0 commit comments

Comments
 (0)