diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 1479602b534..1d2d92bb3ae 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -99,6 +99,120 @@ def test_pad(self): "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) + def test_crop(self): + fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5} + # Test transforms.RandomCrop with size and padding as tuple + meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, } + self._test_geom_op( + 'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + + tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8) + # Test torchscript of transforms.RandomCrop with size as int + f = T.RandomCrop(size=5) + scripted_fn = torch.jit.script(f) + scripted_fn(tensor) + + # Test torchscript of transforms.RandomCrop with size as [int, ] + f = T.RandomCrop(size=[5, ], padding=[2, ]) + scripted_fn = torch.jit.script(f) + scripted_fn(tensor) + + # Test torchscript of transforms.RandomCrop with size as list + f = T.RandomCrop(size=[6, 6]) + scripted_fn = torch.jit.script(f) + scripted_fn(tensor) + + def test_center_crop(self): + fn_kwargs = {"output_size": (4, 5)} + meth_kwargs = {"size": (4, 5), } + self._test_geom_op( + "center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + fn_kwargs = {"output_size": (5,)} + meth_kwargs = {"size": (5, )} + self._test_geom_op( + "center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8) + # Test torchscript of transforms.CenterCrop with size as int + f = T.CenterCrop(size=5) + scripted_fn = torch.jit.script(f) + scripted_fn(tensor) + + # Test torchscript of transforms.CenterCrop with size as [int, ] + f = T.CenterCrop(size=[5, ]) + scripted_fn = torch.jit.script(f) + scripted_fn(tensor) + + # Test torchscript of transforms.CenterCrop with size as tuple + f = T.CenterCrop(size=(6, 6)) + scripted_fn = torch.jit.script(f) + scripted_fn(tensor) + + def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None): + if fn_kwargs is None: + fn_kwargs = {} + if meth_kwargs is None: + meth_kwargs = {} + tensor, pil_img = self._create_data(height=20, width=20) + transformed_t_list = getattr(F, func)(tensor, **fn_kwargs) + transformed_p_list = getattr(F, func)(pil_img, **fn_kwargs) + self.assertEqual(len(transformed_t_list), len(transformed_p_list)) + self.assertEqual(len(transformed_t_list), out_length) + for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list): + self.compareTensorToPIL(transformed_tensor, transformed_pil_img) + + scripted_fn = torch.jit.script(getattr(F, func)) + transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs) + self.assertEqual(len(transformed_t_list), len(transformed_t_list_script)) + self.assertEqual(len(transformed_t_list_script), out_length) + for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script): + self.assertTrue(transformed_tensor.equal(transformed_tensor_script), + msg="{} vs {}".format(transformed_tensor, transformed_tensor_script)) + + # test for class interface + f = getattr(T, method)(**meth_kwargs) + scripted_fn = torch.jit.script(f) + output = scripted_fn(tensor) + self.assertEqual(len(output), len(transformed_t_list_script)) + + def test_five_crop(self): + fn_kwargs = meth_kwargs = {"size": (5,)} + self._test_geom_op_list_output( + "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + fn_kwargs = meth_kwargs = {"size": [5, ]} + self._test_geom_op_list_output( + "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + fn_kwargs = meth_kwargs = {"size": (4, 5)} + self._test_geom_op_list_output( + "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + fn_kwargs = meth_kwargs = {"size": [4, 5]} + self._test_geom_op_list_output( + "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + + def test_ten_crop(self): + fn_kwargs = meth_kwargs = {"size": (5,)} + self._test_geom_op_list_output( + "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + fn_kwargs = meth_kwargs = {"size": [5, ]} + self._test_geom_op_list_output( + "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + fn_kwargs = meth_kwargs = {"size": (4, 5)} + self._test_geom_op_list_output( + "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + fn_kwargs = meth_kwargs = {"size": [4, 5]} + self._test_geom_op_list_output( + "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 06a54c6aa5f..cda26348552 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -2,6 +2,7 @@ import numbers import warnings from collections.abc import Iterable +from typing import Any import numpy as np from numpy import sin, cos, tan @@ -9,7 +10,7 @@ import torch from torch import Tensor -from torch.jit.annotations import List +from torch.jit.annotations import List, Tuple try: import accimage @@ -20,18 +21,25 @@ from . import functional_tensor as F_t -def _is_pil_image(img): - if accimage is not None: - return isinstance(img, (Image.Image, accimage.Image)) - else: - return isinstance(img, Image.Image) +_is_pil_image = F_pil._is_pil_image + + +def _get_image_size(img: Tensor) -> List[int]: + """Returns image sizea as (w, h) + """ + if isinstance(img, torch.Tensor): + return F_t._get_image_size(img) + return F_pil._get_image_size(img) -def _is_numpy(img): + +@torch.jit.unused +def _is_numpy(img: Any) -> bool: return isinstance(img, np.ndarray) -def _is_numpy_image(img): +@torch.jit.unused +def _is_numpy_image(img: Any) -> bool: return img.ndim in {2, 3} @@ -46,7 +54,7 @@ def to_tensor(pic): Returns: Tensor: Converted image. """ - if not(_is_pil_image(pic) or _is_numpy(pic)): + if not(F_pil._is_pil_image(pic) or _is_numpy(pic)): raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) if _is_numpy(pic) and not _is_numpy_image(pic): @@ -101,7 +109,7 @@ def pil_to_tensor(pic): Returns: Tensor: Converted image. """ - if not(_is_pil_image(pic)): + if not(F_pil._is_pil_image(pic)): raise TypeError('pic should be PIL Image. Got {}'.format(type(pic))) if accimage is not None and isinstance(pic, accimage.Image): @@ -319,7 +327,7 @@ def resize(img, size, interpolation=Image.BILINEAR): Returns: PIL Image: Resized image. """ - if not _is_pil_image(img): + if not F_pil._is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)): raise TypeError('Got inappropriate size arg: {}'.format(size)) @@ -388,41 +396,58 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode) -def crop(img, top, left, height, width): - """Crop the given PIL Image. +def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: + """Crop the given image at specified location and output size. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions Args: - img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. + img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image. top (int): Vertical component of the top left corner of the crop box. left (int): Horizontal component of the top left corner of the crop box. height (int): Height of the crop box. width (int): Width of the crop box. Returns: - PIL Image: Cropped image. + PIL Image or Tensor: Cropped image. """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - return img.crop((left, top, left + width, top + height)) + if not isinstance(img, torch.Tensor): + return F_pil.crop(img, top, left, height, width) + return F_t.crop(img, top, left, height, width) -def center_crop(img, output_size): - """Crop the given PIL Image and resize it to desired size. + +def center_crop(img: Tensor, output_size: List[int]) -> Tensor: + """Crops the given image at the center. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions Args: - img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. - output_size (sequence or int): (height, width) of the crop box. If int, - it is used for both directions + img (PIL Image or Tensor): Image to be cropped. + output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int + it is used for both directions. + Returns: - PIL Image: Cropped image. + PIL Image or Tensor: Cropped image. """ if isinstance(output_size, numbers.Number): output_size = (int(output_size), int(output_size)) - image_width, image_height = img.size + elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: + output_size = (output_size[0], output_size[0]) + + image_width, image_height = _get_image_size(img) crop_height, crop_width = output_size - crop_top = int(round((image_height - crop_height) / 2.)) - crop_left = int(round((image_width - crop_width) / 2.)) + + # crop_top = int(round((image_height - crop_height) / 2.)) + # Result can be different between python func and scripted func + # Temporary workaround: + crop_top = int((image_height - crop_height + 1) * 0.5) + # crop_left = int(round((image_width - crop_width) / 2.)) + # Result can be different between python func and scripted func + # Temporary workaround: + crop_left = int((image_width - crop_width + 1) * 0.5) return crop(img, crop_top, crop_left, crop_height, crop_width) @@ -443,23 +468,23 @@ def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINE Returns: PIL Image: Cropped image. """ - assert _is_pil_image(img), 'img should be PIL Image' + assert F_pil._is_pil_image(img), 'img should be PIL Image' img = crop(img, top, left, height, width) img = resize(img, size, interpolation) return img def hflip(img: Tensor) -> Tensor: - """Horizontally flip the given PIL Image or torch Tensor. + """Horizontally flip the given PIL Image or Tensor. Args: - img (PIL Image or Torch Tensor): Image to be flipped. If img + img (PIL Image or Tensor): Image to be flipped. If img is a Tensor, it is expected to be in [..., H, W] format, where ... means it can have an arbitrary number of trailing dimensions. Returns: - PIL Image: Horizontally flipped image. + PIL Image or Tensor: Horizontally flipped image. """ if not isinstance(img, torch.Tensor): return F_pil.hflip(img) @@ -512,8 +537,7 @@ def _get_perspective_coeffs(startpoints, endpoints): Args: List containing [top-left, top-right, bottom-right, bottom-left] of the original image, - List containing [top-left, top-right, bottom-right, bottom-left] of the transformed - image + List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image Returns: octuple (a, b, c, d, e, f, g, h) for transforming each pixel. """ @@ -545,7 +569,7 @@ def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=N PIL Image: Perspectively transformed Image. """ - if not _is_pil_image(img): + if not F_pil._is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) opts = _parse_fill(fill, img, '5.0.0') @@ -558,7 +582,7 @@ def vflip(img: Tensor) -> Tensor: """Vertically flip the given PIL Image or torch Tensor. Args: - img (PIL Image or Torch Tensor): Image to be flipped. If img + img (PIL Image or Tensor): Image to be flipped. If img is a Tensor, it is expected to be in [..., H, W] format, where ... means it can have an arbitrary number of trailing dimensions. @@ -572,17 +596,20 @@ def vflip(img: Tensor) -> Tensor: return F_t.vflip(img) -def five_crop(img, size): - """Crop the given PIL Image into four corners and the central crop. +def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """Crop the given image into four corners and the central crop. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions .. Note:: This transform returns a tuple of images and there may be a mismatch in the number of inputs and targets your ``Dataset`` returns. Args: - size (sequence or int): Desired output size of the crop. If size is an - int instead of sequence like (h, w), a square crop (size, size) is - made. + img (PIL Image or Tensor): Image to be cropped. + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]). Returns: tuple: tuple (tl, tr, bl, br, center) @@ -590,37 +617,44 @@ def five_crop(img, size): """ if isinstance(size, numbers.Number): size = (int(size), int(size)) - else: - assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + elif isinstance(size, (tuple, list)) and len(size) == 1: + size = (size[0], size[0]) - image_width, image_height = img.size + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") + + image_width, image_height = _get_image_size(img) crop_height, crop_width = size if crop_width > image_width or crop_height > image_height: msg = "Requested crop size {} is bigger than input size {}" raise ValueError(msg.format(size, (image_height, image_width))) - tl = img.crop((0, 0, crop_width, crop_height)) - tr = img.crop((image_width - crop_width, 0, image_width, crop_height)) - bl = img.crop((0, image_height - crop_height, crop_width, image_height)) - br = img.crop((image_width - crop_width, image_height - crop_height, - image_width, image_height)) - center = center_crop(img, (crop_height, crop_width)) - return (tl, tr, bl, br, center) + tl = crop(img, 0, 0, crop_height, crop_width) + tr = crop(img, 0, image_width - crop_width, crop_height, crop_width) + bl = crop(img, image_height - crop_height, 0, crop_height, crop_width) + br = crop(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width) + + center = center_crop(img, [crop_height, crop_width]) + + return tl, tr, bl, br, center -def ten_crop(img, size, vertical_flip=False): - """Generate ten cropped images from the given PIL Image. - Crop the given PIL Image into four corners and the central crop plus the +def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[Tensor]: + """Generate ten cropped images from the given image. + Crop the given image into four corners and the central crop plus the flipped version of these (horizontal flipping is used by default). + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions .. Note:: This transform returns a tuple of images and there may be a mismatch in the number of inputs and targets your ``Dataset`` returns. Args: + img (PIL Image or Tensor): Image to be cropped. size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is - made. + made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]). vertical_flip (bool): Use vertical flipping instead of horizontal Returns: @@ -630,8 +664,11 @@ def ten_crop(img, size, vertical_flip=False): """ if isinstance(size, numbers.Number): size = (int(size), int(size)) - else: - assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + elif isinstance(size, (tuple, list)) and len(size) == 1: + size = (size[0], size[0]) + + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") first_five = five_crop(img, size) @@ -648,13 +685,13 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: """Adjust brightness of an Image. Args: - img (PIL Image or Torch Tensor): Image to be adjusted. + img (PIL Image or Tensor): Image to be adjusted. brightness_factor (float): How much to adjust the brightness. Can be any non negative number. 0 gives a black image, 1 gives the original image while 2 increases the brightness by a factor of 2. Returns: - PIL Image or Torch Tensor: Brightness adjusted image. + PIL Image or Tensor: Brightness adjusted image. """ if not isinstance(img, torch.Tensor): return F_pil.adjust_brightness(img, brightness_factor) @@ -666,13 +703,13 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: """Adjust contrast of an Image. Args: - img (PIL Image or Torch Tensor): Image to be adjusted. + img (PIL Image or Tensor): Image to be adjusted. contrast_factor (float): How much to adjust the contrast. Can be any non negative number. 0 gives a solid gray image, 1 gives the original image while 2 increases the contrast by a factor of 2. Returns: - PIL Image or Torch Tensor: Contrast adjusted image. + PIL Image or Tensor: Contrast adjusted image. """ if not isinstance(img, torch.Tensor): return F_pil.adjust_contrast(img, contrast_factor) @@ -684,13 +721,13 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: """Adjust color saturation of an image. Args: - img (PIL Image or Torch Tensor): Image to be adjusted. + img (PIL Image or Tensor): Image to be adjusted. saturation_factor (float): How much to adjust the saturation. 0 will give a black and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2. Returns: - PIL Image or Torch Tensor: Saturation adjusted image. + PIL Image or Tensor: Saturation adjusted image. """ if not isinstance(img, torch.Tensor): return F_pil.adjust_saturation(img, saturation_factor) @@ -749,7 +786,7 @@ def adjust_gamma(img, gamma, gain=1): while gamma smaller than 1 make dark regions lighter. gain (float): The constant multiplier. """ - if not _is_pil_image(img): + if not F_pil._is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) if gamma < 0: @@ -789,7 +826,7 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=None): .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters """ - if not _is_pil_image(img): + if not F_pil._is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) opts = _parse_fill(fill, img, '5.2.0') @@ -870,7 +907,7 @@ def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None): If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) """ - if not _is_pil_image(img): + if not F_pil._is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ @@ -897,7 +934,7 @@ def to_grayscale(img, num_output_channels=1): if num_output_channels = 3 : returned image is 3 channel with r = g = b """ - if not _is_pil_image(img): + if not F_pil._is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) if num_output_channels == 1: diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 3786d0e31a7..f1bcda113aa 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -1,4 +1,5 @@ import numbers +from typing import Any, List import torch try: @@ -10,13 +11,20 @@ @torch.jit.unused -def _is_pil_image(img): +def _is_pil_image(img: Any) -> bool: if accimage is not None: return isinstance(img, (Image.Image, accimage.Image)) else: return isinstance(img, Image.Image) +@torch.jit.unused +def _get_image_size(img: Any) -> List[int]: + if _is_pil_image(img): + return img.size + raise TypeError("Unexpected type {}".format(type(img))) + + @torch.jit.unused def hflip(img): """Horizontally flip the given PIL Image. @@ -258,3 +266,23 @@ def pad(img, padding, fill=0, padding_mode="constant"): img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) return Image.fromarray(img) + + +@torch.jit.unused +def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Image.Image: + """Crop the given PIL Image. + + Args: + img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. + top (int): Vertical component of the top left corner of the crop box. + left (int): Horizontal component of the top left corner of the crop box. + height (int): Height of the crop box. + width (int): Width of the crop box. + + Returns: + PIL Image: Cropped image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.crop((left, top, left + width, top + height)) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 56703d0a1fd..980e67d692f 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -3,12 +3,17 @@ from torch.jit.annotations import List, BroadcastingList2 -def _is_tensor_a_torch_image(input): - return input.ndim >= 2 +def _is_tensor_a_torch_image(x: Tensor) -> bool: + return x.ndim >= 2 -def vflip(img): - # type: (Tensor) -> Tensor +def _get_image_size(img: Tensor) -> List[int]: + if _is_tensor_a_torch_image(img): + return [img.shape[-1], img.shape[-2]] + raise TypeError("Unexpected type {}".format(type(img))) + + +def vflip(img: Tensor) -> Tensor: """Vertically flip the given the Image Tensor. Args: @@ -23,8 +28,7 @@ def vflip(img): return img.flip(-2) -def hflip(img): - # type: (Tensor) -> Tensor +def hflip(img: Tensor) -> Tensor: """Horizontally flip the given the Image Tensor. Args: @@ -39,12 +43,11 @@ def hflip(img): return img.flip(-1) -def crop(img, top, left, height, width): - # type: (Tensor, int, int, int, int) -> Tensor +def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: """Crop the given Image Tensor. Args: - img (Tensor): Image to be cropped in the form [C, H, W]. (0,0) denotes the top left corner of the image. + img (Tensor): Image to be cropped in the form [..., H, W]. (0,0) denotes the top left corner of the image. top (int): Vertical component of the top left corner of the crop box. left (int): Horizontal component of the top left corner of the crop box. height (int): Height of the crop box. @@ -54,13 +57,12 @@ def crop(img, top, left, height, width): Tensor: Cropped image. """ if not _is_tensor_a_torch_image(img): - raise TypeError('tensor is not a torch image.') + raise TypeError("tensor is not a torch image.") return img[..., top:top + height, left:left + width] -def rgb_to_grayscale(img): - # type: (Tensor) -> Tensor +def rgb_to_grayscale(img: Tensor) -> Tensor: """Convert the given RGB Image Tensor to Grayscale. For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which is L = R * 0.2989 + G * 0.5870 + B * 0.1140 @@ -78,8 +80,7 @@ def rgb_to_grayscale(img): return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype) -def adjust_brightness(img, brightness_factor): - # type: (Tensor, float) -> Tensor +def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: """Adjust brightness of an RGB image. Args: @@ -97,8 +98,7 @@ def adjust_brightness(img, brightness_factor): return _blend(img, torch.zeros_like(img), brightness_factor) -def adjust_contrast(img, contrast_factor): - # type: (Tensor, float) -> Tensor +def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: """Adjust contrast of an RGB image. Args: @@ -166,8 +166,7 @@ def adjust_hue(img, hue_factor): return img_hue_adj -def adjust_saturation(img, saturation_factor): - # type: (Tensor, float) -> Tensor +def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: """Adjust color saturation of an RGB image. Args: @@ -185,12 +184,11 @@ def adjust_saturation(img, saturation_factor): return _blend(img, rgb_to_grayscale(img), saturation_factor) -def center_crop(img, output_size): - # type: (Tensor, BroadcastingList2[int]) -> Tensor +def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor: """Crop the Image Tensor and resize it to desired size. Args: - img (Tensor): Image to be cropped. (0,0) denotes the top left corner of the image. + img (Tensor): Image to be cropped. output_size (sequence or int): (height, width) of the crop box. If int, it is used for both directions @@ -202,23 +200,29 @@ def center_crop(img, output_size): _, image_width, image_height = img.size() crop_height, crop_width = output_size - crop_top = int(round((image_height - crop_height) / 2.)) - crop_left = int(round((image_width - crop_width) / 2.)) + # crop_top = int(round((image_height - crop_height) / 2.)) + # Result can be different between python func and scripted func + # Temporary workaround: + crop_top = int((image_height - crop_height + 1) * 0.5) + # crop_left = int(round((image_width - crop_width) / 2.)) + # Result can be different between python func and scripted func + # Temporary workaround: + crop_left = int((image_width - crop_width + 1) * 0.5) return crop(img, crop_top, crop_left, crop_height, crop_width) -def five_crop(img, size): - # type: (Tensor, BroadcastingList2[int]) -> List[Tensor] +def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]: """Crop the given Image Tensor into four corners and the central crop. .. Note:: This transform returns a List of Tensors and there may be a mismatch in the number of inputs and targets your ``Dataset`` returns. Args: - size (sequence or int): Desired output size of the crop. If size is an - int instead of sequence like (h, w), a square crop (size, size) is - made. + img (Tensor): Image to be cropped. + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. Returns: List: List (tl, tr, bl, br, center) @@ -244,19 +248,20 @@ def five_crop(img, size): return [tl, tr, bl, br, center] -def ten_crop(img, size, vertical_flip=False): - # type: (Tensor, BroadcastingList2[int], bool) -> List[Tensor] +def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = False) -> List[Tensor]: """Crop the given Image Tensor into four corners and the central crop plus the flipped version of these (horizontal flipping is used by default). + .. Note:: This transform returns a List of images and there may be a mismatch in the number of inputs and targets your ``Dataset`` returns. Args: - size (sequence or int): Desired output size of the crop. If size is an + img (Tensor): Image to be cropped. + size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. - vertical_flip (bool): Use vertical flipping instead of horizontal + vertical_flip (bool): Use vertical flipping instead of horizontal Returns: List: List (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip) @@ -279,8 +284,7 @@ def ten_crop(img, size, vertical_flip=False): return first_five + second_five -def _blend(img1, img2, ratio): - # type: (Tensor, Tensor, float) -> Tensor +def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255 return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index edf68f63127..6ee266d5f79 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1,16 +1,19 @@ -import torch import math +import numbers import random +import warnings +from collections.abc import Sequence, Iterable +from typing import Tuple + +import numpy as np +import torch from PIL import Image +from torch import Tensor + try: import accimage except ImportError: accimage = None -import numpy as np -import numbers -import types -from collections.abc import Sequence, Iterable -import warnings from . import functional as F @@ -31,15 +34,6 @@ } -def _get_image_size(img): - if F._is_pil_image(img): - return img.size - elif isinstance(img, torch.Tensor) and img.dim() > 2: - return img.shape[-2:][::-1] - else: - raise TypeError("Unexpected type {}".format(type(img))) - - class Compose(object): """Composes several transforms together. @@ -98,7 +92,7 @@ def __repr__(self): class PILToTensor(object): """Convert a ``PIL Image`` to a tensor of the same type. - Converts a PIL Image (H x W x C) to a torch.Tensor of shape (C x H x W). + Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W). """ def __call__(self, pic): @@ -258,28 +252,36 @@ def __init__(self, *args, **kwargs): super(Scale, self).__init__(*args, **kwargs) -class CenterCrop(object): - """Crops the given PIL Image at the center. +class CenterCrop(torch.nn.Module): + """Crops the given image at the center. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is - made. + made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]). """ def __init__(self, size): + super().__init__() if isinstance(size, numbers.Number): self.size = (int(size), int(size)) + elif isinstance(size, Sequence) and len(size) == 1: + self.size = (size[0], size[0]) else: + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") + self.size = size - def __call__(self, img): + def forward(self, img): """ Args: - img (PIL Image): Image to be cropped. + img (PIL Image or Tensor): Image to be cropped. Returns: - PIL Image: Cropped image. + PIL Image or Tensor: Cropped image. """ return F.center_crop(img, self.size) @@ -443,25 +445,30 @@ def __call__(self, img): return t(img) -class RandomCrop(object): - """Crop the given PIL Image at a random location. +class RandomCrop(torch.nn.Module): + """Crop the given image at a random location. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is - made. + made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]). padding (int or sequence, optional): Optional padding on each border - of the image. Default is None, i.e no padding. If a sequence of length - 4 is provided, it is used to pad left, top, right, bottom borders - respectively. If a sequence of length 2 is provided, it is used to - pad left/right, top/bottom borders, respectively. + of the image. Default is None. If a single int is provided this + is used to pad all borders. If tuple of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a tuple of length 4 is provided + this is the padding for the left, top, right and bottom borders respectively. + In torchscript mode padding as single int is not supported, use a tuple or + list of length 1: ``[padding, ]``. pad_if_needed (boolean): It will pad the image if smaller than the desired size to avoid raising an exception. Since cropping is done after padding, the padding seems to be done at a random offset. - fill: Pixel fill value for constant fill. Default is 0. If a tuple of + fill (int or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. This value is only used when the padding_mode is constant - padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. + padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. - constant: pads with a constant value, this value is specified with fill @@ -479,60 +486,70 @@ class RandomCrop(object): """ - def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'): - if isinstance(size, numbers.Number): - self.size = (int(size), int(size)) - else: - self.size = size - self.padding = padding - self.pad_if_needed = pad_if_needed - self.fill = fill - self.padding_mode = padding_mode - @staticmethod - def get_params(img, output_size): + def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]: """Get parameters for ``crop`` for a random crop. Args: - img (PIL Image): Image to be cropped. + img (PIL Image or Tensor): Image to be cropped. output_size (tuple): Expected output size of the crop. Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. """ - w, h = _get_image_size(img) + w, h = F._get_image_size(img) th, tw = output_size if w == tw and h == th: return 0, 0, h, w - i = random.randint(0, h - th) - j = random.randint(0, w - tw) + i = torch.randint(0, h - th, size=(1, )).item() + j = torch.randint(0, w - tw, size=(1, )).item() return i, j, th, tw - def __call__(self, img): + def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): + super().__init__() + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + elif isinstance(size, Sequence) and len(size) == 1: + self.size = (size[0], size[0]) + else: + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") + + # cast to tuple for torchscript + self.size = tuple(size) + self.padding = padding + self.pad_if_needed = pad_if_needed + self.fill = fill + self.padding_mode = padding_mode + + def forward(self, img): """ Args: - img (PIL Image): Image to be cropped. + img (PIL Image or Tensor): Image to be cropped. Returns: - PIL Image: Cropped image. + PIL Image or Tensor: Cropped image. """ if self.padding is not None: img = F.pad(img, self.padding, self.fill, self.padding_mode) + width, height = F._get_image_size(img) # pad the width if needed - if self.pad_if_needed and img.size[0] < self.size[1]: - img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) + if self.pad_if_needed and width < self.size[1]: + padding = [self.size[1] - width, 0] + img = F.pad(img, padding, self.fill, self.padding_mode) # pad the height if needed - if self.pad_if_needed and img.size[1] < self.size[0]: - img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) + if self.pad_if_needed and height < self.size[0]: + padding = [0, self.size[0] - height] + img = F.pad(img, padding, self.fill, self.padding_mode) i, j, h, w = self.get_params(img, self.size) return F.crop(img, i, j, h, w) def __repr__(self): - return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) + return self.__class__.__name__ + "(size={0}, padding={1})".format(self.size, self.padding) class RandomHorizontalFlip(torch.nn.Module): @@ -566,7 +583,7 @@ def __repr__(self): class RandomVerticalFlip(torch.nn.Module): - """Vertically flip the given PIL Image randomly with a given probability. + """Vertically flip the given image randomly with a given probability. The image can be a PIL Image or a torch Tensor, in which case it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions @@ -702,7 +719,7 @@ def get_params(img, scale, ratio): tuple: params (i, j, h, w) to be passed to ``crop`` for a random sized crop. """ - width, height = _get_image_size(img) + width, height = F._get_image_size(img) area = height * width for _ in range(10): @@ -763,8 +780,11 @@ def __init__(self, *args, **kwargs): super(RandomSizedCrop, self).__init__(*args, **kwargs) -class FiveCrop(object): - """Crop the given PIL Image into four corners and the central crop +class FiveCrop(torch.nn.Module): + """Crop the given image into four corners and the central crop. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions .. Note:: This transform returns a tuple of images and there may be a mismatch in the number of @@ -774,6 +794,7 @@ class FiveCrop(object): Args: size (sequence or int): Desired output size of the crop. If size is an ``int`` instead of sequence like (h, w), a square crop of size (size, size) is made. + If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]). Example: >>> transform = Compose([ @@ -788,23 +809,37 @@ class FiveCrop(object): """ def __init__(self, size): - self.size = size + super().__init__() if isinstance(size, numbers.Number): self.size = (int(size), int(size)) + elif isinstance(size, Sequence) and len(size) == 1: + self.size = (size[0], size[0]) else: - assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") + self.size = size - def __call__(self, img): + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + tuple of 5 images. Image can be PIL Image or Tensor + """ return F.five_crop(img, self.size) def __repr__(self): return self.__class__.__name__ + '(size={0})'.format(self.size) -class TenCrop(object): - """Crop the given PIL Image into four corners and the central crop plus the flipped version of - these (horizontal flipping is used by default) +class TenCrop(torch.nn.Module): + """Crop the given image into four corners and the central crop plus the flipped version of + these (horizontal flipping is used by default). + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions .. Note:: This transform returns a tuple of images and there may be a mismatch in the number of @@ -814,7 +849,7 @@ class TenCrop(object): Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is - made. + made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]). vertical_flip (bool): Use vertical flipping instead of horizontal Example: @@ -830,15 +865,26 @@ class TenCrop(object): """ def __init__(self, size, vertical_flip=False): - self.size = size + super().__init__() if isinstance(size, numbers.Number): self.size = (int(size), int(size)) + elif isinstance(size, Sequence) and len(size) == 1: + self.size = (size[0], size[0]) else: - assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") + self.size = size self.vertical_flip = vertical_flip - def __call__(self, img): + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + tuple of 10 images. Image can be PIL Image or Tensor + """ return F.ten_crop(img, self.size, self.vertical_flip) def __repr__(self):