From 3025a36af7c0d72afd3f101496bd42e430c9a9b2 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 23 Jun 2020 11:13:21 +0200 Subject: [PATCH 01/11] [WIP] Unified Tensor/PIL crop --- test/test_transforms_tensor.py | 31 ++++++--- torchvision/transforms/functional.py | 53 ++++++++------- torchvision/transforms/functional_pil.py | 31 ++++++++- torchvision/transforms/functional_tensor.py | 14 +++- torchvision/transforms/transforms.py | 74 ++++++++++----------- 5 files changed, 128 insertions(+), 75 deletions(-) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 7791dd8b4f9..8996717ffa9 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -18,26 +18,30 @@ def compareTensorToPIL(self, tensor, pil_image): pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))) self.assertTrue(tensor.equal(pil_tensor)) - def _test_flip(self, func, method): - tensor, pil_img = self._create_data() - flip_tensor = getattr(F, func)(tensor) - flip_pil_img = getattr(F, func)(pil_img) - self.compareTensorToPIL(flip_tensor, flip_pil_img) + def _test_geom_op(self, func, method, fn_kwargs=None, meth_kwargs=None): + if fn_kwargs is None: + fn_kwargs = {} + if meth_kwargs is None: + meth_kwargs = {} + tensor, pil_img = self._create_data(height=10, width=10) + transformed_tensor = getattr(F, func)(tensor, **fn_kwargs) + transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs) + self.compareTensorToPIL(transformed_tensor, transformed_pil_img) scripted_fn = torch.jit.script(getattr(F, func)) - flip_tensor_script = scripted_fn(tensor) - self.assertTrue(flip_tensor.equal(flip_tensor_script)) + transformed_tensor_script = scripted_fn(tensor, **fn_kwargs) + self.assertTrue(transformed_tensor.equal(transformed_tensor_script)) # test for class interface - f = getattr(T, method)() + f = getattr(T, method)(**meth_kwargs) scripted_fn = torch.jit.script(f) scripted_fn(tensor) def test_random_horizontal_flip(self): - self._test_flip('hflip', 'RandomHorizontalFlip') + self._test_geom_op('hflip', 'RandomHorizontalFlip') def test_random_vertical_flip(self): - self._test_flip('vflip', 'RandomVerticalFlip') + self._test_geom_op('vflip', 'RandomVerticalFlip') def test_adjustments(self): fns = ['adjust_brightness', 'adjust_contrast', 'adjust_saturation'] @@ -65,6 +69,13 @@ def test_adjustments(self): self.assertLess(max_diff, 5 / 255 + 1e-5) self.assertLess(max_diff_scripted, 5 / 255 + 1e-5) + def test_crop(self): + fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5} + meth_kwargs = {"size": (4, 5), "padding": 4, "pad_if_needed": True, } + self._test_geom_op( + 'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 5d8549ea883..d5caf97db8e 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -16,18 +16,24 @@ from . import functional_tensor as F_t -def _is_pil_image(img): - if accimage is not None: - return isinstance(img, (Image.Image, accimage.Image)) - else: - return isinstance(img, Image.Image) +@torch.jit.export +def _get_image_size(img): + # type: (Tensor) -> List[int] + if isinstance(img, torch.Tensor): + return F_t._get_image_size(img) + + return F_pil._get_image_size(img) +@torch.jit.ignore def _is_numpy(img): + # type: (Any) -> bool return isinstance(img, np.ndarray) +@torch.jit.ignore def _is_numpy_image(img): + # type: (Any) -> bool return img.ndim in {2, 3} @@ -42,7 +48,7 @@ def to_tensor(pic): Returns: Tensor: Converted image. """ - if not(_is_pil_image(pic) or _is_numpy(pic)): + if not(F_pil._is_pil_image(pic) or _is_numpy(pic)): raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) if _is_numpy(pic) and not _is_numpy_image(pic): @@ -97,7 +103,7 @@ def pil_to_tensor(pic): Returns: Tensor: Converted image. """ - if not(_is_pil_image(pic)): + if not(F_pil._is_pil_image(pic)): raise TypeError('pic should be PIL Image. Got {}'.format(type(pic))) if accimage is not None and isinstance(pic, accimage.Image): @@ -315,7 +321,7 @@ def resize(img, size, interpolation=Image.BILINEAR): Returns: PIL Image: Resized image. """ - if not _is_pil_image(img): + if not F_pil._is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)): raise TypeError('Got inappropriate size arg: {}'.format(size)) @@ -374,7 +380,7 @@ def pad(img, padding, fill=0, padding_mode='constant'): Returns: PIL Image: Padded image. """ - if not _is_pil_image(img): + if not F_pil._is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) if not isinstance(padding, (numbers.Number, tuple)): @@ -436,23 +442,24 @@ def pad(img, padding, fill=0, padding_mode='constant'): return Image.fromarray(img) -def crop(img, top, left, height, width): +def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: """Crop the given PIL Image. Args: - img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. + img (PIL Image or torch.Tensor): Image to be cropped. (0,0) denotes the top left corner of the image. top (int): Vertical component of the top left corner of the crop box. left (int): Horizontal component of the top left corner of the crop box. height (int): Height of the crop box. width (int): Width of the crop box. Returns: - PIL Image: Cropped image. + PIL Image or torch.Tensor: Cropped image. """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - return img.crop((left, top, left + width, top + height)) + if not isinstance(img, torch.Tensor): + return F_pil.crop(img, top, left, height, width) + + return F_t.crop(img, top, left, height, width) def center_crop(img, output_size): @@ -491,7 +498,7 @@ def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINE Returns: PIL Image: Cropped image. """ - assert _is_pil_image(img), 'img should be PIL Image' + assert F_pil._is_pil_image(img), 'img should be PIL Image' img = crop(img, top, left, height, width) img = resize(img, size, interpolation) return img @@ -501,13 +508,13 @@ def hflip(img: Tensor) -> Tensor: """Horizontally flip the given PIL Image or torch Tensor. Args: - img (PIL Image or Torch Tensor): Image to be flipped. If img + img (PIL Image or torch.Tensor): Image to be flipped. If img is a Tensor, it is expected to be in [..., H, W] format, where ... means it can have an arbitrary number of trailing dimensions. Returns: - PIL Image: Horizontally flipped image. + PIL Image or torch.Tensor: Horizontally flipped image. """ if not isinstance(img, torch.Tensor): return F_pil.hflip(img) @@ -593,7 +600,7 @@ def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=N PIL Image: Perspectively transformed Image. """ - if not _is_pil_image(img): + if not F_pil._is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) opts = _parse_fill(fill, img, '5.0.0') @@ -797,7 +804,7 @@ def adjust_gamma(img, gamma, gain=1): while gamma smaller than 1 make dark regions lighter. gain (float): The constant multiplier. """ - if not _is_pil_image(img): + if not F_pil._is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) if gamma < 0: @@ -837,7 +844,7 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=None): .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters """ - if not _is_pil_image(img): + if not F_pil._is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) opts = _parse_fill(fill, img, '5.2.0') @@ -918,7 +925,7 @@ def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None): If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) """ - if not _is_pil_image(img): + if not F_pil._is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ @@ -945,7 +952,7 @@ def to_grayscale(img, num_output_channels=1): if num_output_channels = 3 : returned image is 3 channel with r = g = b """ - if not _is_pil_image(img): + if not F_pil._is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) if num_output_channels == 1: diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 84e27e79040..c848bc13fca 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -3,18 +3,27 @@ import accimage except ImportError: accimage = None -from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION +from PIL import Image, ImageOps, ImageEnhance import numpy as np @torch.jit.unused def _is_pil_image(img): + # type: (Any) -> bool if accimage is not None: return isinstance(img, (Image.Image, accimage.Image)) else: return isinstance(img, Image.Image) +@torch.jit.unused +def _get_image_size(img): + # type: (Any) -> List[int] + if _is_pil_image(img): + return img.size + raise TypeError("Unexpected type {}".format(type(img))) + + @torch.jit.unused def hflip(img): """Horizontally flip the given PIL Image. @@ -152,3 +161,23 @@ def adjust_hue(img, hue_factor): img = Image.merge('HSV', (h, s, v)).convert(input_mode) return img + + +@torch.jit.unused +def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Image.Image: + """Crop the given PIL Image. + + Args: + img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. + top (int): Vertical component of the top left corner of the crop box. + left (int): Horizontal component of the top left corner of the crop box. + height (int): Height of the crop box. + width (int): Width of the crop box. + + Returns: + PIL Image: Cropped image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.crop((left, top, left + width, top + height)) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 89440701d17..c9841bc95cd 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -3,8 +3,18 @@ from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple -def _is_tensor_a_torch_image(input): - return input.ndim >= 2 +@torch.jit.export +def _is_tensor_a_torch_image(x): + # type: (Tensor) -> bool + return x.ndim >= 2 + + +@torch.jit.export +def _get_image_size(img): + # type: (Tensor) -> List[int] + if _is_tensor_a_torch_image(img): + return [img.shape[-1], img.shape[-2]] + raise TypeError("Unexpected type {}".format(type(img))) def vflip(img): diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index ba92bbccf6c..a2ac946bcf8 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -31,15 +31,6 @@ } -def _get_image_size(img): - if F._is_pil_image(img): - return img.size - elif isinstance(img, torch.Tensor) and img.dim() > 2: - return img.shape[-2:][::-1] - else: - raise TypeError("Unexpected type {}".format(type(img))) - - class Compose(object): """Composes several transforms together. @@ -433,8 +424,11 @@ def __call__(self, img): return t(img) -class RandomCrop(object): - """Crop the given PIL Image at a random location. +class RandomCrop(torch.nn.Module): + """Crop the given image at a random location. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions Args: size (sequence or int): Desired output size of the crop. If size is an @@ -469,16 +463,6 @@ class RandomCrop(object): """ - def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'): - if isinstance(size, numbers.Number): - self.size = (int(size), int(size)) - else: - self.size = size - self.padding = padding - self.pad_if_needed = pad_if_needed - self.fill = fill - self.padding_mode = padding_mode - @staticmethod def get_params(img, output_size): """Get parameters for ``crop`` for a random crop. @@ -490,32 +474,44 @@ def get_params(img, output_size): Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. """ - w, h = _get_image_size(img) + # type: (Tensor, Tuple[int, int]) -> Tuple[int, int, int, int] + w, h = F._get_image_size(img) th, tw = output_size if w == tw and h == th: return 0, 0, h, w - i = random.randint(0, h - th) - j = random.randint(0, w - tw) + i = torch.randint(0, h - th, size=(1, )).item() + j = torch.randint(0, w - tw, size=(1, )).item() return i, j, th, tw - def __call__(self, img): + def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'): + super().__init__() + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + self.padding = padding + self.pad_if_needed = pad_if_needed + self.fill = fill + self.padding_mode = padding_mode + + def forward(self, img): """ Args: - img (PIL Image): Image to be cropped. + img (PIL Image or torch.Tensor): Image to be cropped. Returns: - PIL Image: Cropped image. + PIL Image or torch.Tensor: Cropped image. """ - if self.padding is not None: - img = F.pad(img, self.padding, self.fill, self.padding_mode) + # if self.padding is not None: + # img = F.pad(img, self.padding, self.fill, self.padding_mode) - # pad the width if needed - if self.pad_if_needed and img.size[0] < self.size[1]: - img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) - # pad the height if needed - if self.pad_if_needed and img.size[1] < self.size[0]: - img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) + # # pad the width if needed + # if self.pad_if_needed and img.size[0] < self.size[1]: + # img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) + # # pad the height if needed + # if self.pad_if_needed and img.size[1] < self.size[0]: + # img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) i, j, h, w = self.get_params(img, self.size) @@ -542,10 +538,10 @@ def __init__(self, p=0.5): def forward(self, img): """ Args: - img (PIL Image or Tensor): Image to be flipped. + img (PIL Image or torch.Tensor): Image to be flipped. Returns: - PIL Image or Tensor: Randomly flipped image. + PIL Image or torch.Tensor: Randomly flipped image. """ if torch.rand(1) < self.p: return F.hflip(img) @@ -572,10 +568,10 @@ def __init__(self, p=0.5): def forward(self, img): """ Args: - img (PIL Image or Tensor): Image to be flipped. + img (PIL Image or torch.Tensor): Image to be flipped. Returns: - PIL Image or Tensor: Randomly flipped image. + PIL Image or torch.Tensor: Randomly flipped image. """ if torch.rand(1) < self.p: return F.vflip(img) From 2eea2cdb2e15b9e94cbb1030364faf3a3f623f68 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 23 Jun 2020 11:32:38 +0200 Subject: [PATCH 02/11] Fixed misplaced type annotation --- torchvision/transforms/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index a2ac946bcf8..cb361bdf2aa 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -465,6 +465,7 @@ class RandomCrop(torch.nn.Module): @staticmethod def get_params(img, output_size): + # type: (Tensor, Tuple[int, int]) -> Tuple[int, int, int, int] """Get parameters for ``crop`` for a random crop. Args: @@ -474,7 +475,6 @@ def get_params(img, output_size): Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. """ - # type: (Tensor, Tuple[int, int]) -> Tuple[int, int, int, int] w, h = F._get_image_size(img) th, tw = output_size if w == tw and h == th: From c32c8bf4c2d6acafaaee2040b37c0488f3c32e19 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 24 Jun 2020 15:03:24 +0200 Subject: [PATCH 03/11] Fixed tests - crop with padding - other tests using mising private functions: _is_pil_image, _get_image_size --- torchvision/transforms/functional.py | 8 +++--- torchvision/transforms/transforms.py | 40 +++++++++++++++------------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 783f85e8107..31531e61486 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -398,14 +398,14 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: """Crop the given PIL Image. Args: - img (PIL Image or torch.Tensor): Image to be cropped. (0,0) denotes the top left corner of the image. + img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image. top (int): Vertical component of the top left corner of the crop box. left (int): Horizontal component of the top left corner of the crop box. height (int): Height of the crop box. width (int): Width of the crop box. Returns: - PIL Image or torch.Tensor: Cropped image. + PIL Image or Tensor: Cropped image. """ if not isinstance(img, torch.Tensor): @@ -460,13 +460,13 @@ def hflip(img: Tensor) -> Tensor: """Horizontally flip the given PIL Image or torch Tensor. Args: - img (PIL Image or torch.Tensor): Image to be flipped. If img + img (PIL Image or Tensor): Image to be flipped. If img is a Tensor, it is expected to be in [..., H, W] format, where ... means it can have an arbitrary number of trailing dimensions. Returns: - PIL Image or torch.Tensor: Horizontally flipped image. + PIL Image or Tensor: Horizontally flipped image. """ if not isinstance(img, torch.Tensor): return F_pil.hflip(img) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index f86931302de..b3790ec2168 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -89,7 +89,7 @@ def __repr__(self): class PILToTensor(object): """Convert a ``PIL Image`` to a tensor of the same type. - Converts a PIL Image (H x W x C) to a torch.Tensor of shape (C x H x W). + Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W). """ def __call__(self, pic): @@ -490,6 +490,7 @@ def get_params(img, output_size): if w == tw and h == th: return 0, 0, h, w + print(h, th, w, tw) i = torch.randint(0, h - th, size=(1, )).item() j = torch.randint(0, w - tw, size=(1, )).item() return i, j, th, tw @@ -508,20 +509,23 @@ def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode def forward(self, img): """ Args: - img (PIL Image or torch.Tensor): Image to be cropped. + img (PIL Image or Tensor): Image to be cropped. Returns: - PIL Image or torch.Tensor: Cropped image. + PIL Image or Tensor: Cropped image. """ - # if self.padding is not None: - # img = F.pad(img, self.padding, self.fill, self.padding_mode) - - # # pad the width if needed - # if self.pad_if_needed and img.size[0] < self.size[1]: - # img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) - # # pad the height if needed - # if self.pad_if_needed and img.size[1] < self.size[0]: - # img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) + if self.padding is not None: + img = F.pad(img, self.padding, self.fill, self.padding_mode) + + width, height = F._get_image_size(img) + # pad the width if needed + if self.pad_if_needed and height < self.size[1]: + padding = [self.size[1] - height, 0] + img = F.pad(img, padding, self.fill, self.padding_mode) + # pad the height if needed + if self.pad_if_needed and width < self.size[0]: + padding = [0, self.size[0] - width] + img = F.pad(img, padding, self.fill, self.padding_mode) i, j, h, w = self.get_params(img, self.size) @@ -548,10 +552,10 @@ def __init__(self, p=0.5): def forward(self, img): """ Args: - img (PIL Image or torch.Tensor): Image to be flipped. + img (PIL Image or Tensor): Image to be flipped. Returns: - PIL Image or torch.Tensor: Randomly flipped image. + PIL Image or Tensor: Randomly flipped image. """ if torch.rand(1) < self.p: return F.hflip(img) @@ -578,10 +582,10 @@ def __init__(self, p=0.5): def forward(self, img): """ Args: - img (PIL Image or torch.Tensor): Image to be flipped. + img (PIL Image or Tensor): Image to be flipped. Returns: - PIL Image or torch.Tensor: Randomly flipped image. + PIL Image or Tensor: Randomly flipped image. """ if torch.rand(1) < self.p: return F.vflip(img) @@ -619,7 +623,7 @@ def __call__(self, img): Returns: PIL Image: Random perspectivley transformed image. """ - if not F._is_pil_image(img): + if not F.F_pil._is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) if random.random() < self.p: @@ -698,7 +702,7 @@ def get_params(img, scale, ratio): tuple: params (i, j, h, w) to be passed to ``crop`` for a random sized crop. """ - width, height = _get_image_size(img) + width, height = F._get_image_size(img) area = height * width for _ in range(10): From 69cb9c5bdb7d7b068427a33ec079f5e92cd7766b Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 24 Jun 2020 16:05:20 +0200 Subject: [PATCH 04/11] Unified CenterCrop and F.center_crop - sorted includes in transforms.py - used py3 annotations --- test/test_transforms_tensor.py | 14 ++++++- torchvision/transforms/functional.py | 38 +++++++++++------- torchvision/transforms/functional_tensor.py | 17 +++----- torchvision/transforms/transforms.py | 44 ++++++++++++--------- 4 files changed, 68 insertions(+), 45 deletions(-) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 0d5eebbb43c..73e5a114820 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -101,11 +101,23 @@ def test_pad(self): def test_crop(self): fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5} - meth_kwargs = {"size": (4, 5), "padding": 4, "pad_if_needed": True, } + meth_kwargs = {"size": (4, 5), "padding": [4, ], "pad_if_needed": True, } self._test_geom_op( 'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) + def test_center_crop(self): + fn_kwargs = {"output_size": (4, 5)} + meth_kwargs = {"size": (4, 5), } + self._test_geom_op( + "center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + fn_kwargs = {"output_size": (5,)} + meth_kwargs = {"size": (5, )} + self._test_geom_op( + "center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 31531e61486..d4c3a211eef 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -2,6 +2,7 @@ import numbers import warnings from collections.abc import Iterable +from typing import Any import numpy as np from numpy import sin, cos, tan @@ -21,8 +22,9 @@ @torch.jit.export -def _get_image_size(img): - # type: (Tensor) -> List[int] +def _get_image_size(img: Tensor) -> List[int]: + """Returns image sizea as (w, h) + """ if isinstance(img, torch.Tensor): return F_t._get_image_size(img) @@ -30,14 +32,12 @@ def _get_image_size(img): @torch.jit.ignore -def _is_numpy(img): - # type: (Any) -> bool +def _is_numpy(img: Any) -> bool: return isinstance(img, np.ndarray) @torch.jit.ignore -def _is_numpy_image(img): - # type: (Any) -> bool +def _is_numpy_image(img: Any) -> bool: return img.ndim in {2, 3} @@ -414,22 +414,32 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: return F_t.crop(img, top, left, height, width) -def center_crop(img, output_size): - """Crop the given PIL Image and resize it to desired size. +def center_crop(img: Tensor, output_size: List[int]) -> Tensor: + """Crops the given image at the center. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions Args: - img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. - output_size (sequence or int): (height, width) of the crop box. If int, + img (PIL Image or Tensor): Image to be cropped. + output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int it is used for both directions Returns: - PIL Image: Cropped image. + PIL Image or Tensor: Cropped image. """ if isinstance(output_size, numbers.Number): output_size = (int(output_size), int(output_size)) - image_width, image_height = img.size + elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: + output_size = (output_size[0], output_size[0]) + image_width, image_height = _get_image_size(img) crop_height, crop_width = output_size - crop_top = int(round((image_height - crop_height) / 2.)) - crop_left = int(round((image_width - crop_width) / 2.)) + + # crop_top = int(round((image_height - crop_height) / 2.)) + # Result can be different between python func and scripted func + crop_top = int((image_height - crop_height + 1) * 0.5) + # crop_left = int(round((image_width - crop_width) / 2.)) + # Result can be different between python func and scripted func + crop_left = int((image_width - crop_width + 1) * 0.5) + return crop(img, crop_top, crop_left, crop_height, crop_width) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 4c7df3b2902..999a0f6eb79 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -4,21 +4,18 @@ @torch.jit.export -def _is_tensor_a_torch_image(x): - # type: (Tensor) -> bool +def _is_tensor_a_torch_image(x: Tensor) -> bool: return x.ndim >= 2 @torch.jit.export -def _get_image_size(img): - # type: (Tensor) -> List[int] +def _get_image_size(img: Tensor) -> List[int]: if _is_tensor_a_torch_image(img): return [img.shape[-1], img.shape[-2]] raise TypeError("Unexpected type {}".format(type(img))) -def vflip(img): - # type: (Tensor) -> Tensor +def vflip(img: Tensor) -> Tensor: """Vertically flip the given the Image Tensor. Args: @@ -33,8 +30,7 @@ def vflip(img): return img.flip(-2) -def hflip(img): - # type: (Tensor) -> Tensor +def hflip(img: Tensor) -> Tensor: """Horizontally flip the given the Image Tensor. Args: @@ -49,8 +45,7 @@ def hflip(img): return img.flip(-1) -def crop(img, top, left, height, width): - # type: (Tensor, int, int, int, int) -> Tensor +def crop(img: Tensor, top: int, left: int, height: int, width: int): """Crop the given Image Tensor. Args: @@ -64,7 +59,7 @@ def crop(img, top, left, height, width): Tensor: Cropped image. """ if not _is_tensor_a_torch_image(img): - raise TypeError('tensor is not a torch image.') + raise TypeError("tensor is not a torch image.") return img[..., top:top + height, left:left + width] diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index b3790ec2168..342341ed00c 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1,16 +1,19 @@ -import torch import math +import numbers import random +import warnings +from collections.abc import Sequence, Iterable +from typing import Tuple + +import numpy as np +import torch from PIL import Image +from torch import Tensor + try: import accimage except ImportError: accimage = None -import numpy as np -import numbers -import types -from collections.abc import Sequence, Iterable -import warnings from . import functional as F @@ -249,28 +252,33 @@ def __init__(self, *args, **kwargs): super(Scale, self).__init__(*args, **kwargs) -class CenterCrop(object): - """Crops the given PIL Image at the center. +class CenterCrop(torch.nn.Module): + """Crops the given image at the center. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is - made. + made. For scripted operation please use a list: (size, ) or (size_x, size_y) """ def __init__(self, size): + super().__init__() if isinstance(size, numbers.Number): self.size = (int(size), int(size)) + elif isinstance(size, (tuple, list)) and len(size) == 1: + self.size = (size[0], size[0]) else: self.size = size - def __call__(self, img): + def forward(self, img): """ Args: - img (PIL Image): Image to be cropped. + img (PIL Image or Tensor): Image to be cropped. Returns: - PIL Image: Cropped image. + PIL Image or Tensor: Cropped image. """ return F.center_crop(img, self.size) @@ -474,8 +482,7 @@ class RandomCrop(torch.nn.Module): """ @staticmethod - def get_params(img, output_size): - # type: (Tensor, Tuple[int, int]) -> Tuple[int, int, int, int] + def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]: """Get parameters for ``crop`` for a random crop. Args: @@ -490,7 +497,6 @@ def get_params(img, output_size): if w == tw and h == th: return 0, 0, h, w - print(h, th, w, tw) i = torch.randint(0, h - th, size=(1, )).item() j = torch.randint(0, w - tw, size=(1, )).item() return i, j, th, tw @@ -519,12 +525,12 @@ def forward(self, img): width, height = F._get_image_size(img) # pad the width if needed - if self.pad_if_needed and height < self.size[1]: - padding = [self.size[1] - height, 0] + if self.pad_if_needed and width < self.size[1]: + padding = [self.size[1] - width, 0] img = F.pad(img, padding, self.fill, self.padding_mode) # pad the height if needed - if self.pad_if_needed and width < self.size[0]: - padding = [0, self.size[0] - width] + if self.pad_if_needed and height < self.size[0]: + padding = [0, self.size[0] - height] img = F.pad(img, padding, self.fill, self.padding_mode) i, j, h, w = self.get_params(img, self.size) From 01175db1c2b1ced3515f9aeeb4cf7dc4f001e78d Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 24 Jun 2020 17:40:46 +0200 Subject: [PATCH 05/11] Unified FiveCrop and F.five_crop --- test/test_transforms_tensor.py | 39 ++++++++++++++++++++++++++++ torchvision/transforms/functional.py | 34 ++++++++++++++---------- torchvision/transforms/transforms.py | 31 +++++++++++++++++----- 3 files changed, 85 insertions(+), 19 deletions(-) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 73e5a114820..081087088ce 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -118,6 +118,45 @@ def test_center_crop(self): "center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) + def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None): + if fn_kwargs is None: + fn_kwargs = {} + if meth_kwargs is None: + meth_kwargs = {} + tensor, pil_img = self._create_data(height=20, width=20) + transformed_t_list = getattr(F, func)(tensor, **fn_kwargs) + transformed_p_list = getattr(F, func)(pil_img, **fn_kwargs) + self.assertEqual(len(transformed_t_list), len(transformed_p_list)) + self.assertEqual(len(transformed_t_list), out_length) + for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list): + self.compareTensorToPIL(transformed_tensor, transformed_pil_img) + + scripted_fn = torch.jit.script(getattr(F, func)) + transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs) + self.assertEqual(len(transformed_t_list), len(transformed_t_list_script)) + self.assertEqual(len(transformed_t_list_script), out_length) + for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script): + self.assertTrue(transformed_tensor.equal(transformed_tensor_script), + msg="{} vs {}".format(transformed_tensor, transformed_tensor_script)) + + # test for class interface + f = getattr(T, method)(**meth_kwargs) + scripted_fn = torch.jit.script(f) + output = scripted_fn(tensor) + self.assertEqual(len(output), len(transformed_t_list_script)) + + def test_five_crop(self): + fn_kwargs = {"size": (5,)} + meth_kwargs = {"size": (5, )} + self._test_geom_op_list_output( + "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + fn_kwargs = {"size": (4, 5)} + meth_kwargs = {"size": (4, 5)} + self._test_geom_op_list_output( + "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index d4c3a211eef..591cfe2f0d0 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -10,7 +10,7 @@ import torch from torch import Tensor -from torch.jit.annotations import List +from torch.jit.annotations import List, Tuple try: import accimage @@ -423,6 +423,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: img (PIL Image or Tensor): Image to be cropped. output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int it is used for both directions + Returns: PIL Image or Tensor: Cropped image. """ @@ -430,6 +431,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: output_size = (int(output_size), int(output_size)) elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: output_size = (output_size[0], output_size[0]) + image_width, image_height = _get_image_size(img) crop_height, crop_width = output_size @@ -589,8 +591,10 @@ def vflip(img: Tensor) -> Tensor: return F_t.vflip(img) -def five_crop(img, size): - """Crop the given PIL Image into four corners and the central crop. +def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """Crop the given image into four corners and the central crop. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions .. Note:: This transform returns a tuple of images and there may be a @@ -607,22 +611,26 @@ def five_crop(img, size): """ if isinstance(size, numbers.Number): size = (int(size), int(size)) - else: - assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + elif isinstance(size, (tuple, list)) and len(size) == 1: + size = (size[0], size[0]) + + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") - image_width, image_height = img.size + image_width, image_height = _get_image_size(img) crop_height, crop_width = size if crop_width > image_width or crop_height > image_height: msg = "Requested crop size {} is bigger than input size {}" raise ValueError(msg.format(size, (image_height, image_width))) - tl = img.crop((0, 0, crop_width, crop_height)) - tr = img.crop((image_width - crop_width, 0, image_width, crop_height)) - bl = img.crop((0, image_height - crop_height, crop_width, image_height)) - br = img.crop((image_width - crop_width, image_height - crop_height, - image_width, image_height)) - center = center_crop(img, (crop_height, crop_width)) - return (tl, tr, bl, br, center) + tl = crop(img, 0, 0, crop_height, crop_width) + tr = crop(img, 0, image_width - crop_width, crop_height, crop_width) + bl = crop(img, image_height - crop_height, 0, crop_height, crop_width) + br = crop(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width) + + center = center_crop(img, [crop_height, crop_width]) + + return tl, tr, bl, br, center def ten_crop(img, size, vertical_flip=False): diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 342341ed00c..c53c177f3b5 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -260,7 +260,7 @@ class CenterCrop(torch.nn.Module): Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is - made. For scripted operation please use a list: (size, ) or (size_x, size_y) + made. For scripted operation, please use a list: (size, ) or (size_x, size_y) """ def __init__(self, size): @@ -270,6 +270,9 @@ def __init__(self, size): elif isinstance(size, (tuple, list)) and len(size) == 1: self.size = (size[0], size[0]) else: + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") + self.size = size def forward(self, img): @@ -572,7 +575,7 @@ def __repr__(self): class RandomVerticalFlip(torch.nn.Module): - """Vertically flip the given PIL Image randomly with a given probability. + """Vertically flip the given image randomly with a given probability. The image can be a PIL Image or a torch Tensor, in which case it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions @@ -769,8 +772,11 @@ def __init__(self, *args, **kwargs): super(RandomSizedCrop, self).__init__(*args, **kwargs) -class FiveCrop(object): - """Crop the given PIL Image into four corners and the central crop +class FiveCrop(torch.nn.Module): + """Crop the given image into four corners and the central crop. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions .. Note:: This transform returns a tuple of images and there may be a mismatch in the number of @@ -780,6 +786,7 @@ class FiveCrop(object): Args: size (sequence or int): Desired output size of the crop. If size is an ``int`` instead of sequence like (h, w), a square crop of size (size, size) is made. + For scripted operation, please use a list: (size, ) or (size_x, size_y) Example: >>> transform = Compose([ @@ -794,14 +801,26 @@ class FiveCrop(object): """ def __init__(self, size): + super().__init__() self.size = size if isinstance(size, numbers.Number): self.size = (int(size), int(size)) + elif isinstance(size, (tuple, list)) and len(size) == 1: + self.size = (size[0], size[0]) else: - assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") + self.size = size - def __call__(self, img): + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + PIL Image or Tensor: Cropped image. + """ return F.five_crop(img, self.size) def __repr__(self): From b356e8b72664149e4dc8158057d6772be233377c Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 26 Jun 2020 13:14:37 +0200 Subject: [PATCH 06/11] Improved tests and docs --- test/test_transforms_tensor.py | 48 ++++++++++++++++++--- torchvision/transforms/functional.py | 37 ++++++++-------- torchvision/transforms/functional_pil.py | 11 +++-- torchvision/transforms/functional_tensor.py | 41 ++++++++---------- torchvision/transforms/transforms.py | 43 ++++++++++-------- 5 files changed, 111 insertions(+), 69 deletions(-) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 081087088ce..bc4814ce8dd 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -101,11 +101,28 @@ def test_pad(self): def test_crop(self): fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5} - meth_kwargs = {"size": (4, 5), "padding": [4, ], "pad_if_needed": True, } + # Test transforms.RandomCrop with size and padding as tuple + meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, } self._test_geom_op( 'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) + tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8) + # Test torchscript of transforms.RandomCrop with size as int + f = T.RandomCrop(size=5) + scripted_fn = torch.jit.script(f) + scripted_fn(tensor) + + # Test torchscript of transforms.RandomCrop with size as [int, ] + f = T.RandomCrop(size=[5, ], padding=[2, ]) + scripted_fn = torch.jit.script(f) + scripted_fn(tensor) + + # Test torchscript of transforms.RandomCrop with size as list + f = T.RandomCrop(size=[6, 6]) + scripted_fn = torch.jit.script(f) + scripted_fn(tensor) + def test_center_crop(self): fn_kwargs = {"output_size": (4, 5)} meth_kwargs = {"size": (4, 5), } @@ -117,6 +134,21 @@ def test_center_crop(self): self._test_geom_op( "center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) + tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8) + # Test torchscript of transforms.CenterCrop with size as int + f = T.CenterCrop(size=5) + scripted_fn = torch.jit.script(f) + scripted_fn(tensor) + + # Test torchscript of transforms.CenterCrop with size as [int, ] + f = T.CenterCrop(size=[5, ]) + scripted_fn = torch.jit.script(f) + scripted_fn(tensor) + + # Test torchscript of transforms.CenterCrop with size as tuple + f = T.CenterCrop(size=(6, 6)) + scripted_fn = torch.jit.script(f) + scripted_fn(tensor) def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None): if fn_kwargs is None: @@ -146,13 +178,19 @@ def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, me self.assertEqual(len(output), len(transformed_t_list_script)) def test_five_crop(self): - fn_kwargs = {"size": (5,)} - meth_kwargs = {"size": (5, )} + fn_kwargs = meth_kwargs = {"size": (5,)} + self._test_geom_op_list_output( + "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + fn_kwargs = meth_kwargs = {"size": [5, ]} + self._test_geom_op_list_output( + "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + fn_kwargs = meth_kwargs = {"size": (4, 5)} self._test_geom_op_list_output( "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) - fn_kwargs = {"size": (4, 5)} - meth_kwargs = {"size": (4, 5)} + fn_kwargs = meth_kwargs = {"size": [4, 5]} self._test_geom_op_list_output( "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 591cfe2f0d0..a046079e0e6 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -395,7 +395,10 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: - """Crop the given PIL Image. + """Crop the given image at specified location and output size. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions Args: img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image. @@ -416,13 +419,13 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: def center_crop(img: Tensor, output_size: List[int]) -> Tensor: """Crops the given image at the center. - The image can be a PIL Image or a torch Tensor, in which case it is expected + The image can be a PIL Image or a Tensor, in which case it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions Args: img (PIL Image or Tensor): Image to be cropped. output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int - it is used for both directions + it is used for both directions. Returns: PIL Image or Tensor: Cropped image. @@ -469,7 +472,7 @@ def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINE def hflip(img: Tensor) -> Tensor: - """Horizontally flip the given PIL Image or torch Tensor. + """Horizontally flip the given PIL Image or Tensor. Args: img (PIL Image or Tensor): Image to be flipped. If img @@ -531,8 +534,7 @@ def _get_perspective_coeffs(startpoints, endpoints): Args: List containing [top-left, top-right, bottom-right, bottom-left] of the original image, - List containing [top-left, top-right, bottom-right, bottom-left] of the transformed - image + List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image Returns: octuple (a, b, c, d, e, f, g, h) for transforming each pixel. """ @@ -577,7 +579,7 @@ def vflip(img: Tensor) -> Tensor: """Vertically flip the given PIL Image or torch Tensor. Args: - img (PIL Image or Torch Tensor): Image to be flipped. If img + img (PIL Image or Tensor): Image to be flipped. If img is a Tensor, it is expected to be in [..., H, W] format, where ... means it can have an arbitrary number of trailing dimensions. @@ -593,7 +595,7 @@ def vflip(img: Tensor) -> Tensor: def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """Crop the given image into four corners and the central crop. - The image can be a PIL Image or a torch Tensor, in which case it is expected + The image can be a PIL Image or a Tensor, in which case it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions .. Note:: @@ -601,9 +603,10 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten mismatch in the number of inputs and targets your ``Dataset`` returns. Args: - size (sequence or int): Desired output size of the crop. If size is an - int instead of sequence like (h, w), a square crop (size, size) is - made. + img (PIL Image or Tensor): Image to be cropped. + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]). Returns: tuple: tuple (tl, tr, bl, br, center) @@ -673,13 +676,13 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: """Adjust brightness of an Image. Args: - img (PIL Image or Torch Tensor): Image to be adjusted. + img (PIL Image or Tensor): Image to be adjusted. brightness_factor (float): How much to adjust the brightness. Can be any non negative number. 0 gives a black image, 1 gives the original image while 2 increases the brightness by a factor of 2. Returns: - PIL Image or Torch Tensor: Brightness adjusted image. + PIL Image or Tensor: Brightness adjusted image. """ if not isinstance(img, torch.Tensor): return F_pil.adjust_brightness(img, brightness_factor) @@ -691,13 +694,13 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: """Adjust contrast of an Image. Args: - img (PIL Image or Torch Tensor): Image to be adjusted. + img (PIL Image or Tensor): Image to be adjusted. contrast_factor (float): How much to adjust the contrast. Can be any non negative number. 0 gives a solid gray image, 1 gives the original image while 2 increases the contrast by a factor of 2. Returns: - PIL Image or Torch Tensor: Contrast adjusted image. + PIL Image or Tensor: Contrast adjusted image. """ if not isinstance(img, torch.Tensor): return F_pil.adjust_contrast(img, contrast_factor) @@ -709,13 +712,13 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: """Adjust color saturation of an image. Args: - img (PIL Image or Torch Tensor): Image to be adjusted. + img (PIL Image or Tensor): Image to be adjusted. saturation_factor (float): How much to adjust the saturation. 0 will give a black and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2. Returns: - PIL Image or Torch Tensor: Saturation adjusted image. + PIL Image or Tensor: Saturation adjusted image. """ if not isinstance(img, torch.Tensor): return F_pil.adjust_saturation(img, saturation_factor) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 9fbadd4ef8f..04129976d6c 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -1,4 +1,5 @@ import numbers +from typing import Any, List import torch try: @@ -10,8 +11,7 @@ @torch.jit.unused -def _is_pil_image(img): - # type: (Any) -> bool +def _is_pil_image(img: Any) -> bool: if accimage is not None: return isinstance(img, (Image.Image, accimage.Image)) else: @@ -19,15 +19,14 @@ def _is_pil_image(img): @torch.jit.unused -def _get_image_size(img): - # type: (Any) -> List[int] +def _get_image_size(img: Any) -> List[int]: if _is_pil_image(img): return img.size raise TypeError("Unexpected type {}".format(type(img))) @torch.jit.unused -def hflip(img): +def hflip(img: Any): """Horizontally flip the given PIL Image. Args: @@ -43,7 +42,7 @@ def hflip(img): @torch.jit.unused -def vflip(img): +def vflip(img: Any): """Vertically flip the given PIL Image. Args: diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 999a0f6eb79..7e0da179654 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -49,7 +49,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int): """Crop the given Image Tensor. Args: - img (Tensor): Image to be cropped in the form [C, H, W]. (0,0) denotes the top left corner of the image. + img (Tensor): Image to be cropped in the form [..., H, W]. (0,0) denotes the top left corner of the image. top (int): Vertical component of the top left corner of the crop box. left (int): Horizontal component of the top left corner of the crop box. height (int): Height of the crop box. @@ -64,8 +64,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int): return img[..., top:top + height, left:left + width] -def rgb_to_grayscale(img): - # type: (Tensor) -> Tensor +def rgb_to_grayscale(img: Tensor) -> Tensor: """Convert the given RGB Image Tensor to Grayscale. For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which is L = R * 0.2989 + G * 0.5870 + B * 0.1140 @@ -83,8 +82,7 @@ def rgb_to_grayscale(img): return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype) -def adjust_brightness(img, brightness_factor): - # type: (Tensor, float) -> Tensor +def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: """Adjust brightness of an RGB image. Args: @@ -102,8 +100,7 @@ def adjust_brightness(img, brightness_factor): return _blend(img, torch.zeros_like(img), brightness_factor) -def adjust_contrast(img, contrast_factor): - # type: (Tensor, float) -> Tensor +def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: """Adjust contrast of an RGB image. Args: @@ -171,8 +168,7 @@ def adjust_hue(img, hue_factor): return img_hue_adj -def adjust_saturation(img, saturation_factor): - # type: (Tensor, float) -> Tensor +def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: """Adjust color saturation of an RGB image. Args: @@ -190,12 +186,11 @@ def adjust_saturation(img, saturation_factor): return _blend(img, rgb_to_grayscale(img), saturation_factor) -def center_crop(img, output_size): - # type: (Tensor, BroadcastingList2[int]) -> Tensor +def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor: """Crop the Image Tensor and resize it to desired size. Args: - img (Tensor): Image to be cropped. (0,0) denotes the top left corner of the image. + img (Tensor): Image to be cropped. output_size (sequence or int): (height, width) of the crop box. If int, it is used for both directions @@ -213,17 +208,17 @@ def center_crop(img, output_size): return crop(img, crop_top, crop_left, crop_height, crop_width) -def five_crop(img, size): - # type: (Tensor, BroadcastingList2[int]) -> List[Tensor] +def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]: """Crop the given Image Tensor into four corners and the central crop. .. Note:: This transform returns a List of Tensors and there may be a mismatch in the number of inputs and targets your ``Dataset`` returns. Args: - size (sequence or int): Desired output size of the crop. If size is an - int instead of sequence like (h, w), a square crop (size, size) is - made. + img (Tensor): Image to be cropped. + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. Returns: List: List (tl, tr, bl, br, center) @@ -249,19 +244,20 @@ def five_crop(img, size): return [tl, tr, bl, br, center] -def ten_crop(img, size, vertical_flip=False): - # type: (Tensor, BroadcastingList2[int], bool) -> List[Tensor] +def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = False) -> List[Tensor]: """Crop the given Image Tensor into four corners and the central crop plus the flipped version of these (horizontal flipping is used by default). + .. Note:: This transform returns a List of images and there may be a mismatch in the number of inputs and targets your ``Dataset`` returns. Args: - size (sequence or int): Desired output size of the crop. If size is an + img (Tensor): Image to be cropped. + size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. - vertical_flip (bool): Use vertical flipping instead of horizontal + vertical_flip (bool): Use vertical flipping instead of horizontal Returns: List: List (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip) @@ -284,8 +280,7 @@ def ten_crop(img, size, vertical_flip=False): return first_five + second_five -def _blend(img1, img2, ratio): - # type: (Tensor, Tensor, float) -> Tensor +def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255 return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index c53c177f3b5..b6033f7e17a 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -260,14 +260,14 @@ class CenterCrop(torch.nn.Module): Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is - made. For scripted operation, please use a list: (size, ) or (size_x, size_y) + made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]). """ def __init__(self, size): super().__init__() if isinstance(size, numbers.Number): self.size = (int(size), int(size)) - elif isinstance(size, (tuple, list)) and len(size) == 1: + elif isinstance(size, Sequence) and len(size) == 1: self.size = (size[0], size[0]) else: if len(size) != 2: @@ -447,26 +447,28 @@ def __call__(self, img): class RandomCrop(torch.nn.Module): """Crop the given image at a random location. - The image can be a PIL Image or a torch Tensor, in which case it is expected + The image can be a PIL Image or a Tensor, in which case it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is - made. + made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]). padding (int or sequence, optional): Optional padding on each border - of the image. Default is None, i.e no padding. If a sequence of length - 4 is provided, it is used to pad left, top, right, bottom borders - respectively. If a sequence of length 2 is provided, it is used to - pad left/right, top/bottom borders, respectively. + of the image. Default is None. If a single int is provided this + is used to pad all borders. If tuple of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a tuple of length 4 is provided + this is the padding for the left, top, right and bottom borders respectively. + In torchscript mode padding as single int is not supported, use a tuple or + list of length 1: ``[padding, ]``. pad_if_needed (boolean): It will pad the image if smaller than the desired size to avoid raising an exception. Since cropping is done after padding, the padding seems to be done at a random offset. - fill: Pixel fill value for constant fill. Default is 0. If a tuple of + fill (int or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. This value is only used when the padding_mode is constant - padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. + padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. - constant: pads with a constant value, this value is specified with fill @@ -489,7 +491,7 @@ def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int """Get parameters for ``crop`` for a random crop. Args: - img (PIL Image): Image to be cropped. + img (PIL Image or Tensor): Image to be cropped. output_size (tuple): Expected output size of the crop. Returns: @@ -504,12 +506,18 @@ def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int j = torch.randint(0, w - tw, size=(1, )).item() return i, j, th, tw - def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'): + def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): super().__init__() if isinstance(size, numbers.Number): self.size = (int(size), int(size)) + elif isinstance(size, Sequence) and len(size) == 1: + self.size = (size[0], size[0]) else: - self.size = size + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") + + # cast to tuple for torchscript + self.size = tuple(size) self.padding = padding self.pad_if_needed = pad_if_needed self.fill = fill @@ -541,7 +549,7 @@ def forward(self, img): return F.crop(img, i, j, h, w) def __repr__(self): - return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) + return self.__class__.__name__ + "(size={0}, padding={1})".format(self.size, self.padding) class RandomHorizontalFlip(torch.nn.Module): @@ -774,7 +782,7 @@ def __init__(self, *args, **kwargs): class FiveCrop(torch.nn.Module): """Crop the given image into four corners and the central crop. - The image can be a PIL Image or a torch Tensor, in which case it is expected + The image can be a PIL Image or a Tensor, in which case it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions @@ -786,7 +794,7 @@ class FiveCrop(torch.nn.Module): Args: size (sequence or int): Desired output size of the crop. If size is an ``int`` instead of sequence like (h, w), a square crop of size (size, size) is made. - For scripted operation, please use a list: (size, ) or (size_x, size_y) + If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]). Example: >>> transform = Compose([ @@ -802,10 +810,9 @@ class FiveCrop(torch.nn.Module): def __init__(self, size): super().__init__() - self.size = size if isinstance(size, numbers.Number): self.size = (int(size), int(size)) - elif isinstance(size, (tuple, list)) and len(size) == 1: + elif isinstance(size, Sequence) and len(size) == 1: self.size = (size[0], size[0]) else: if len(size) != 2: From 94e29a486f23d7cbb38bbcba4280c533f0f45e95 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 26 Jun 2020 14:20:55 +0200 Subject: [PATCH 07/11] Unified TenCrop and F.ten_crop --- test/test_transforms_tensor.py | 18 +++++++++++++++++ torchvision/transforms/functional.py | 18 +++++++++++------ torchvision/transforms/transforms.py | 30 ++++++++++++++++++++-------- 3 files changed, 52 insertions(+), 14 deletions(-) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index bc4814ce8dd..1d2d92bb3ae 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -195,6 +195,24 @@ def test_five_crop(self): "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) + def test_ten_crop(self): + fn_kwargs = meth_kwargs = {"size": (5,)} + self._test_geom_op_list_output( + "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + fn_kwargs = meth_kwargs = {"size": [5, ]} + self._test_geom_op_list_output( + "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + fn_kwargs = meth_kwargs = {"size": (4, 5)} + self._test_geom_op_list_output( + "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + fn_kwargs = meth_kwargs = {"size": [4, 5]} + self._test_geom_op_list_output( + "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index a046079e0e6..a64da1e0054 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -636,19 +636,22 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten return tl, tr, bl, br, center -def ten_crop(img, size, vertical_flip=False): - """Generate ten cropped images from the given PIL Image. - Crop the given PIL Image into four corners and the central crop plus the +def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[Tensor]: + """Generate ten cropped images from the given image. + Crop the given image into four corners and the central crop plus the flipped version of these (horizontal flipping is used by default). + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions .. Note:: This transform returns a tuple of images and there may be a mismatch in the number of inputs and targets your ``Dataset`` returns. Args: + img (PIL Image or Tensor): Image to be cropped. size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is - made. + made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]). vertical_flip (bool): Use vertical flipping instead of horizontal Returns: @@ -658,8 +661,11 @@ def ten_crop(img, size, vertical_flip=False): """ if isinstance(size, numbers.Number): size = (int(size), int(size)) - else: - assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + elif isinstance(size, (tuple, list)) and len(size) == 1: + size = (size[0], size[0]) + + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") first_five = five_crop(img, size) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index b6033f7e17a..90b14ed7de8 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -826,7 +826,7 @@ def forward(self, img): img (PIL Image or Tensor): Image to be cropped. Returns: - PIL Image or Tensor: Cropped image. + tuple of 5 images. Image can be PIL Image or Tensor """ return F.five_crop(img, self.size) @@ -834,9 +834,12 @@ def __repr__(self): return self.__class__.__name__ + '(size={0})'.format(self.size) -class TenCrop(object): - """Crop the given PIL Image into four corners and the central crop plus the flipped version of - these (horizontal flipping is used by default) +class TenCrop(torch.nn.Module): + """Crop the given image into four corners and the central crop plus the flipped version of + these (horizontal flipping is used by default). + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions .. Note:: This transform returns a tuple of images and there may be a mismatch in the number of @@ -846,7 +849,7 @@ class TenCrop(object): Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is - made. + made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]). vertical_flip (bool): Use vertical flipping instead of horizontal Example: @@ -862,15 +865,26 @@ class TenCrop(object): """ def __init__(self, size, vertical_flip=False): - self.size = size + super().__init__() if isinstance(size, numbers.Number): self.size = (int(size), int(size)) + elif isinstance(size, Sequence) and len(size) == 1: + self.size = (size[0], size[0]) else: - assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") + self.size = size self.vertical_flip = vertical_flip - def __call__(self, img): + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + tuple of 10 images. Image can be PIL Image or Tensor + """ return F.ten_crop(img, self.size, self.vertical_flip) def __repr__(self): From a74bfbc6a111a13c652ce41c0e08c4ded46d8439 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 26 Jun 2020 15:29:46 +0200 Subject: [PATCH 08/11] Removed useless typing in functional_pil --- torchvision/transforms/functional_pil.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 04129976d6c..f1bcda113aa 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -26,7 +26,7 @@ def _get_image_size(img: Any) -> List[int]: @torch.jit.unused -def hflip(img: Any): +def hflip(img): """Horizontally flip the given PIL Image. Args: @@ -42,7 +42,7 @@ def hflip(img: Any): @torch.jit.unused -def vflip(img: Any): +def vflip(img): """Vertically flip the given PIL Image. Args: From 8835cf5321d2c7dfbedc564641b7e7917a7bff6f Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 29 Jun 2020 21:51:27 +0200 Subject: [PATCH 09/11] Updated code according to the review - removed useless torch.jit.export - added missing typing return type - fixed F.F_pil._is_pil_image -> F._is_pil_image --- torchvision/transforms/functional.py | 3 +++ torchvision/transforms/functional_tensor.py | 4 +--- torchvision/transforms/transforms.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index a64da1e0054..7b7d4cbcb3d 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -21,6 +21,9 @@ from . import functional_tensor as F_t +_is_pil_image = F_pil._is_pil_image + + @torch.jit.export def _get_image_size(img: Tensor) -> List[int]: """Returns image sizea as (w, h) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 7e0da179654..8324abf5381 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -3,12 +3,10 @@ from torch.jit.annotations import List, BroadcastingList2 -@torch.jit.export def _is_tensor_a_torch_image(x: Tensor) -> bool: return x.ndim >= 2 -@torch.jit.export def _get_image_size(img: Tensor) -> List[int]: if _is_tensor_a_torch_image(img): return [img.shape[-1], img.shape[-2]] @@ -45,7 +43,7 @@ def hflip(img: Tensor) -> Tensor: return img.flip(-1) -def crop(img: Tensor, top: int, left: int, height: int, width: int): +def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: """Crop the given Image Tensor. Args: diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 90b14ed7de8..6ee266d5f79 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -640,7 +640,7 @@ def __call__(self, img): Returns: PIL Image: Random perspectivley transformed image. """ - if not F.F_pil._is_pil_image(img): + if not F._is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) if random.random() < self.p: From 749a9385954e878283770db4ebf852aa07706ef1 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 29 Jun 2020 22:00:08 +0200 Subject: [PATCH 10/11] Removed useless torch.jit.export --- torchvision/transforms/functional.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 7b7d4cbcb3d..43a07476f8b 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -24,7 +24,6 @@ _is_pil_image = F_pil._is_pil_image -@torch.jit.export def _get_image_size(img: Tensor) -> List[int]: """Returns image sizea as (w, h) """ From dbb2295b0a6d46ffc7add6e1f05e88559ee86e7c Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 29 Jun 2020 22:00:08 +0200 Subject: [PATCH 11/11] Improved code according to the review --- torchvision/transforms/functional.py | 7 ++++--- torchvision/transforms/functional_tensor.py | 10 ++++++++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 43a07476f8b..cda26348552 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -33,12 +33,12 @@ def _get_image_size(img: Tensor) -> List[int]: return F_pil._get_image_size(img) -@torch.jit.ignore +@torch.jit.unused def _is_numpy(img: Any) -> bool: return isinstance(img, np.ndarray) -@torch.jit.ignore +@torch.jit.unused def _is_numpy_image(img: Any) -> bool: return img.ndim in {2, 3} @@ -442,11 +442,12 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: # crop_top = int(round((image_height - crop_height) / 2.)) # Result can be different between python func and scripted func + # Temporary workaround: crop_top = int((image_height - crop_height + 1) * 0.5) # crop_left = int(round((image_width - crop_width) / 2.)) # Result can be different between python func and scripted func + # Temporary workaround: crop_left = int((image_width - crop_width + 1) * 0.5) - return crop(img, crop_top, crop_left, crop_height, crop_width) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 8324abf5381..980e67d692f 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -200,8 +200,14 @@ def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor: _, image_width, image_height = img.size() crop_height, crop_width = output_size - crop_top = int(round((image_height - crop_height) / 2.)) - crop_left = int(round((image_width - crop_width) / 2.)) + # crop_top = int(round((image_height - crop_height) / 2.)) + # Result can be different between python func and scripted func + # Temporary workaround: + crop_top = int((image_height - crop_height + 1) * 0.5) + # crop_left = int(round((image_width - crop_width) / 2.)) + # Result can be different between python func and scripted func + # Temporary workaround: + crop_left = int((image_width - crop_width + 1) * 0.5) return crop(img, crop_top, crop_left, crop_height, crop_width)