From 36fef0d6182b6075900094095e8514b2a4cf88ad Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 21 Jul 2020 09:29:45 +0200 Subject: [PATCH 1/5] Added code for F_t.rotate with test - updated F.affine tests --- test/test_functional_tensor.py | 57 ++++++-- torchvision/transforms/functional.py | 88 +++++++----- torchvision/transforms/functional_pil.py | 31 +++++ torchvision/transforms/functional_tensor.py | 147 ++++++++++++++++---- 4 files changed, 252 insertions(+), 71 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 7b4b9b490da..909629a2806 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -406,7 +406,7 @@ def test_affine(self): ) # 3) Test translation test_configs = [ - [10, 12], (12, 13) + [10, 12], (-12, -13) ] for t in test_configs: for fn in [F.affine, scripted_affine]: @@ -418,21 +418,21 @@ def test_affine(self): test_configs = [ (45, [5, 6], 1.0, [0.0, 0.0]), (33, (5, -4), 1.0, [0.0, 0.0]), - (45, [5, 4], 1.2, [0.0, 0.0]), - (33, (4, 8), 2.0, [0.0, 0.0]), + (45, [-5, 4], 1.2, [0.0, 0.0]), + (33, (-4, -8), 2.0, [0.0, 0.0]), (85, (10, -10), 0.7, [0.0, 0.0]), (0, [0, 0], 1.0, [35.0, ]), (25, [0, 0], 1.2, [0.0, 15.0]), - (45, [10, 0], 0.7, [2.0, 5.0]), - (45, [10, -10], 1.2, [4.0, 5.0]), + (45, [-10, 0], 0.7, [2.0, 5.0]), + (45, [-10, -10], 1.2, [4.0, 5.0]), ] for r in [0, ]: for a, t, s, sh in test_configs: + out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + for fn in [F.affine, scripted_affine]: out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r) - out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r) - out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) - num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] # Tolerance : less than 5% of different pixels @@ -444,6 +444,47 @@ def test_affine(self): ) ) + def test_rotate(self): + # Tests on square image + tensor, pil_img = self._create_data(26, 26) + scripted_rotate = torch.jit.script(F.rotate) + + img_size = pil_img.size + + centers = [ + None, + (int(img_size[0] * 0.3), int(img_size[0] * 0.4)), + [int(img_size[0] * 0.5), int(img_size[0] * 0.6)] + ] + + for r in [0, ]: + for a in range(-120, 120, 23): + for e in [True, False]: + for c in centers: + + out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + for fn in [F.rotate, scripted_rotate]: + out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c) + + self.assertEqual( + out_tensor.shape, + out_pil_tensor.shape, + msg="{}: {} vs {}".format( + (r, a, e, c), out_tensor.shape, out_pil_tensor.shape + ) + ) + num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] + # Tolerance : less than 1% of different pixels + self.assertLess( + ratio_diff_pixels, + 0.01, + msg="{}: {}\n{} vs \n{}".format( + (r, a, e, c), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] + ) + ) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 340592a01f0..f62ea49b382 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -758,40 +758,8 @@ def adjust_gamma(img, gamma, gain=1): return img -def rotate(img, angle, resample=False, expand=False, center=None, fill=None): - """Rotate the image by angle. - - - Args: - img (PIL Image): PIL Image to be rotated. - angle (float or int): In degrees degrees counter clockwise order. - resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): - An optional resampling filter. See `filters`_ for more information. - If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. - expand (bool, optional): Optional expansion flag. - If true, expands the output image to make it large enough to hold the entire rotated image. - If false or omitted, make the output image the same size as the input image. - Note that the expand flag assumes rotation around the center and no translation. - center (2-tuple, optional): Optional center of rotation. - Origin is the upper left corner. - Default is the center of the image. - fill (n-tuple or int or float): Pixel fill value for area outside the rotated - image. If int or float, the value is used for all bands respectively. - Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. - - .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters - - """ - if not F_pil._is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - - opts = _parse_fill(fill, img, '5.2.0') - - return img.rotate(angle, resample, expand, center, **opts) - - def _get_inverse_affine_matrix( - center: List[int], angle: float, translate: List[float], scale: float, shear: List[float] + center: List[float], angle: float, translate: List[float], scale: float, shear: List[float] ) -> List[float]: # Helper method to compute inverse matrix for affine transformation @@ -840,6 +808,56 @@ def _get_inverse_affine_matrix( return matrix +def rotate( + img: Tensor, angle: float, resample: int = 0, expand: bool = False, + center: Optional[List[int]] = None, fill: Optional[int] = None +) -> Tensor: + """Rotate the image by angle. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + img (PIL Image or Tensor): image to be rotated. + angle (float or int): rotation angle value in degrees, counter-clockwise. + resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): + An optional resampling filter. See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. + expand (bool, optional): Optional expansion flag. + If true, expands the output image to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + center (list or tuple, optional): Optional center of rotation. Origin is the upper left corner. + Default is the center of the image. + fill (n-tuple or int or float): Pixel fill value for area outside the rotated + image. If int or float, the value is used for all bands respectively. + Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. + + Returns: + PIL Image or Tensor: Rotated image. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + if not isinstance(angle, (int, float)): + raise TypeError("Argument angle should be int or float") + + if center is not None and not isinstance(center, (list, tuple)): + raise TypeError("Argument center should be a sequence") + + if not isinstance(img, torch.Tensor): + return F_pil.rotate(img, angle=angle, resample=resample, expand=expand, center=center, fill=fill) + + center_f = [0.0, 0.0] + if center is not None: + img_size = _get_image_size(img) + # Center is normalized to [-1, +1] + center_f = [2.0 * t / s - 1.0 for s, t in zip(img_size, center)] + # due to current incoherence of rotation angle direction between affine and rotate implementations + # we need to set -angle. + matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) + return F_t.rotate(img, matrix=matrix, resample=resample, expand=expand, fill=fill) + + def affine( img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float], resample: int = 0, fillcolor: Optional[int] = None @@ -849,7 +867,7 @@ def affine( to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. Args: - img (PIL Image or Tensor): image to be rotated. + img (PIL Image or Tensor): image to transform. angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction. translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation) scale (float): overall scale @@ -911,7 +929,7 @@ def affine( # we need to rescale translate by image size / 2 as its values can be between -1 and 1 translate = [2.0 * t / s for s, t in zip(img_size, translate)] - matrix = _get_inverse_affine_matrix([0, 0], angle, translate, scale, shear) + matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear) return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index f165b65f8d8..fd603d83e4d 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -390,3 +390,34 @@ def affine(img, matrix, resample=0, fillcolor=None): output_size = img.size opts = _parse_fill(fillcolor, img, '5.0.0') return img.transform(output_size, Image.AFFINE, matrix, resample, **opts) + + +@torch.jit.unused +def rotate(img, angle, resample=0, expand=False, center=None, fill=None): + """Rotate PIL image by angle. + + Args: + img (PIL Image): image to be rotated. + angle (float or int): rotation angle value in degrees, counter-clockwise. + resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): + An optional resampling filter. See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. + expand (bool, optional): Optional expansion flag. + If true, expands the output image to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + center (2-tuple, optional): Optional center of rotation. + Origin is the upper left corner. + Default is the center of the image. + fill (n-tuple or int or float): Pixel fill value for area outside the rotated + image. If int or float, the value is used for all bands respectively. + Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. + + Returns: + PIL Image: Rotated image. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + opts = _parse_fill(fill, img, '5.2.0') + return img.rotate(angle, resample, expand, center, **opts) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 2bd4549059e..d45a65cb3bc 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional +from typing import Optional, Dict, Tuple import torch from torch import Tensor @@ -578,48 +578,32 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: return img -def affine( - img: Tensor, matrix: List[float], resample: int = 0, fillcolor: Optional[int] = None -) -> Tensor: - """Apply affine transformation on the Tensor image keeping image center invariant. +def _assert_grid_transform_inputs( + img: Tensor, matrix: List[float], resample: int, fillcolor: Optional[int], _interpolation_modes: Dict[int, str] +): + if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): + raise TypeError("img should be Tensor Image. Got {}".format(type(img))) - Args: - img (Tensor): image to be rotated. - matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation. - resample (int, optional): An optional resampling filter. Default is nearest (=2). Other supported values: - bilinear(=2). - fillcolor (int, optional): this option is not supported for Tensor input. Fill value for the area outside the - transform in the output image is always 0. + if not isinstance(matrix, list): + raise TypeError("Argument matrix should be a list. Got {}".format(type(matrix))) - Returns: - Tensor: Transformed image. - """ - if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): - raise TypeError('img should be Tensor Image. Got {}'.format(type(img))) + if len(matrix) != 6: + raise ValueError("Argument matrix should have 6 float values") if fillcolor is not None: - warnings.warn("Argument fillcolor is not supported for Tensor input. Fill value is zero") - - _interpolation_modes = { - 0: "nearest", - 2: "bilinear", - } + warnings.warn("Argument fill/fillcolor is not supported for Tensor input. Fill value is zero") if resample not in _interpolation_modes: raise ValueError("This resampling mode is unsupported with Tensor input") - theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) - shape = img.shape - grid = affine_grid(theta, size=(1, shape[-3], shape[-2], shape[-1]), align_corners=False) +def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor: # make image NCHW need_squeeze = False if img.ndim < 4: img = img.unsqueeze(dim=0) need_squeeze = True - mode = _interpolation_modes[resample] - out_dtype = img.dtype need_cast = False if img.dtype not in (torch.float32, torch.float64): @@ -636,3 +620,110 @@ def affine( img = torch.round(img).to(out_dtype) return img + + +def affine( + img: Tensor, matrix: List[float], resample: int = 0, fillcolor: Optional[int] = None +) -> Tensor: + """Apply affine transformation on the Tensor image keeping image center invariant. + + Args: + img (Tensor): image to be rotated. + matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation. + resample (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values: + bilinear(=2). + fillcolor (int, optional): this option is not supported for Tensor input. Fill value for the area outside the + transform in the output image is always 0. + + Returns: + Tensor: Transformed image. + """ + _interpolation_modes = { + 0: "nearest", + 2: "bilinear", + } + + _assert_grid_transform_inputs(img, matrix, resample, fillcolor, _interpolation_modes) + + theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) + shape = img.shape + grid = affine_grid(theta, size=(1, shape[-3], shape[-2], shape[-1]), align_corners=False) + mode = _interpolation_modes[resample] + + return _apply_grid_transform(img, grid, mode) + + +def _compute_output_size(theta: Tensor, w: int, h: int) -> Tuple[int, int]: + point = torch.tensor([0.0, 0.0, 1.0]) + pts = [] + for i in [0.0, float(h)]: + for j in [0.0, float(w)]: + # we need to normalize coordinates according to + # [0, s] is mapped [-1, +1] as theta translation parameters are + # normalized like that + point[1], point[0] = 2.0 * i / w - 1.0, 2.0 * j / h - 1.0 + new_point = torch.matmul(theta, point) + # denormalize back to w, h: + new_point = (new_point + 1.0) * torch.tensor([w, h]) / 2.0 + pts.append(new_point) + pts = torch.stack(pts) + min_vals, _ = pts.min(dim=0) + max_vals, _ = pts.max(dim=0) + size = torch.ceil(max_vals) - torch.floor(min_vals) + return int(size[0]), int(size[1]) + + +def _expanded_affine_grid(theta: Tensor, w: int, h: int, expand: bool = False) -> Tensor: + if expand: + ow, oh = _compute_output_size(theta, w, h) + else: + ow, oh = w, h + output_grid = torch.zeros(1, oh, ow, 2) + + d = 0.5 # if not align_corners + + point = torch.tensor([0.0, 0.0, 1.0]) + for i in range(oh): + for j in range(ow): + point[1] = (i + d - oh * 0.5) / (0.5 * h) + point[0] = (j + d - ow * 0.5) / (0.5 * w) + output_grid[0, i, j, :] = torch.matmul(theta, point) + return output_grid + + +def rotate( + img: Tensor, matrix: List[float], resample: int = 0, expand: bool = False, fill: Optional[int] = None +) -> Tensor: + """Rotate the Tensor image by angle. + + Args: + img (Tensor): image to be rotated. + matrix (list of floats): list of 6 float values representing inverse matrix for rotation transformation. + resample (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values: + bilinear(=2). + expand (bool, optional): Optional expansion flag. + If true, expands the output image to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + fill (n-tuple or int or float): this option is not supported for Tensor input. + Fill value for the area outside the transform in the output image is always 0. + + Returns: + Tensor: Rotated image. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + _interpolation_modes = { + 0: "nearest", + 2: "bilinear", + } + + _assert_grid_transform_inputs(img, matrix, resample, fill, _interpolation_modes) + + theta = torch.tensor(matrix).reshape(2, 3) + shape = img.shape + grid = _expanded_affine_grid(theta, shape[-1], shape[-2], expand=expand) + mode = _interpolation_modes[resample] + + return _apply_grid_transform(img, grid, mode) From 2b98bdc6392e21b0d3941f850b7f67c9b2187e0b Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 21 Jul 2020 09:40:44 +0200 Subject: [PATCH 2/5] Rotate test tolerance to 2% --- test/test_functional_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 909629a2806..5972a078db4 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -476,10 +476,10 @@ def test_rotate(self): ) num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] - # Tolerance : less than 1% of different pixels + # Tolerance : less than 2% of different pixels self.assertLess( ratio_diff_pixels, - 0.01, + 0.02, msg="{}: {}\n{} vs \n{}".format( (r, a, e, c), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] ) From c7231bd5137ba9d26834f07a98a816cb3ba99d83 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 21 Jul 2020 10:17:13 +0200 Subject: [PATCH 3/5] Fixes failing test --- test/test_transforms.py | 2 +- torchvision/transforms/functional_pil.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 125502a3ad5..eb9efb3347e 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1258,7 +1258,7 @@ def test_rotate(self): x = np.zeros((100, 100, 3), dtype=np.uint8) x[40, 40] = [255, 255, 255] - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, r"img should be PIL Image"): F.rotate(x, 10) img = F.to_pil_image(x) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index fd603d83e4d..d128943c0ed 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -419,5 +419,8 @@ def rotate(img, angle, resample=0, expand=False, center=None, fill=None): .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters """ + if not _is_pil_image(img): + raise TypeError("img should be PIL Image. Got {}".format(type(img))) + opts = _parse_fill(fill, img, '5.2.0') return img.rotate(angle, resample, expand, center, **opts) From 65a306b682128f584cfe9f3e83eb54504cc5bfd4 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 21 Jul 2020 09:40:06 +0200 Subject: [PATCH 4/5] [WIP] RandomRotation --- torchvision/transforms/functional.py | 4 +++- torchvision/transforms/transforms.py | 15 +++++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index f62ea49b382..5c81fd62427 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -830,7 +830,9 @@ def rotate( Default is the center of the image. fill (n-tuple or int or float): Pixel fill value for area outside the rotated image. If int or float, the value is used for all bands respectively. - Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. + Defaults to 0 for all bands. This option is only available for Pillow>=5.2.0. + This option is not supported for Tensor input. Fill value for the area outside the transform in the output + image is always 0. Returns: PIL Image or Tensor: Rotated image. diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index f7d421d2b83..f19647a884a 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1097,16 +1097,18 @@ def __repr__(self): return format_string -class RandomRotation(object): +class RandomRotation(torch.nn.Module): """Rotate the image by angle. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. Args: degrees (sequence or float or int): Range of degrees to select from. If degrees is a number instead of sequence like (min, max), the range of degrees will be (-degrees, +degrees). - resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): - An optional resampling filter. See `filters`_ for more information. + resample (int, optional): An optional resampling filter. See `filters`_ for more information. If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. + If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported. expand (bool, optional): Optional expansion flag. If true, expands the output to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. @@ -1116,13 +1118,14 @@ class RandomRotation(object): Default is the center of the image. fill (n-tuple or int or float): Pixel fill value for area outside the rotated image. If int or float, the value is used for all bands respectively. - Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. + Defaults to 0 for all bands. This option is only available for Pillow>=5.2.0. .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters """ def __init__(self, degrees, resample=False, expand=False, center=None, fill=None): + super().__init__() if isinstance(degrees, numbers.Number): if degrees < 0: raise ValueError("If degrees is a single number, it must be positive.") @@ -1148,13 +1151,13 @@ def get_params(degrees): return angle - def __call__(self, img): + def forward(self, img): """ Args: img (PIL Image): Image to be rotated. Returns: - PIL Image: Rotated image. + PIL Image or Tensor: Rotated image. """ angle = self.get_params(self.degrees) From 170775f22a63f75d6ad3a9479da5c146f1547d61 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 21 Jul 2020 11:09:31 +0200 Subject: [PATCH 5/5] Unified RandomRotation with tests --- test/test_transforms_tensor.py | 18 ++++++++++++++++ torchvision/transforms/functional.py | 2 +- torchvision/transforms/transforms.py | 32 +++++++++++++++++----------- 3 files changed, 39 insertions(+), 13 deletions(-) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index fbd3331a490..ece8f03e69e 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -264,6 +264,24 @@ def test_resized_crop(self): out2 = s_transform(tensor) self.assertTrue(out1.equal(out2)) + def test_random_rotate(self): + tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8) + + for center in [(0, 0), [10, 10], None, (56, 44)]: + for expand in [True, False]: + for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]: + for interpolation in [NEAREST, BILINEAR]: + transform = T.RandomRotation( + degrees=degrees, resample=interpolation, expand=expand, center=center + ) + s_transform = torch.jit.script(transform) + + torch.manual_seed(12) + out1 = transform(tensor) + torch.manual_seed(12) + out2 = s_transform(tensor) + self.assertTrue(out1.equal(out2)) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 5c81fd62427..2c77f545efa 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -826,7 +826,7 @@ def rotate( If true, expands the output image to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. Note that the expand flag assumes rotation around the center and no translation. - center (list or tuple, optional): Optional center of rotation. Origin is the upper left corner. + center (list or tuple, optional): Optional center of rotation, (x, y). Origin is the upper left corner. Default is the center of the image. fill (n-tuple or int or float): Pixel fill value for area outside the rotated image. If int or float, the value is used for all bands respectively. diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index f19647a884a..fc276e14f58 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1113,12 +1113,13 @@ class RandomRotation(torch.nn.Module): If true, expands the output to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. Note that the expand flag assumes rotation around the center and no translation. - center (2-tuple, optional): Optional center of rotation. - Origin is the upper left corner. + center (list or tuple, optional): Optional center of rotation, (x, y). Origin is the upper left corner. Default is the center of the image. fill (n-tuple or int or float): Pixel fill value for area outside the rotated image. If int or float, the value is used for all bands respectively. Defaults to 0 for all bands. This option is only available for Pillow>=5.2.0. + This option is not supported for Tensor input. Fill value for the area outside the transform in the output + image is always 0. .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters @@ -1129,39 +1130,46 @@ def __init__(self, degrees, resample=False, expand=False, center=None, fill=None if isinstance(degrees, numbers.Number): if degrees < 0: raise ValueError("If degrees is a single number, it must be positive.") - self.degrees = (-degrees, degrees) + degrees = [-degrees, degrees] else: + if not isinstance(degrees, Sequence): + raise TypeError("degrees should be a sequence of length 2.") if len(degrees) != 2: raise ValueError("If degrees is a sequence, it must be of len 2.") - self.degrees = degrees + + self.degrees = [float(d) for d in degrees] + + if center is not None: + if not isinstance(center, Sequence): + raise TypeError("center should be a sequence of length 2.") + if len(center) != 2: + raise ValueError("center should be a sequence of length 2.") + + self.center = center self.resample = resample self.expand = expand - self.center = center self.fill = fill @staticmethod - def get_params(degrees): + def get_params(degrees: List[float]) -> float: """Get parameters for ``rotate`` for a random rotation. Returns: - sequence: params to be passed to ``rotate`` for random rotation. + float: angle parameter to be passed to ``rotate`` for random rotation. """ - angle = random.uniform(degrees[0], degrees[1]) - + angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item()) return angle def forward(self, img): """ Args: - img (PIL Image): Image to be rotated. + img (PIL Image or Tensor): Image to be rotated. Returns: PIL Image or Tensor: Rotated image. """ - angle = self.get_params(self.degrees) - return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill) def __repr__(self):