From 1b0f88d8b37e200ce8f37c41e6ab75faec8da14e Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Fri, 26 Nov 2021 02:55:29 -0500 Subject: [PATCH 01/51] Make operations differential Make operations differential w.r.t. hyper-parameters, which is extremely helpful for AutoAugment search. --- torchvision/transforms/functional_tensor.py | 35 +++++++++++---------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 09ae726931c..5e060c81a8f 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, Tuple, List +from typing import Optional, Tuple, List, Union import torch from torch import Tensor @@ -152,7 +152,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: return l_img -def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: +def adjust_brightness(img: Tensor, brightness_factor: Union[float, Tensor]) -> Tensor: if brightness_factor < 0: raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") @@ -163,7 +163,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: return _blend(img, torch.zeros_like(img), brightness_factor) -def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: +def adjust_contrast(img: Tensor, contrast_factor: Union[float, Tensor]) -> Tensor: if contrast_factor < 0: raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") @@ -180,7 +180,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: return _blend(img, mean, contrast_factor) -def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: +def adjust_hue(img: Tensor, hue_factor: Union[float, Tensor]) -> Tensor: if not (-0.5 <= hue_factor <= 0.5): raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") @@ -209,7 +209,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: return img_hue_adj -def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: +def adjust_saturation(img: Tensor, saturation_factor: Union[float, Tensor]) -> Tensor: if saturation_factor < 0: raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") @@ -223,7 +223,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: return _blend(img, rgb_to_grayscale(img), saturation_factor) -def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: +def adjust_gamma(img: Tensor, gamma: Union[float, Tensor], gain: float = 1) -> Tensor: if not isinstance(img, torch.Tensor): raise TypeError("Input img should be a Tensor.") @@ -311,8 +311,7 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa return first_five + second_five -def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: - ratio = float(ratio) +def _blend(img1: Tensor, img2: Tensor, ratio: Union[float, Tensor]) -> Tensor: bound = 1.0 if img1.is_floating_point() else 255.0 return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype) @@ -562,7 +561,7 @@ def resize( def _assert_grid_transform_inputs( img: Tensor, - matrix: Optional[List[float]], + matrix: Optional[Union[List[float], Tensor]], interpolation: str, fill: Optional[List[float]], supported_interpolation_modes: List[str], @@ -574,8 +573,8 @@ def _assert_grid_transform_inputs( _assert_image_tensor(img) - if matrix is not None and not isinstance(matrix, list): - raise TypeError("Argument matrix should be a list") + if matrix is not None and not isinstance(matrix, (list, Tensor)): + raise TypeError("Argument matrix should be a list or Tensor") if matrix is not None and len(matrix) != 6: raise ValueError("Argument matrix should have 6 float values") @@ -692,12 +691,13 @@ def _gen_affine_grid( def affine( - img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None + img: Tensor, matrix: Union[List[float], torch.Tensor], interpolation: str = "nearest", fill: Optional[List[float]] = None ) -> Tensor: _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) + theta = matrix.to(dtype=dtype, device=img.device).reshape(1, 2, 3) if isinstance(matrix, Tensor) \ + else torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) shape = img.shape # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) @@ -733,7 +733,7 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] def rotate( img: Tensor, - matrix: List[float], + matrix: Union[List[float], torch.Tensor], interpolation: str = "nearest", expand: bool = False, fill: Optional[List[float]] = None, @@ -742,7 +742,8 @@ def rotate( w, h = img.shape[-1], img.shape[-2] ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) + theta = matrix.to(dtype=dtype, device=img.device).reshape(1, 2, 3) if isinstance(matrix, Tensor) \ + else torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh) @@ -873,7 +874,7 @@ def posterize(img: Tensor, bits: int) -> Tensor: return img & mask -def solarize(img: Tensor, threshold: float) -> Tensor: +def solarize(img: Tensor, threshold: Union[float, torch.Tensor]) -> Tensor: _assert_image_tensor(img) @@ -909,7 +910,7 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor: return result -def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: +def adjust_sharpness(img: Tensor, sharpness_factor: Union[float, torch.Tensor]) -> Tensor: if sharpness_factor < 0: raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.") From c1fb61408f9333a0a8bf0d352532dea834584789 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Fri, 26 Nov 2021 04:04:40 -0500 Subject: [PATCH 02/51] Update functional_tensor.py --- torchvision/transforms/functional_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 5e060c81a8f..a7a0025b573 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -704,7 +704,7 @@ def affine( return _apply_grid_transform(img, grid, interpolation, fill=fill) -def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]: +def _compute_output_size(matrix: Union[List[float], Tensor], w: int, h: int) -> Tuple[int, int]: # Inspired of PIL implementation: # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 From f968b7d77b2cd82c35ef44314ecb75860bbf28b8 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Sun, 28 Nov 2021 04:07:39 -0500 Subject: [PATCH 03/51] update --- torchvision/transforms/functional.py | 138 +++++++++++++------- torchvision/transforms/functional_tensor.py | 48 ++++--- 2 files changed, 121 insertions(+), 65 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index bd5b170626e..d025dc6028f 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -2,7 +2,7 @@ import numbers import warnings from enum import Enum -from typing import List, Tuple, Any, Optional +from typing import List, Tuple, Any, Optional, Union import numpy as np import torch @@ -772,14 +772,14 @@ def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[ return first_five + second_five -def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: +def adjust_brightness(img: Tensor, brightness_factor: Union[float, Tensor]) -> Tensor: """Adjust brightness of an image. Args: img (PIL Image or Tensor): Image to be adjusted. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. - brightness_factor (float): How much to adjust the brightness. Can be + brightness_factor (float or Tensor): How much to adjust the brightness. Can be any non negative number. 0 gives a black image, 1 gives the original image while 2 increases the brightness by a factor of 2. @@ -792,14 +792,14 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: return F_t.adjust_brightness(img, brightness_factor) -def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: +def adjust_contrast(img: Tensor, contrast_factor: Union[float, Tensor]) -> Tensor: """Adjust contrast of an image. Args: img (PIL Image or Tensor): Image to be adjusted. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. - contrast_factor (float): How much to adjust the contrast. Can be any + contrast_factor (float or Tensor): How much to adjust the contrast. Can be any non negative number. 0 gives a solid gray image, 1 gives the original image while 2 increases the contrast by a factor of 2. @@ -812,14 +812,14 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: return F_t.adjust_contrast(img, contrast_factor) -def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: +def adjust_saturation(img: Tensor, saturation_factor: Union[float, Tensor]) -> Tensor: """Adjust color saturation of an image. Args: img (PIL Image or Tensor): Image to be adjusted. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. - saturation_factor (float): How much to adjust the saturation. 0 will + saturation_factor (float or Tensor): How much to adjust the saturation. 0 will give a black and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2. @@ -832,7 +832,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: return F_t.adjust_saturation(img, saturation_factor) -def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: +def adjust_hue(img: Tensor, hue_factor: Union[float, Tensor]) -> Tensor: """Adjust hue of an image. The image hue is adjusted by converting the image to HSV and @@ -851,7 +851,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image mode "1", "I", "F" and modes with transparency (alpha channel) are not supported. - hue_factor (float): How much to shift the hue channel. Should be in + hue_factor (float or Tensor): How much to shift the hue channel. Should be in [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in HSV space in positive and negative direction respectively. 0 means no shift. Therefore, both -0.5 and 0.5 will give an image @@ -866,7 +866,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: return F_t.adjust_hue(img, hue_factor) -def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: +def adjust_gamma(img: Tensor, gamma: Union[float, Tensor], gain: Union[float, Tensor] = 1) -> Tensor: r"""Perform gamma correction on an image. Also known as Power Law Transform. Intensities in RGB mode are adjusted @@ -884,10 +884,10 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, modes with transparency (alpha channel) are not supported. - gamma (float): Non negative real number, same as :math:`\gamma` in the equation. + gamma (float or Tensor): Non negative real number, same as :math:`\gamma` in the equation. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter. - gain (float): The constant multiplier. + gain (float or Tensor): The constant multiplier. Returns: PIL Image or Tensor: Gamma correction adjusted image. """ @@ -948,12 +948,61 @@ def _get_inverse_affine_matrix( return matrix +def _get_inverse_affine_matrix_tensor( + center: Union[List[float], Tensor], + angle: Union[float, Tensor], + translate: Union[List[float], Tensor], + scale: Union[float, Tensor], + shear: Union[List[float], Tensor], +) -> Tensor: + device: Optional[torch.device] = None + for element in [center, angle, translate, scale, shear]: + if isinstance(element, Tensor): + if device is None: + device = element.device + + center = center.to(device=device) if isinstance(center, Tensor) else torch.tensor(center, device=device) + angle = angle.to(device=device) if isinstance(angle, Tensor) else torch.tensor(angle, device=device) + translate = translate.to(device=device) if isinstance(translate, Tensor) else torch.tensor(translate, device=device) + scale = scale.to(device=device) if isinstance(scale, Tensor) else torch.tensor(scale, device=device) + shear = shear.to(device=device) if isinstance(shear, Tensor) else torch.tensor(shear, device=device) + + rot = angle * math.pi / 180 + sx = shear[0] * math.pi / 180 + sy = shear[1] * math.pi / 180 + + cx, cy = center + tx, ty = translate + + # RSS without scaling + a = torch.cos(rot - sy) / torch.cos(sy) + b = -torch.cos(rot - sy) * torch.tan(sx) / torch.cos(sy) - torch.sin(rot) + c = torch.sin(rot - sy) / torch.cos(sy) + d = -torch.sin(rot - sy) * torch.tan(sx) / torch.cos(sy) + torch.cos(rot) + + # Inverted rotation matrix with scale and shear + # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 + zero = torch.zeros(1, device=device) + matrix = torch.cat([d, -b, zero, -c, a, zero]) + matrix = [x / scale for x in matrix] + + # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 + matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty) + matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty) + + # Apply center translation: C * RSS^-1 * C^-1 * T^-1 + matrix[2] += cx + matrix[5] += cy + + return matrix + + def rotate( img: Tensor, - angle: float, + angle: Union[float, Tensor], interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - center: Optional[List[int]] = None, + center: Optional[Union[List[int], Tensor]] = None, fill: Optional[List[float]] = None, resample: Optional[int] = None, ) -> Tensor: @@ -963,7 +1012,7 @@ def rotate( Args: img (PIL Image or Tensor): image to be rotated. - angle (number): rotation angle value in degrees, counter-clockwise. + angle (number or Tensor): rotation angle value in degrees, counter-clockwise. interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. @@ -972,7 +1021,7 @@ def rotate( If true, expands the output image to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. Note that the expand flag assumes rotation around the center and no translation. - center (sequence, optional): Optional center of rotation. Origin is the upper left corner. + center (sequence or Tensor, optional): Optional center of rotation. Origin is the upper left corner. Default is the center of the image. fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. @@ -1001,10 +1050,10 @@ def rotate( ) interpolation = _interpolation_modes_from_int(interpolation) - if not isinstance(angle, (int, float)): - raise TypeError("Argument angle should be int or float") + if not isinstance(angle, (int, float, Tensor)): + raise TypeError("Argument angle should be int, float or Tensor") - if center is not None and not isinstance(center, (list, tuple)): + if center is not None and not isinstance(center, (list, tuple, Tensor)): raise TypeError("Argument center should be a sequence") if not isinstance(interpolation, InterpolationMode): @@ -1014,24 +1063,25 @@ def rotate( pil_interpolation = pil_modes_mapping[interpolation] return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill) - center_f = [0.0, 0.0] + center_f = torch.zeros(2, device=img.device, dtype=torch.float) if center is not None: img_size = get_image_size(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. - center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)] + center_f[0] = 1.0 * (center[0] - img_size[0] * 0.5) + center_f[1] = 1.0 * (center[1] - img_size[1] * 0.5) # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. - matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) + matrix = _get_inverse_affine_matrix_tensor(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill) def affine( img: Tensor, - angle: float, - translate: List[int], - scale: float, - shear: List[float], + angle: Union[float, Tensor], + translate: Union[List[int], Tensor], + scale: Union[float, Tensor], + shear: Union[List[float], Tensor], interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None, resample: Optional[int] = None, @@ -1043,10 +1093,10 @@ def affine( Args: img (PIL Image or Tensor): image to transform. - angle (number): rotation angle in degrees between -180 and 180, clockwise direction. - translate (sequence of integers): horizontal and vertical translations (post-rotation translation) - scale (float): overall scale - shear (float or sequence): shear angle value in degrees between -180 to 180, clockwise direction. + angle (number or Tensor): rotation angle in degrees between -180 and 180, clockwise direction. + translate (sequence of integers or Tensor): horizontal and vertical translations (post-rotation translation) + scale (float or Tensor): overall scale + shear (float or sequence or Tensor): shear angle value in degrees between -180 to 180, clockwise direction. If a sequence is specified, the first value corresponds to a shear parallel to the x axis, while the second value corresponds to a shear parallel to the y axis. interpolation (InterpolationMode): Desired interpolation enum defined by @@ -1085,11 +1135,11 @@ def affine( warnings.warn("Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead") fill = fillcolor - if not isinstance(angle, (int, float)): - raise TypeError("Argument angle should be int or float") + if not isinstance(angle, (int, float, Tensor)): + raise TypeError("Argument angle should be int, float or Tensor") - if not isinstance(translate, (list, tuple)): - raise TypeError("Argument translate should be a sequence") + if not isinstance(translate, (list, tuple, Tensor)): + raise TypeError("Argument translate should be a sequence or Tensor") if len(translate) != 2: raise ValueError("Argument translate should be a sequence of length 2") @@ -1097,7 +1147,7 @@ def affine( if scale <= 0.0: raise ValueError("Argument scale should be positive") - if not isinstance(shear, (numbers.Number, (list, tuple))): + if not isinstance(shear, (numbers.Number, (list, tuple, Tensor))): raise TypeError("Shear should be either a single value or a sequence of two values") if not isinstance(interpolation, InterpolationMode): @@ -1127,12 +1177,12 @@ def affine( # it is visually better to estimate the center without 0.5 offset # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine center = [img_size[0] * 0.5, img_size[1] * 0.5] - matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) + matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear).tolist() pil_interpolation = pil_modes_mapping[interpolation] return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill) translate_f = [1.0 * t for t in translate] - matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate_f, scale, shear) + matrix = _get_inverse_affine_matrix_tensor([0.0, 0.0], angle, translate_f, scale, shear) return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill) @@ -1204,7 +1254,7 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool if not inplace: img = img.clone() - img[..., i : i + h, j : j + w] = v + img[..., i: i + h, j: j + w] = v return img @@ -1291,7 +1341,7 @@ def invert(img: Tensor) -> Tensor: return F_t.invert(img) -def posterize(img: Tensor, bits: int) -> Tensor: +def posterize(img: Tensor, bits: Union[int, Tensor]) -> Tensor: """Posterize an image by reducing the number of bits for each color channel. Args: @@ -1300,7 +1350,7 @@ def posterize(img: Tensor, bits: int) -> Tensor: it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". - bits (int): The number of bits to keep for each channel (0-8). + bits (int or Tensor): The number of bits to keep for each channel (0-8). Returns: PIL Image or Tensor: Posterized image. """ @@ -1313,7 +1363,7 @@ def posterize(img: Tensor, bits: int) -> Tensor: return F_t.posterize(img, bits) -def solarize(img: Tensor, threshold: float) -> Tensor: +def solarize(img: Tensor, threshold: Union[float, Tensor]) -> Tensor: """Solarize an RGB/grayscale image by inverting all pixel values above a threshold. Args: @@ -1321,7 +1371,7 @@ def solarize(img: Tensor, threshold: float) -> Tensor: If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". - threshold (float): All pixels equal or above this value are inverted. + threshold (float or Tensor): All pixels equal or above this value are inverted. Returns: PIL Image or Tensor: Solarized image. """ @@ -1331,14 +1381,14 @@ def solarize(img: Tensor, threshold: float) -> Tensor: return F_t.solarize(img, threshold) -def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: +def adjust_sharpness(img: Tensor, sharpness_factor: Union[float, Tensor]) -> Tensor: """Adjust the sharpness of an image. Args: img (PIL Image or Tensor): Image to be adjusted. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. - sharpness_factor (float): How much to adjust the sharpness. Can be + sharpness_factor (float or Tensor): How much to adjust the sharpness. Can be any non negative number. 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness by a factor of 2. diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 09ae726931c..589e4b38e4f 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, Tuple, List +from typing import Optional, Tuple, List, Union import torch from torch import Tensor @@ -128,7 +128,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: if left < 0 or top < 0 or right > w or bottom > h: padding_ltrb = [max(-left, 0), max(-top, 0), max(right - w, 0), max(bottom - h, 0)] - return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0) + return pad(img[..., max(top, 0): bottom, max(left, 0): right], padding_ltrb, fill=0) return img[..., top:bottom, left:right] @@ -152,7 +152,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: return l_img -def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: +def adjust_brightness(img: Tensor, brightness_factor: Union[float, Tensor]) -> Tensor: if brightness_factor < 0: raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") @@ -163,7 +163,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: return _blend(img, torch.zeros_like(img), brightness_factor) -def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: +def adjust_contrast(img: Tensor, contrast_factor: Union[float, Tensor]) -> Tensor: if contrast_factor < 0: raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") @@ -180,7 +180,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: return _blend(img, mean, contrast_factor) -def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: +def adjust_hue(img: Tensor, hue_factor: Union[float, Tensor]) -> Tensor: if not (-0.5 <= hue_factor <= 0.5): raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") @@ -209,7 +209,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: return img_hue_adj -def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: +def adjust_saturation(img: Tensor, saturation_factor: Union[float, Tensor]) -> Tensor: if saturation_factor < 0: raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") @@ -223,7 +223,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: return _blend(img, rgb_to_grayscale(img), saturation_factor) -def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: +def adjust_gamma(img: Tensor, gamma: Union[float, Tensor], gain: Union[float, Tensor] = 1) -> Tensor: if not isinstance(img, torch.Tensor): raise TypeError("Input img should be a Tensor.") @@ -311,8 +311,7 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa return first_five + second_five -def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: - ratio = float(ratio) +def _blend(img1: Tensor, img2: Tensor, ratio: Union[float, Tensor]) -> Tensor: bound = 1.0 if img1.is_floating_point() else 255.0 return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype) @@ -384,7 +383,7 @@ def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor: if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0: neg_min_padding = [-min(x, 0) for x in padding] crop_left, crop_right, crop_top, crop_bottom = neg_min_padding - img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right] + img = img[..., crop_top: img.shape[-2] - crop_bottom, crop_left: img.shape[-1] - crop_right] padding = [max(x, 0) for x in padding] in_sizes = img.size() @@ -562,7 +561,7 @@ def resize( def _assert_grid_transform_inputs( img: Tensor, - matrix: Optional[List[float]], + matrix: Optional[Union[List[float], Tensor]], interpolation: str, fill: Optional[List[float]], supported_interpolation_modes: List[str], @@ -574,8 +573,8 @@ def _assert_grid_transform_inputs( _assert_image_tensor(img) - if matrix is not None and not isinstance(matrix, list): - raise TypeError("Argument matrix should be a list") + if matrix is not None and not isinstance(matrix, (list, Tensor)): + raise TypeError("Argument matrix should be a list or Tensor") if matrix is not None and len(matrix) != 6: raise ValueError("Argument matrix should have 6 float values") @@ -692,19 +691,23 @@ def _gen_affine_grid( def affine( - img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None + img: Tensor, + matrix: Union[List[float], torch.Tensor], + interpolation: str = "nearest", + fill: Optional[List[float]] = None, ) -> Tensor: _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) + theta = matrix.to(dtype=dtype, device=img.device).reshape(1, 2, 3) if isinstance(matrix, Tensor) \ + else torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) shape = img.shape # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) return _apply_grid_transform(img, grid, interpolation, fill=fill) -def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]: +def _compute_output_size(matrix: Union[List[float], Tensor], w: int, h: int) -> Tuple[int, int]: # Inspired of PIL implementation: # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 @@ -733,7 +736,7 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] def rotate( img: Tensor, - matrix: List[float], + matrix: Union[List[float], torch.Tensor], interpolation: str = "nearest", expand: bool = False, fill: Optional[List[float]] = None, @@ -742,7 +745,8 @@ def rotate( w, h = img.shape[-1], img.shape[-2] ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) + theta = matrix.to(dtype=dtype, device=img.device).reshape(1, 2, 3) if isinstance(matrix, Tensor) \ + else torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh) @@ -859,7 +863,7 @@ def invert(img: Tensor) -> Tensor: return bound - img -def posterize(img: Tensor, bits: int) -> Tensor: +def posterize(img: Tensor, bits: Union[int, Tensor]) -> Tensor: _assert_image_tensor(img) @@ -873,7 +877,7 @@ def posterize(img: Tensor, bits: int) -> Tensor: return img & mask -def solarize(img: Tensor, threshold: float) -> Tensor: +def solarize(img: Tensor, threshold: Union[float, Tensor]) -> Tensor: _assert_image_tensor(img) @@ -909,7 +913,7 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor: return result -def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: +def adjust_sharpness(img: Tensor, sharpness_factor: Union[float, Tensor]) -> Tensor: if sharpness_factor < 0: raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.") @@ -929,6 +933,8 @@ def autocontrast(img: Tensor) -> Tensor: if img.ndim < 3: raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") + elif img.ndim == 3: + img = img.unsqueeze(0) _assert_channels(img, [1, 3]) From 14436ca8c0100be7e50aeef9e4259cdfcd000f9e Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Sun, 28 Nov 2021 14:29:10 -0500 Subject: [PATCH 04/51] update --- torchvision/transforms/functional.py | 38 +++++++++++---------- torchvision/transforms/functional_tensor.py | 10 +++--- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index d025dc6028f..4b62149359e 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -961,11 +961,16 @@ def _get_inverse_affine_matrix_tensor( if device is None: device = element.device - center = center.to(device=device) if isinstance(center, Tensor) else torch.tensor(center, device=device) - angle = angle.to(device=device) if isinstance(angle, Tensor) else torch.tensor(angle, device=device) - translate = translate.to(device=device) if isinstance(translate, Tensor) else torch.tensor(translate, device=device) - scale = scale.to(device=device) if isinstance(scale, Tensor) else torch.tensor(scale, device=device) - shear = shear.to(device=device) if isinstance(shear, Tensor) else torch.tensor(shear, device=device) + center = center.to(device=device).flatten() if isinstance(center, Tensor) \ + else torch.tensor(center, device=device).flatten() + angle = angle.to(device=device).flatten() if isinstance(angle, Tensor) \ + else torch.tensor(angle, device=device).flatten() + translate = translate.to(device=device).flatten() if isinstance(translate, Tensor) \ + else torch.tensor(translate, device=device).flatten() + scale = scale.to(device=device).flatten() if isinstance(scale, Tensor) \ + else torch.tensor(scale, device=device).flatten() + shear = shear.to(device=device).flatten() if isinstance(shear, Tensor) \ + else torch.tensor(shear, device=device).flatten() rot = angle * math.pi / 180 sx = shear[0] * math.pi / 180 @@ -983,18 +988,15 @@ def _get_inverse_affine_matrix_tensor( # Inverted rotation matrix with scale and shear # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 zero = torch.zeros(1, device=device) - matrix = torch.cat([d, -b, zero, -c, a, zero]) - matrix = [x / scale for x in matrix] + matrix = torch.cat([d, -b, zero, -c, a, zero]) / scale # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 - matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty) - matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty) - # Apply center translation: C * RSS^-1 * C^-1 * T^-1 - matrix[2] += cx - matrix[5] += cy + new_matrix = matrix.clone() + new_matrix[2] = matrix[2] + cx + matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty) + new_matrix[5] = matrix[5] + cy + matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty) - return matrix + return new_matrix def rotate( @@ -1078,10 +1080,10 @@ def rotate( def affine( img: Tensor, - angle: Union[float, Tensor], - translate: Union[List[int], Tensor], - scale: Union[float, Tensor], - shear: Union[List[float], Tensor], + angle: Union[float, Tensor] = 0.0, + translate: Union[List[int], Tensor] = [0.0, 0.0], + scale: Union[float, Tensor] = 1.0, + shear: Union[List[float], Tensor] = [0.0, 0.0], interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None, resample: Optional[int] = None, @@ -1181,7 +1183,7 @@ def affine( pil_interpolation = pil_modes_mapping[interpolation] return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill) - translate_f = [1.0 * t for t in translate] + translate_f = translate if isinstance(translate, Tensor) else [1.0 * t for t in translate] matrix = _get_inverse_affine_matrix_tensor([0.0, 0.0], angle, translate_f, scale, shear) return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 589e4b38e4f..288546e7333 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -941,12 +941,12 @@ def autocontrast(img: Tensor) -> Tensor: bound = 1.0 if img.is_floating_point() else 255.0 dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype) - maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype) - eq_idxs = torch.where(minimum == maximum)[0] - minimum[eq_idxs] = 0 - maximum[eq_idxs] = bound + minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype).clone() + maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype).clone() scale = bound / (maximum - minimum) + eq_idxs = torch.isfinite(scale).logical_not() + minimum[eq_idxs] = 0 + scale[eq_idxs] = 1 return ((img - minimum) * scale).clamp(0, bound).to(img.dtype) From 7675a541f6e999f9725a8ef564a23d839484ba21 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Sun, 28 Nov 2021 14:47:12 -0500 Subject: [PATCH 05/51] fix a bug --- torchvision/transforms/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 4b62149359e..14630819381 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1179,7 +1179,7 @@ def affine( # it is visually better to estimate the center without 0.5 offset # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine center = [img_size[0] * 0.5, img_size[1] * 0.5] - matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear).tolist() + matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) pil_interpolation = pil_modes_mapping[interpolation] return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill) From bb3221b3a328ee09104fdddaef2831c9a515f467 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Sun, 28 Nov 2021 14:54:55 -0500 Subject: [PATCH 06/51] format code with ufmt --- torchvision/transforms/functional.py | 27 ++++++++++++++------- torchvision/transforms/functional_tensor.py | 14 ++++++++--- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 14630819381..e63aaef3bfc 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -961,16 +961,25 @@ def _get_inverse_affine_matrix_tensor( if device is None: device = element.device - center = center.to(device=device).flatten() if isinstance(center, Tensor) \ + center = ( + center.to(device=device).flatten() + if isinstance(center, Tensor) else torch.tensor(center, device=device).flatten() - angle = angle.to(device=device).flatten() if isinstance(angle, Tensor) \ - else torch.tensor(angle, device=device).flatten() - translate = translate.to(device=device).flatten() if isinstance(translate, Tensor) \ + ) + angle = ( + angle.to(device=device).flatten() if isinstance(angle, Tensor) else torch.tensor(angle, device=device).flatten() + ) + translate = ( + translate.to(device=device).flatten() + if isinstance(translate, Tensor) else torch.tensor(translate, device=device).flatten() - scale = scale.to(device=device).flatten() if isinstance(scale, Tensor) \ - else torch.tensor(scale, device=device).flatten() - shear = shear.to(device=device).flatten() if isinstance(shear, Tensor) \ - else torch.tensor(shear, device=device).flatten() + ) + scale = ( + scale.to(device=device).flatten() if isinstance(scale, Tensor) else torch.tensor(scale, device=device).flatten() + ) + shear = ( + shear.to(device=device).flatten() if isinstance(shear, Tensor) else torch.tensor(shear, device=device).flatten() + ) rot = angle * math.pi / 180 sx = shear[0] * math.pi / 180 @@ -1256,7 +1265,7 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool if not inplace: img = img.clone() - img[..., i: i + h, j: j + w] = v + img[..., i : i + h, j : j + w] = v return img diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 288546e7333..3635a1d580e 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -128,7 +128,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: if left < 0 or top < 0 or right > w or bottom > h: padding_ltrb = [max(-left, 0), max(-top, 0), max(right - w, 0), max(bottom - h, 0)] - return pad(img[..., max(top, 0): bottom, max(left, 0): right], padding_ltrb, fill=0) + return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0) return img[..., top:bottom, left:right] @@ -383,7 +383,7 @@ def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor: if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0: neg_min_padding = [-min(x, 0) for x in padding] crop_left, crop_right, crop_top, crop_bottom = neg_min_padding - img = img[..., crop_top: img.shape[-2] - crop_bottom, crop_left: img.shape[-1] - crop_right] + img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right] padding = [max(x, 0) for x in padding] in_sizes = img.size() @@ -699,8 +699,11 @@ def affine( _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - theta = matrix.to(dtype=dtype, device=img.device).reshape(1, 2, 3) if isinstance(matrix, Tensor) \ + theta = ( + matrix.to(dtype=dtype, device=img.device).reshape(1, 2, 3) + if isinstance(matrix, Tensor) else torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) + ) shape = img.shape # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) @@ -745,8 +748,11 @@ def rotate( w, h = img.shape[-1], img.shape[-2] ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - theta = matrix.to(dtype=dtype, device=img.device).reshape(1, 2, 3) if isinstance(matrix, Tensor) \ + theta = ( + matrix.to(dtype=dtype, device=img.device).reshape(1, 2, 3) + if isinstance(matrix, Tensor) else torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) + ) # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh) From 26ea403d07ae33de5dc7f288c41feb4895949948 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Sun, 28 Nov 2021 16:39:06 -0500 Subject: [PATCH 07/51] change default interpolation mode --- torchvision/transforms/functional.py | 8 ++++---- torchvision/transforms/functional_tensor.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index e63aaef3bfc..5e3f6238178 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1011,7 +1011,7 @@ def _get_inverse_affine_matrix_tensor( def rotate( img: Tensor, angle: Union[float, Tensor], - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, expand: bool = False, center: Optional[Union[List[int], Tensor]] = None, fill: Optional[List[float]] = None, @@ -1025,7 +1025,7 @@ def rotate( img (PIL Image or Tensor): image to be rotated. angle (number or Tensor): rotation angle value in degrees, counter-clockwise. interpolation (InterpolationMode): Desired interpolation enum defined by - :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. expand (bool, optional): Optional expansion flag. @@ -1093,7 +1093,7 @@ def affine( translate: Union[List[int], Tensor] = [0.0, 0.0], scale: Union[float, Tensor] = 1.0, shear: Union[List[float], Tensor] = [0.0, 0.0], - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: Optional[List[float]] = None, resample: Optional[int] = None, fillcolor: Optional[List[float]] = None, @@ -1111,7 +1111,7 @@ def affine( If a sequence is specified, the first value corresponds to a shear parallel to the x axis, while the second value corresponds to a shear parallel to the y axis. interpolation (InterpolationMode): Desired interpolation enum defined by - :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. fill (sequence or number, optional): Pixel fill value for the area outside the transformed diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 3635a1d580e..4e32c48b44d 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -693,7 +693,7 @@ def _gen_affine_grid( def affine( img: Tensor, matrix: Union[List[float], torch.Tensor], - interpolation: str = "nearest", + interpolation: str = "bilinear", fill: Optional[List[float]] = None, ) -> Tensor: _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) @@ -740,7 +740,7 @@ def _compute_output_size(matrix: Union[List[float], Tensor], w: int, h: int) -> def rotate( img: Tensor, matrix: Union[List[float], torch.Tensor], - interpolation: str = "nearest", + interpolation: str = "bilinear", expand: bool = False, fill: Optional[List[float]] = None, ) -> Tensor: From 418338bdfc6529daecb85881b6465c4a7711637a Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Sun, 28 Nov 2021 16:41:36 -0500 Subject: [PATCH 08/51] minor update --- torchvision/transforms/functional_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 4e32c48b44d..506b7276699 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -692,7 +692,7 @@ def _gen_affine_grid( def affine( img: Tensor, - matrix: Union[List[float], torch.Tensor], + matrix: Union[List[float], Tensor], interpolation: str = "bilinear", fill: Optional[List[float]] = None, ) -> Tensor: From dff78cfa05b404281a337d8b9f7a915d9ed9856b Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Sun, 28 Nov 2021 16:42:42 -0500 Subject: [PATCH 09/51] minor update 2 --- torchvision/transforms/functional_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 506b7276699..99b9910d80b 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -739,7 +739,7 @@ def _compute_output_size(matrix: Union[List[float], Tensor], w: int, h: int) -> def rotate( img: Tensor, - matrix: Union[List[float], torch.Tensor], + matrix: Union[List[float], Tensor], interpolation: str = "bilinear", expand: bool = False, fill: Optional[List[float]] = None, From d5d0aa74b486b3209148b42d2fa512ea069bf131 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 17:33:46 -0500 Subject: [PATCH 10/51] update type --- torchvision/transforms/functional_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index ad3f17b9e7f..ab2ed1860e9 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -16,7 +16,7 @@ def _assert_image_tensor(img: Tensor) -> None: raise TypeError("Tensor is not a torch image.") -def _assert_threshold(img: Tensor, threshold: float) -> None: +def _assert_threshold(img: Tensor, threshold: Union[float, Tensor]) -> None: bound = 1 if img.is_floating_point() else 255 if threshold > bound: raise TypeError("Threshold should be less than bound of img.") From 3187f9760c10b5c4dfb9701f16141748186ca6f7 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 17:53:58 -0500 Subject: [PATCH 11/51] try to fix JIT --- test.py | 5 +++++ torchvision/transforms/functional.py | 10 ++++++---- 2 files changed, 11 insertions(+), 4 deletions(-) create mode 100644 test.py diff --git a/test.py b/test.py new file mode 100644 index 00000000000..df08088fec9 --- /dev/null +++ b/test.py @@ -0,0 +1,5 @@ +import torch +import torchvision.transforms.functional as F + + +torch.jit.script(F.rotate) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 5e3f6238178..4986b12cb09 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1013,7 +1013,7 @@ def rotate( angle: Union[float, Tensor], interpolation: InterpolationMode = InterpolationMode.BILINEAR, expand: bool = False, - center: Optional[Union[List[int], Tensor]] = None, + center: Optional[Union[List[float], Tensor]] = None, fill: Optional[List[float]] = None, resample: Optional[int] = None, ) -> Tensor: @@ -1074,12 +1074,14 @@ def rotate( pil_interpolation = pil_modes_mapping[interpolation] return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill) - center_f = torch.zeros(2, device=img.device, dtype=torch.float) + center_f: Union[List[float], Tensor] = [0.0, 0.0] if center is not None: img_size = get_image_size(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. - center_f[0] = 1.0 * (center[0] - img_size[0] * 0.5) - center_f[1] = 1.0 * (center[1] - img_size[1] * 0.5) + if isinstance(center, Tensor): + center_f = 1.0 * (center - img_size * 0.5) + elif center is not None: + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)] # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. From 0f963f080889a3e9b02700cd4ca490eb93fa19dd Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 17:57:57 -0500 Subject: [PATCH 12/51] fix a bug --- torchvision/transforms/functional.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 4986b12cb09..96577912c1d 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1079,6 +1079,7 @@ def rotate( img_size = get_image_size(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. if isinstance(center, Tensor): + img_size = torch.tensor(img_size) center_f = 1.0 * (center - img_size * 0.5) elif center is not None: center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)] From 5163c1f8538a08abfd9e89c1f6f2d4699e01d3b8 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 18:02:47 -0500 Subject: [PATCH 13/51] Delete test.py --- test.py | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 test.py diff --git a/test.py b/test.py deleted file mode 100644 index df08088fec9..00000000000 --- a/test.py +++ /dev/null @@ -1,5 +0,0 @@ -import torch -import torchvision.transforms.functional as F - - -torch.jit.script(F.rotate) From d13cc80e31ee813d8a2c752d296d36a68652df19 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 18:05:09 -0500 Subject: [PATCH 14/51] fix default interpolation mode --- torchvision/transforms/functional.py | 8 ++++---- torchvision/transforms/functional_tensor.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 96577912c1d..3ad41a865fa 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1011,7 +1011,7 @@ def _get_inverse_affine_matrix_tensor( def rotate( img: Tensor, angle: Union[float, Tensor], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, center: Optional[Union[List[float], Tensor]] = None, fill: Optional[List[float]] = None, @@ -1025,7 +1025,7 @@ def rotate( img (PIL Image or Tensor): image to be rotated. angle (number or Tensor): rotation angle value in degrees, counter-clockwise. interpolation (InterpolationMode): Desired interpolation enum defined by - :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. expand (bool, optional): Optional expansion flag. @@ -1096,7 +1096,7 @@ def affine( translate: Union[List[int], Tensor] = [0.0, 0.0], scale: Union[float, Tensor] = 1.0, shear: Union[List[float], Tensor] = [0.0, 0.0], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None, resample: Optional[int] = None, fillcolor: Optional[List[float]] = None, @@ -1114,7 +1114,7 @@ def affine( If a sequence is specified, the first value corresponds to a shear parallel to the x axis, while the second value corresponds to a shear parallel to the y axis. interpolation (InterpolationMode): Desired interpolation enum defined by - :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. fill (sequence or number, optional): Pixel fill value for the area outside the transformed diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index ab2ed1860e9..ef574b75481 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -699,7 +699,7 @@ def _gen_affine_grid( def affine( img: Tensor, matrix: Union[List[float], Tensor], - interpolation: str = "bilinear", + interpolation: str = "nearest", fill: Optional[List[float]] = None, ) -> Tensor: _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) @@ -746,7 +746,7 @@ def _compute_output_size(matrix: Union[List[float], Tensor], w: int, h: int) -> def rotate( img: Tensor, matrix: Union[List[float], Tensor], - interpolation: str = "bilinear", + interpolation: str = "nearest", expand: bool = False, fill: Optional[List[float]] = None, ) -> Tensor: From 04761e30ee21536a917ff213a95b56c287f89afb Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 18:10:03 -0500 Subject: [PATCH 15/51] add test --- test/test_functional_tensor.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 8f923475664..790a0bcb2ec 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -154,6 +154,19 @@ def test_rotate_interpolation_type(self): res2 = F.rotate(tensor, 45, interpolation=BILINEAR) assert_equal(res1, res2) + @pytest.mark.parametrize("fn", [F.rotate, scripted_rotate]) + @pytest.mark.parametrize("center", [None, torch.tensor([0.1, 0.2], requires_grad=True)]) + def test_differentiable_rotate(self, fn, center): + alpha = torch.tensor(1.0, requires_grad=True) + x = torch.zeros(1, 3, 10, 10) + + y = fn(x, alpha, interpolation=BILINEAR, center=center) + assert y.requires_grad + y.mean().backward() + assert alpha.grad is not None + if center is not None: + assert center.grad is not None + class TestAffine: From 342b83b64f6df5008ce1ed2fcdc2f0e5389e55d4 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 18:10:17 -0500 Subject: [PATCH 16/51] fix --- torchvision/transforms/functional.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 3ad41a865fa..2b82481efaf 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -981,9 +981,9 @@ def _get_inverse_affine_matrix_tensor( shear.to(device=device).flatten() if isinstance(shear, Tensor) else torch.tensor(shear, device=device).flatten() ) - rot = angle * math.pi / 180 - sx = shear[0] * math.pi / 180 - sy = shear[1] * math.pi / 180 + rot = angle * torch.pi / 180.0 + sx = shear[0] * torch.pi / 180.0 + sy = shear[1] * torch.pi / 180.0 cx, cy = center tx, ty = translate @@ -1076,7 +1076,7 @@ def rotate( center_f: Union[List[float], Tensor] = [0.0, 0.0] if center is not None: - img_size = get_image_size(img) + img_size: Union[List[int], Tensor] = get_image_size(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. if isinstance(center, Tensor): img_size = torch.tensor(img_size) From d837f80e722e5854a8c495e5e1d1f37f93d05416 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 18:16:52 -0500 Subject: [PATCH 17/51] fix --- torchvision/transforms/functional.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 2b82481efaf..0b783be27f6 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1076,11 +1076,11 @@ def rotate( center_f: Union[List[float], Tensor] = [0.0, 0.0] if center is not None: - img_size: Union[List[int], Tensor] = get_image_size(img) + img_size = get_image_size(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. if isinstance(center, Tensor): - img_size = torch.tensor(img_size) - center_f = 1.0 * (center - img_size * 0.5) + img_size_t = torch.tensor(img_size) + center_f = 1.0 * (center - img_size_t * 0.5) elif center is not None: center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)] From 5f8d940641199a3f696ecc3ca7ddbe3a23d4949b Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 18:29:22 -0500 Subject: [PATCH 18/51] fix --- torchvision/transforms/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 0b783be27f6..50bf18729e5 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -985,8 +985,8 @@ def _get_inverse_affine_matrix_tensor( sx = shear[0] * torch.pi / 180.0 sy = shear[1] * torch.pi / 180.0 - cx, cy = center - tx, ty = translate + cx, cy = center[0], center[1] + tx, ty = translate[0], translate[1] # RSS without scaling a = torch.cos(rot - sy) / torch.cos(sy) From 595a529442522960b6bffbd54c14bb2933027db8 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 18:56:49 -0500 Subject: [PATCH 19/51] temporary fix --- torchvision/transforms/functional.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 50bf18729e5..9dea664546e 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1086,7 +1086,8 @@ def rotate( # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. - matrix = _get_inverse_affine_matrix_tensor(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) + matrix = _get_inverse_affine_matrix_tensor(center_f, -angle if isinstance(angle, Tensor) else -angle, + [0.0, 0.0], 1.0, [0.0, 0.0]) return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill) From 17370fc323b5ac173ba0c0f7406396370d1cf2f9 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 19:02:57 -0500 Subject: [PATCH 20/51] temporary fix --- torchvision/transforms/functional.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 9dea664546e..0fd1aa5f42b 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1086,8 +1086,11 @@ def rotate( # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. - matrix = _get_inverse_affine_matrix_tensor(center_f, -angle if isinstance(angle, Tensor) else -angle, - [0.0, 0.0], 1.0, [0.0, 0.0]) + if isinstance(angle, Tensor): + angle = -angle + else: + angle = -angle + matrix = _get_inverse_affine_matrix_tensor(center_f, angle, [0.0, 0.0], 1.0, [0.0, 0.0]) return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill) From 049c42a2f394330cfd4595e6c36709654f307395 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 19:35:04 -0500 Subject: [PATCH 21/51] fix --- torchvision/transforms/functional_tensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index ef574b75481..2ec6f33b6b5 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -730,7 +730,9 @@ def _compute_output_size(matrix: Union[List[float], Tensor], w: int, h: int) -> [0.5 * w, -0.5 * h, 1.0], ] ) - theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) + if isinstance(matrix, list): + theta = torch.tensor(matrix, dtype=torch.float) + theta = theta.reshape(1, 2, 3) new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2) min_vals, _ = new_pts.min(dim=0) max_vals, _ = new_pts.max(dim=0) From 27fb6430cb6f38d519153a5998a53266f5e27ab9 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 19:38:37 -0500 Subject: [PATCH 22/51] fix --- torchvision/transforms/functional_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 2ec6f33b6b5..8025d61e6f7 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -731,8 +731,8 @@ def _compute_output_size(matrix: Union[List[float], Tensor], w: int, h: int) -> ] ) if isinstance(matrix, list): - theta = torch.tensor(matrix, dtype=torch.float) - theta = theta.reshape(1, 2, 3) + matrix = torch.tensor(matrix, dtype=torch.float) + theta = matrix.reshape(1, 2, 3) new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2) min_vals, _ = new_pts.min(dim=0) max_vals, _ = new_pts.max(dim=0) From a1f83851e46d5f02d3898baa4a1598e3f1980e69 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 19:42:42 -0500 Subject: [PATCH 23/51] fix --- torchvision/transforms/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 0fd1aa5f42b..b369a327c94 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1162,7 +1162,7 @@ def affine( if len(translate) != 2: raise ValueError("Argument translate should be a sequence of length 2") - if scale <= 0.0: + if float(scale) <= 0.0: raise ValueError("Argument scale should be positive") if not isinstance(shear, (numbers.Number, (list, tuple, Tensor))): From 18fa9713ea2db02901087aa825b07e7e1ca7f434 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 19:49:18 -0500 Subject: [PATCH 24/51] fix --- torchvision/transforms/functional.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index b369a327c94..feaad0274aa 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1162,7 +1162,8 @@ def affine( if len(translate) != 2: raise ValueError("Argument translate should be a sequence of length 2") - if float(scale) <= 0.0: + scale_float = scale if isinstance(scale, float) else scale.item() + if scale_float <= 0.0: raise ValueError("Argument scale should be positive") if not isinstance(shear, (numbers.Number, (list, tuple, Tensor))): From 33a7ec74ccaea0e68f1be170a646b916325539ea Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 19:59:50 -0500 Subject: [PATCH 25/51] fix --- torchvision/transforms/functional.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index feaad0274aa..d8dd820c1bf 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1185,7 +1185,10 @@ def affine( shear = list(shear) if len(shear) == 1: - shear = [shear[0], shear[0]] + if isinstance(shear, list): + shear = [shear[0], shear[0]] + else: + shear = shear.flatten().repeat(2) if len(shear) != 2: raise ValueError(f"Shear should be a sequence containing two values. Got {shear}") From fc2f6741ea778d88e6bfbfd381090a24462b149e Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 20:03:14 -0500 Subject: [PATCH 26/51] fix --- torchvision/transforms/functional.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index d8dd820c1bf..d7eff62349d 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1203,7 +1203,9 @@ def affine( pil_interpolation = pil_modes_mapping[interpolation] return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill) - translate_f = translate if isinstance(translate, Tensor) else [1.0 * t for t in translate] + translate_f = translate + if isinstance(translate, list): + translate_f = [1.0 * t for t in translate] matrix = _get_inverse_affine_matrix_tensor([0.0, 0.0], angle, translate_f, scale, shear) return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill) From 7122f7136f70967e19a8ed43af826074c94b0a8f Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 20:12:46 -0500 Subject: [PATCH 27/51] fix type --- torchvision/transforms/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index d7eff62349d..7c1c6dc8913 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1097,7 +1097,7 @@ def rotate( def affine( img: Tensor, angle: Union[float, Tensor] = 0.0, - translate: Union[List[int], Tensor] = [0.0, 0.0], + translate: Union[List[float], Tensor] = [0.0, 0.0], scale: Union[float, Tensor] = 1.0, shear: Union[List[float], Tensor] = [0.0, 0.0], interpolation: InterpolationMode = InterpolationMode.NEAREST, From 723f99a8924f42ac7cf3beec55773cdcb712b94b Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 20:32:51 -0500 Subject: [PATCH 28/51] fix --- torchvision/transforms/functional.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 7c1c6dc8913..dbc9043b2a3 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1097,9 +1097,9 @@ def rotate( def affine( img: Tensor, angle: Union[float, Tensor] = 0.0, - translate: Union[List[float], Tensor] = [0.0, 0.0], + translate: Union[List[float], Tensor] = None, scale: Union[float, Tensor] = 1.0, - shear: Union[List[float], Tensor] = [0.0, 0.0], + shear: Union[List[float], Tensor] = None, interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None, resample: Optional[int] = None, @@ -1135,6 +1135,12 @@ def affine( Returns: PIL Image or Tensor: Transformed image. """ + if translate is None: + translate = [0.0, 0.0] + + if shear is None: + shear = [0.0, 0.0] + if resample is not None: warnings.warn( "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" From 52fae98c823f04f3e9d5d4101254ba75c9f4c195 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 20:37:48 -0500 Subject: [PATCH 29/51] fix --- torchvision/transforms/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index dbc9043b2a3..03f580c4745 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1097,9 +1097,9 @@ def rotate( def affine( img: Tensor, angle: Union[float, Tensor] = 0.0, - translate: Union[List[float], Tensor] = None, + translate: Optional[Union[List[float], Tensor]] = None, scale: Union[float, Tensor] = 1.0, - shear: Union[List[float], Tensor] = None, + shear: Optional[Union[List[float], Tensor]] = None, interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None, resample: Optional[int] = None, From da93544d273b28dad1a04618c2d0badb9457c0e7 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 23:12:44 -0500 Subject: [PATCH 30/51] fix --- torchvision/transforms/functional.py | 4 ++-- torchvision/transforms/functional_tensor.py | 18 ++++++++++++------ 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 03f580c4745..e742ecd8b30 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1010,7 +1010,7 @@ def _get_inverse_affine_matrix_tensor( def rotate( img: Tensor, - angle: Union[float, Tensor], + angle: Union[int, float, Tensor], interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, center: Optional[Union[List[float], Tensor]] = None, @@ -1089,7 +1089,7 @@ def rotate( if isinstance(angle, Tensor): angle = -angle else: - angle = -angle + angle = -1.0 * angle matrix = _get_inverse_affine_matrix_tensor(center_f, angle, [0.0, 0.0], 1.0, [0.0, 0.0]) return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 8025d61e6f7..15807304c39 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -18,7 +18,8 @@ def _assert_image_tensor(img: Tensor) -> None: def _assert_threshold(img: Tensor, threshold: Union[float, Tensor]) -> None: bound = 1 if img.is_floating_point() else 255 - if threshold > bound: + threshold_f = threshold if isinstance(threshold, float) else threshold.item() + if threshold_f > bound: raise TypeError("Threshold should be less than bound of img.") @@ -159,7 +160,8 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: def adjust_brightness(img: Tensor, brightness_factor: Union[float, Tensor]) -> Tensor: - if brightness_factor < 0: + brightness_factor_f = brightness_factor if isinstance(brightness_factor, float) else brightness_factor.item() + if brightness_factor_f < 0: raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") _assert_image_tensor(img) @@ -170,7 +172,8 @@ def adjust_brightness(img: Tensor, brightness_factor: Union[float, Tensor]) -> T def adjust_contrast(img: Tensor, contrast_factor: Union[float, Tensor]) -> Tensor: - if contrast_factor < 0: + contrast_factor_f = contrast_factor if isinstance(contrast_factor, float) else contrast_factor.item() + if contrast_factor_f < 0: raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") _assert_image_tensor(img) @@ -187,7 +190,8 @@ def adjust_contrast(img: Tensor, contrast_factor: Union[float, Tensor]) -> Tenso def adjust_hue(img: Tensor, hue_factor: Union[float, Tensor]) -> Tensor: - if not (-0.5 <= hue_factor <= 0.5): + hue_factor_f = hue_factor if isinstance(hue_factor, float) else hue_factor.item() + if not (-0.5 <= hue_factor_f <= 0.5): raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") if not (isinstance(img, torch.Tensor)): @@ -216,7 +220,8 @@ def adjust_hue(img: Tensor, hue_factor: Union[float, Tensor]) -> Tensor: def adjust_saturation(img: Tensor, saturation_factor: Union[float, Tensor]) -> Tensor: - if saturation_factor < 0: + saturation_factor_f = saturation_factor if isinstance(saturation_factor, float) else saturation_factor.item() + if saturation_factor_f < 0: raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") _assert_image_tensor(img) @@ -235,7 +240,8 @@ def adjust_gamma(img: Tensor, gamma: Union[float, Tensor], gain: Union[float, Te _assert_channels(img, [1, 3]) - if gamma < 0: + gamma_f = gamma if isinstance(gamma, float) else gamma.item() + if gamma_f < 0: raise ValueError("Gamma should be a non-negative real number") result = img From d07ee5fa0a410556143035612120bca179ff9346 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Tue, 7 Dec 2021 23:43:14 -0500 Subject: [PATCH 31/51] fix --- torchvision/transforms/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index e742ecd8b30..140a84db0b2 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1089,7 +1089,7 @@ def rotate( if isinstance(angle, Tensor): angle = -angle else: - angle = -1.0 * angle + angle = float(-angle) matrix = _get_inverse_affine_matrix_tensor(center_f, angle, [0.0, 0.0], 1.0, [0.0, 0.0]) return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill) From 7891ac1f9c6c533a9d81727359f234be2b9a6a22 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Wed, 8 Dec 2021 00:17:53 -0500 Subject: [PATCH 32/51] debug --- torchvision/transforms/functional.py | 9 ++++++--- torchvision/transforms/functional_tensor.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 140a84db0b2..8d8a2a5836d 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1013,7 +1013,7 @@ def rotate( angle: Union[int, float, Tensor], interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - center: Optional[Union[List[float], Tensor]] = None, + center: Optional[Union[List[float], Tuple[float, ...], Tensor]] = None, fill: Optional[List[float]] = None, resample: Optional[int] = None, ) -> Tensor: @@ -1086,17 +1086,20 @@ def rotate( # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. + if isinstance(angle, int): + angle = float(angle) + if isinstance(angle, Tensor): angle = -angle else: - angle = float(-angle) + angle = -angle matrix = _get_inverse_affine_matrix_tensor(center_f, angle, [0.0, 0.0], 1.0, [0.0, 0.0]) return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill) def affine( img: Tensor, - angle: Union[float, Tensor] = 0.0, + angle: Union[int, float, Tensor] = 0.0, translate: Optional[Union[List[float], Tensor]] = None, scale: Union[float, Tensor] = 1.0, shear: Optional[Union[List[float], Tensor]] = None, diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 15807304c39..a521f9e045b 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -589,7 +589,7 @@ def _assert_grid_transform_inputs( raise TypeError("Argument matrix should be a list or Tensor") if matrix is not None and len(matrix) != 6: - raise ValueError("Argument matrix should have 6 float values") + raise ValueError("Argument matrix should have 6 float values", str(matrix)) if coeffs is not None and len(coeffs) != 8: raise ValueError("Argument coeffs should have 8 float values") From 8e03fbf456047e8917aaa955ba7673594ba67328 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Wed, 8 Dec 2021 00:21:42 -0500 Subject: [PATCH 33/51] fix --- torchvision/transforms/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 8d8a2a5836d..a0b1c7ab487 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1013,7 +1013,7 @@ def rotate( angle: Union[int, float, Tensor], interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - center: Optional[Union[List[float], Tuple[float, ...], Tensor]] = None, + center: Optional[Union[List[float], Tuple[float, float], Tensor]] = None, fill: Optional[List[float]] = None, resample: Optional[int] = None, ) -> Tensor: From 6c56f6754e8e67742b253b757b6bc38941dbbf7d Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Wed, 8 Dec 2021 00:30:56 -0500 Subject: [PATCH 34/51] debug --- torchvision/transforms/functional.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index a0b1c7ab487..66126d913a3 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1067,6 +1067,9 @@ def rotate( if center is not None and not isinstance(center, (list, tuple, Tensor)): raise TypeError("Argument center should be a sequence") + if isinstance(center, tuple): + center = list(center) + if not isinstance(interpolation, InterpolationMode): raise TypeError("Argument interpolation should be a InterpolationMode") From 0f064745fa15561f9f5bbc181a77977fdc65e2c0 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Wed, 8 Dec 2021 01:52:22 -0500 Subject: [PATCH 35/51] test --- torchvision/transforms/functional_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index a521f9e045b..15807304c39 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -589,7 +589,7 @@ def _assert_grid_transform_inputs( raise TypeError("Argument matrix should be a list or Tensor") if matrix is not None and len(matrix) != 6: - raise ValueError("Argument matrix should have 6 float values", str(matrix)) + raise ValueError("Argument matrix should have 6 float values") if coeffs is not None and len(coeffs) != 8: raise ValueError("Argument coeffs should have 8 float values") From 016c0858536b273e79a34962d232e7aaa3785426 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Wed, 8 Dec 2021 01:59:09 -0500 Subject: [PATCH 36/51] debug core dumped --- torchvision/transforms/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 66126d913a3..a9e74081603 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1013,7 +1013,7 @@ def rotate( angle: Union[int, float, Tensor], interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - center: Optional[Union[List[float], Tuple[float, float], Tensor]] = None, + center: Optional[Union[List[float], Tensor]] = None, fill: Optional[List[float]] = None, resample: Optional[int] = None, ) -> Tensor: From 4a5231780bd6e7e9155bf0bead5d4c4b2ddf328c Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Wed, 8 Dec 2021 04:02:14 -0500 Subject: [PATCH 37/51] debug --- torchvision/transforms/functional.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index a9e74081603..a0b1c7ab487 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1013,7 +1013,7 @@ def rotate( angle: Union[int, float, Tensor], interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - center: Optional[Union[List[float], Tensor]] = None, + center: Optional[Union[List[float], Tuple[float, float], Tensor]] = None, fill: Optional[List[float]] = None, resample: Optional[int] = None, ) -> Tensor: @@ -1067,9 +1067,6 @@ def rotate( if center is not None and not isinstance(center, (list, tuple, Tensor)): raise TypeError("Argument center should be a sequence") - if isinstance(center, tuple): - center = list(center) - if not isinstance(interpolation, InterpolationMode): raise TypeError("Argument interpolation should be a InterpolationMode") From 083249a94c3ba48146293fa87679e5670bc36e64 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Wed, 8 Dec 2021 13:47:54 -0500 Subject: [PATCH 38/51] debug --- torchvision/transforms/functional.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index a0b1c7ab487..c0dd1611502 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1082,7 +1082,8 @@ def rotate( img_size_t = torch.tensor(img_size) center_f = 1.0 * (center - img_size_t * 0.5) elif center is not None: - center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)] + center_l = list(center) + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center_l, img_size)] # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. From 22da16dfde20d972468f560fc6ab932ddd9775de Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Wed, 8 Dec 2021 15:02:18 -0500 Subject: [PATCH 39/51] debug --- torchvision/transforms/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index c0dd1611502..1b750474e68 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1082,7 +1082,7 @@ def rotate( img_size_t = torch.tensor(img_size) center_f = 1.0 * (center - img_size_t * 0.5) elif center is not None: - center_l = list(center) + center_l = center if isinstance(center, list) else list(center) center_f = [1.0 * (c - s * 0.5) for c, s in zip(center_l, img_size)] # due to current incoherence of rotation angle direction between affine and rotate implementations From 9c63a144605052fd06683cb188ecc9ac27a5c8dc Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Wed, 8 Dec 2021 16:34:26 -0500 Subject: [PATCH 40/51] update --- torchvision/transforms/functional.py | 112 ++++++++------------ torchvision/transforms/functional_tensor.py | 66 +++++------- 2 files changed, 69 insertions(+), 109 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 1b750474e68..8cba53dfaf0 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -789,7 +789,8 @@ def adjust_brightness(img: Tensor, brightness_factor: Union[float, Tensor]) -> T if not isinstance(img, torch.Tensor): return F_pil.adjust_brightness(img, brightness_factor) - return F_t.adjust_brightness(img, brightness_factor) + brightness_factor_t = torch.as_tensor(brightness_factor, device=img.device) + return F_t.adjust_brightness(img, brightness_factor_t) def adjust_contrast(img: Tensor, contrast_factor: Union[float, Tensor]) -> Tensor: @@ -809,7 +810,8 @@ def adjust_contrast(img: Tensor, contrast_factor: Union[float, Tensor]) -> Tenso if not isinstance(img, torch.Tensor): return F_pil.adjust_contrast(img, contrast_factor) - return F_t.adjust_contrast(img, contrast_factor) + contrast_factor_t = torch.as_tensor(contrast_factor, device=img.device) + return F_t.adjust_contrast(img, contrast_factor_t) def adjust_saturation(img: Tensor, saturation_factor: Union[float, Tensor]) -> Tensor: @@ -829,7 +831,8 @@ def adjust_saturation(img: Tensor, saturation_factor: Union[float, Tensor]) -> T if not isinstance(img, torch.Tensor): return F_pil.adjust_saturation(img, saturation_factor) - return F_t.adjust_saturation(img, saturation_factor) + saturation_factor_t = torch.as_tensor(saturation_factor, device=img.device) + return F_t.adjust_saturation(img, saturation_factor_t) def adjust_hue(img: Tensor, hue_factor: Union[float, Tensor]) -> Tensor: @@ -863,7 +866,8 @@ def adjust_hue(img: Tensor, hue_factor: Union[float, Tensor]) -> Tensor: if not isinstance(img, torch.Tensor): return F_pil.adjust_hue(img, hue_factor) - return F_t.adjust_hue(img, hue_factor) + hue_factor_t = torch.as_tensor(hue_factor, device=img.device) + return F_t.adjust_hue(img, hue_factor_t) def adjust_gamma(img: Tensor, gamma: Union[float, Tensor], gain: Union[float, Tensor] = 1) -> Tensor: @@ -894,7 +898,9 @@ def adjust_gamma(img: Tensor, gamma: Union[float, Tensor], gain: Union[float, Te if not isinstance(img, torch.Tensor): return F_pil.adjust_gamma(img, gamma, gain) - return F_t.adjust_gamma(img, gamma, gain) + gamma_t = torch.as_tensor(gamma, device=img.device) + gain_t = torch.as_tensor(gain, device=img.device) + return F_t.adjust_gamma(img, gamma_t, gain_t) def _get_inverse_affine_matrix( @@ -949,42 +955,15 @@ def _get_inverse_affine_matrix( def _get_inverse_affine_matrix_tensor( - center: Union[List[float], Tensor], - angle: Union[float, Tensor], - translate: Union[List[float], Tensor], - scale: Union[float, Tensor], - shear: Union[List[float], Tensor], + center: Tensor, + angle: Tensor, + translate: Tensor, + scale: Tensor, + shear: Tensor, ) -> Tensor: - device: Optional[torch.device] = None - for element in [center, angle, translate, scale, shear]: - if isinstance(element, Tensor): - if device is None: - device = element.device - - center = ( - center.to(device=device).flatten() - if isinstance(center, Tensor) - else torch.tensor(center, device=device).flatten() - ) - angle = ( - angle.to(device=device).flatten() if isinstance(angle, Tensor) else torch.tensor(angle, device=device).flatten() - ) - translate = ( - translate.to(device=device).flatten() - if isinstance(translate, Tensor) - else torch.tensor(translate, device=device).flatten() - ) - scale = ( - scale.to(device=device).flatten() if isinstance(scale, Tensor) else torch.tensor(scale, device=device).flatten() - ) - shear = ( - shear.to(device=device).flatten() if isinstance(shear, Tensor) else torch.tensor(shear, device=device).flatten() - ) - rot = angle * torch.pi / 180.0 - sx = shear[0] * torch.pi / 180.0 - sy = shear[1] * torch.pi / 180.0 - + shear_rad = shear[0] * torch.pi / 180.0 + sx, sy = shear_rad[0], shear_rad[1] cx, cy = center[0], center[1] tx, ty = translate[0], translate[1] @@ -996,7 +975,7 @@ def _get_inverse_affine_matrix_tensor( # Inverted rotation matrix with scale and shear # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 - zero = torch.zeros(1, device=device) + zero = torch.zeros(1, device=a.device) matrix = torch.cat([d, -b, zero, -c, a, zero]) / scale # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 @@ -1074,27 +1053,21 @@ def rotate( pil_interpolation = pil_modes_mapping[interpolation] return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill) - center_f: Union[List[float], Tensor] = [0.0, 0.0] + center_t = torch.zeros(2, device=img.device) if center is not None: img_size = get_image_size(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. - if isinstance(center, Tensor): - img_size_t = torch.tensor(img_size) - center_f = 1.0 * (center - img_size_t * 0.5) - elif center is not None: - center_l = center if isinstance(center, list) else list(center) - center_f = [1.0 * (c - s * 0.5) for c, s in zip(center_l, img_size)] + center_org = torch.as_tensor(center, device=img.device) + img_size_t = torch.as_tensor(img_size, device=img.device) + center_t = 1.0 * (center_org - img_size_t * 0.5) # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. - if isinstance(angle, int): - angle = float(angle) - - if isinstance(angle, Tensor): - angle = -angle - else: - angle = -angle - matrix = _get_inverse_affine_matrix_tensor(center_f, angle, [0.0, 0.0], 1.0, [0.0, 0.0]) + angle = -torch.as_tensor(angle, dtype=torch.float, device=img.device) + translate = torch.zeros(2, dtype=torch.float, device=img.device) + scale = torch.zeros(1, dtype=torch.float, device=img.device) + shear = torch.zeros(2, dtype=torch.float, device=img.device) + matrix = _get_inverse_affine_matrix_tensor(center_t, angle, translate, scale, shear) return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill) @@ -1172,7 +1145,7 @@ def affine( if len(translate) != 2: raise ValueError("Argument translate should be a sequence of length 2") - scale_float = scale if isinstance(scale, float) else scale.item() + scale_float = scale.item() if isinstance(scale, Tensor) else scale if scale_float <= 0.0: raise ValueError("Argument scale should be positive") @@ -1195,10 +1168,10 @@ def affine( shear = list(shear) if len(shear) == 1: - if isinstance(shear, list): - shear = [shear[0], shear[0]] - else: + if isinstance(shear, Tensor): shear = shear.flatten().repeat(2) + else: + shear = [shear[0], shear[0]] if len(shear) != 2: raise ValueError(f"Shear should be a sequence containing two values. Got {shear}") @@ -1213,10 +1186,12 @@ def affine( pil_interpolation = pil_modes_mapping[interpolation] return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill) - translate_f = translate - if isinstance(translate, list): - translate_f = [1.0 * t for t in translate] - matrix = _get_inverse_affine_matrix_tensor([0.0, 0.0], angle, translate_f, scale, shear) + center_t = torch.zeros(2, device=img.device, dtype=torch.float) + angle_t = torch.as_tensor(angle, device=img.device, dtype=torch.float) + translate_t = torch.as_tensor(translate, device=img.device, dtype=torch.float) + scale_t = torch.as_tensor(scale, device=img.device, dtype=torch.float) + shear_t = torch.as_tensor(shear, device=img.device, dtype=torch.float) + matrix = _get_inverse_affine_matrix_tensor(center_t, angle_t, translate_t, scale_t, shear_t) return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill) @@ -1375,7 +1350,7 @@ def invert(img: Tensor) -> Tensor: return F_t.invert(img) -def posterize(img: Tensor, bits: Union[int, Tensor]) -> Tensor: +def posterize(img: Tensor, bits: int) -> Tensor: """Posterize an image by reducing the number of bits for each color channel. Args: @@ -1384,7 +1359,7 @@ def posterize(img: Tensor, bits: Union[int, Tensor]) -> Tensor: it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". - bits (int or Tensor): The number of bits to keep for each channel (0-8). + bits (int): The number of bits to keep for each channel (0-8). Returns: PIL Image or Tensor: Posterized image. """ @@ -1397,7 +1372,7 @@ def posterize(img: Tensor, bits: Union[int, Tensor]) -> Tensor: return F_t.posterize(img, bits) -def solarize(img: Tensor, threshold: Union[float, Tensor]) -> Tensor: +def solarize(img: Tensor, threshold: float) -> Tensor: """Solarize an RGB/grayscale image by inverting all pixel values above a threshold. Args: @@ -1405,7 +1380,7 @@ def solarize(img: Tensor, threshold: Union[float, Tensor]) -> Tensor: If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". - threshold (float or Tensor): All pixels equal or above this value are inverted. + threshold (float): All pixels equal or above this value are inverted. Returns: PIL Image or Tensor: Solarized image. """ @@ -1432,7 +1407,8 @@ def adjust_sharpness(img: Tensor, sharpness_factor: Union[float, Tensor]) -> Ten if not isinstance(img, torch.Tensor): return F_pil.adjust_sharpness(img, sharpness_factor) - return F_t.adjust_sharpness(img, sharpness_factor) + sharpness_factor_t = torch.as_tensor(sharpness_factor, device=img.device) + return F_t.adjust_sharpness(img, sharpness_factor_t) def autocontrast(img: Tensor) -> Tensor: diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 15807304c39..3e2ff6f8ed1 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, Tuple, List, Union +from typing import Optional, Tuple, List import torch from torch import Tensor @@ -16,10 +16,9 @@ def _assert_image_tensor(img: Tensor) -> None: raise TypeError("Tensor is not a torch image.") -def _assert_threshold(img: Tensor, threshold: Union[float, Tensor]) -> None: +def _assert_threshold(img: Tensor, threshold: Tensor) -> None: bound = 1 if img.is_floating_point() else 255 - threshold_f = threshold if isinstance(threshold, float) else threshold.item() - if threshold_f > bound: + if threshold > bound: raise TypeError("Threshold should be less than bound of img.") @@ -159,9 +158,8 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: return l_img -def adjust_brightness(img: Tensor, brightness_factor: Union[float, Tensor]) -> Tensor: - brightness_factor_f = brightness_factor if isinstance(brightness_factor, float) else brightness_factor.item() - if brightness_factor_f < 0: +def adjust_brightness(img: Tensor, brightness_factor: Tensor) -> Tensor: + if brightness_factor < 0: raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") _assert_image_tensor(img) @@ -171,9 +169,8 @@ def adjust_brightness(img: Tensor, brightness_factor: Union[float, Tensor]) -> T return _blend(img, torch.zeros_like(img), brightness_factor) -def adjust_contrast(img: Tensor, contrast_factor: Union[float, Tensor]) -> Tensor: - contrast_factor_f = contrast_factor if isinstance(contrast_factor, float) else contrast_factor.item() - if contrast_factor_f < 0: +def adjust_contrast(img: Tensor, contrast_factor: Tensor) -> Tensor: + if contrast_factor < 0: raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") _assert_image_tensor(img) @@ -189,9 +186,8 @@ def adjust_contrast(img: Tensor, contrast_factor: Union[float, Tensor]) -> Tenso return _blend(img, mean, contrast_factor) -def adjust_hue(img: Tensor, hue_factor: Union[float, Tensor]) -> Tensor: - hue_factor_f = hue_factor if isinstance(hue_factor, float) else hue_factor.item() - if not (-0.5 <= hue_factor_f <= 0.5): +def adjust_hue(img: Tensor, hue_factor: Tensor) -> Tensor: + if not (-0.5 <= hue_factor <= 0.5): raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") if not (isinstance(img, torch.Tensor)): @@ -219,9 +215,8 @@ def adjust_hue(img: Tensor, hue_factor: Union[float, Tensor]) -> Tensor: return img_hue_adj -def adjust_saturation(img: Tensor, saturation_factor: Union[float, Tensor]) -> Tensor: - saturation_factor_f = saturation_factor if isinstance(saturation_factor, float) else saturation_factor.item() - if saturation_factor_f < 0: +def adjust_saturation(img: Tensor, saturation_factor: Tensor) -> Tensor: + if saturation_factor < 0: raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") _assert_image_tensor(img) @@ -234,14 +229,13 @@ def adjust_saturation(img: Tensor, saturation_factor: Union[float, Tensor]) -> T return _blend(img, rgb_to_grayscale(img), saturation_factor) -def adjust_gamma(img: Tensor, gamma: Union[float, Tensor], gain: Union[float, Tensor] = 1) -> Tensor: +def adjust_gamma(img: Tensor, gamma: Tensor, gain: Tensor = 1) -> Tensor: if not isinstance(img, torch.Tensor): raise TypeError("Input img should be a Tensor.") _assert_channels(img, [1, 3]) - gamma_f = gamma if isinstance(gamma, float) else gamma.item() - if gamma_f < 0: + if gamma < 0: raise ValueError("Gamma should be a non-negative real number") result = img @@ -323,7 +317,7 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa return first_five + second_five -def _blend(img1: Tensor, img2: Tensor, ratio: Union[float, Tensor]) -> Tensor: +def _blend(img1: Tensor, img2: Tensor, ratio: Tensor) -> Tensor: bound = 1.0 if img1.is_floating_point() else 255.0 return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype) @@ -573,7 +567,7 @@ def resize( def _assert_grid_transform_inputs( img: Tensor, - matrix: Optional[Union[List[float], Tensor]], + matrix: Optional[Tensor], interpolation: str, fill: Optional[List[float]], supported_interpolation_modes: List[str], @@ -585,8 +579,8 @@ def _assert_grid_transform_inputs( _assert_image_tensor(img) - if matrix is not None and not isinstance(matrix, (list, Tensor)): - raise TypeError("Argument matrix should be a list or Tensor") + if matrix is not None and not isinstance(matrix, Tensor): + raise TypeError("Argument matrix should be a Tensor") if matrix is not None and len(matrix) != 6: raise ValueError("Argument matrix should have 6 float values") @@ -704,25 +698,21 @@ def _gen_affine_grid( def affine( img: Tensor, - matrix: Union[List[float], Tensor], + matrix: Tensor, interpolation: str = "nearest", fill: Optional[List[float]] = None, ) -> Tensor: _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - theta = ( - matrix.to(dtype=dtype, device=img.device).reshape(1, 2, 3) - if isinstance(matrix, Tensor) - else torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) - ) + theta = matrix.to(dtype=dtype, device=img.device).reshape(1, 2, 3) shape = img.shape # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) return _apply_grid_transform(img, grid, interpolation, fill=fill) -def _compute_output_size(matrix: Union[List[float], Tensor], w: int, h: int) -> Tuple[int, int]: +def _compute_output_size(matrix: Tensor, w: int, h: int) -> Tuple[int, int]: # Inspired of PIL implementation: # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 @@ -736,8 +726,6 @@ def _compute_output_size(matrix: Union[List[float], Tensor], w: int, h: int) -> [0.5 * w, -0.5 * h, 1.0], ] ) - if isinstance(matrix, list): - matrix = torch.tensor(matrix, dtype=torch.float) theta = matrix.reshape(1, 2, 3) new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2) min_vals, _ = new_pts.min(dim=0) @@ -753,7 +741,7 @@ def _compute_output_size(matrix: Union[List[float], Tensor], w: int, h: int) -> def rotate( img: Tensor, - matrix: Union[List[float], Tensor], + matrix: Tensor, interpolation: str = "nearest", expand: bool = False, fill: Optional[List[float]] = None, @@ -762,11 +750,7 @@ def rotate( w, h = img.shape[-1], img.shape[-2] ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - theta = ( - matrix.to(dtype=dtype, device=img.device).reshape(1, 2, 3) - if isinstance(matrix, Tensor) - else torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) - ) + theta = matrix.to(dtype=dtype, device=img.device).reshape(1, 2, 3) # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh) @@ -883,7 +867,7 @@ def invert(img: Tensor) -> Tensor: return bound - img -def posterize(img: Tensor, bits: Union[int, Tensor]) -> Tensor: +def posterize(img: Tensor, bits: Tensor) -> Tensor: _assert_image_tensor(img) @@ -897,7 +881,7 @@ def posterize(img: Tensor, bits: Union[int, Tensor]) -> Tensor: return img & mask -def solarize(img: Tensor, threshold: Union[float, Tensor]) -> Tensor: +def solarize(img: Tensor, threshold: Tensor) -> Tensor: _assert_image_tensor(img) @@ -935,7 +919,7 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor: return result -def adjust_sharpness(img: Tensor, sharpness_factor: Union[float, Tensor]) -> Tensor: +def adjust_sharpness(img: Tensor, sharpness_factor: Tensor) -> Tensor: if sharpness_factor < 0: raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.") From 9988e067187fe0501b399249aa58063185f6dc30 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Wed, 8 Dec 2021 16:42:00 -0500 Subject: [PATCH 41/51] update --- torchvision/transforms/functional.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 8cba53dfaf0..93d854bd8c9 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -2,7 +2,7 @@ import numbers import warnings from enum import Enum -from typing import List, Tuple, Any, Optional, Union +from typing import List, Tuple, Any, Optional import numpy as np import torch @@ -772,7 +772,7 @@ def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[ return first_five + second_five -def adjust_brightness(img: Tensor, brightness_factor: Union[float, Tensor]) -> Tensor: +def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: """Adjust brightness of an image. Args: @@ -793,7 +793,7 @@ def adjust_brightness(img: Tensor, brightness_factor: Union[float, Tensor]) -> T return F_t.adjust_brightness(img, brightness_factor_t) -def adjust_contrast(img: Tensor, contrast_factor: Union[float, Tensor]) -> Tensor: +def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: """Adjust contrast of an image. Args: @@ -814,7 +814,7 @@ def adjust_contrast(img: Tensor, contrast_factor: Union[float, Tensor]) -> Tenso return F_t.adjust_contrast(img, contrast_factor_t) -def adjust_saturation(img: Tensor, saturation_factor: Union[float, Tensor]) -> Tensor: +def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: """Adjust color saturation of an image. Args: @@ -835,7 +835,7 @@ def adjust_saturation(img: Tensor, saturation_factor: Union[float, Tensor]) -> T return F_t.adjust_saturation(img, saturation_factor_t) -def adjust_hue(img: Tensor, hue_factor: Union[float, Tensor]) -> Tensor: +def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: """Adjust hue of an image. The image hue is adjusted by converting the image to HSV and @@ -870,7 +870,7 @@ def adjust_hue(img: Tensor, hue_factor: Union[float, Tensor]) -> Tensor: return F_t.adjust_hue(img, hue_factor_t) -def adjust_gamma(img: Tensor, gamma: Union[float, Tensor], gain: Union[float, Tensor] = 1) -> Tensor: +def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: r"""Perform gamma correction on an image. Also known as Power Law Transform. Intensities in RGB mode are adjusted @@ -989,10 +989,10 @@ def _get_inverse_affine_matrix_tensor( def rotate( img: Tensor, - angle: Union[int, float, Tensor], + angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - center: Optional[Union[List[float], Tuple[float, float], Tensor]] = None, + center: Optional[List[float]] = None, fill: Optional[List[float]] = None, resample: Optional[int] = None, ) -> Tensor: @@ -1073,10 +1073,10 @@ def rotate( def affine( img: Tensor, - angle: Union[int, float, Tensor] = 0.0, - translate: Optional[Union[List[float], Tensor]] = None, - scale: Union[float, Tensor] = 1.0, - shear: Optional[Union[List[float], Tensor]] = None, + angle: float = 0.0, + translate: Optional[List[float]] = None, + scale: float = 1.0, + shear: Optional[List[float]] = None, interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None, resample: Optional[int] = None, @@ -1390,7 +1390,7 @@ def solarize(img: Tensor, threshold: float) -> Tensor: return F_t.solarize(img, threshold) -def adjust_sharpness(img: Tensor, sharpness_factor: Union[float, Tensor]) -> Tensor: +def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: """Adjust the sharpness of an image. Args: From 58db491ed49463e09b8ff8d60fe619020fe83ec0 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Wed, 8 Dec 2021 16:50:29 -0500 Subject: [PATCH 42/51] fix type --- torchvision/transforms/functional_tensor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 3e2ff6f8ed1..839f5912e83 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -229,7 +229,10 @@ def adjust_saturation(img: Tensor, saturation_factor: Tensor) -> Tensor: return _blend(img, rgb_to_grayscale(img), saturation_factor) -def adjust_gamma(img: Tensor, gamma: Tensor, gain: Tensor = 1) -> Tensor: +def adjust_gamma(img: Tensor, gamma: Tensor, gain: Optional[Tensor] = None) -> Tensor: + if gain is None: + gain = torch.ones(1, device=img.device, dtype=torch.float) + if not isinstance(img, torch.Tensor): raise TypeError("Input img should be a Tensor.") From a1a13e61082c4e79f1dc169e0b05414dfe6a1868 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Thu, 9 Dec 2021 00:33:15 -0500 Subject: [PATCH 43/51] fix a bug --- torchvision/transforms/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 93d854bd8c9..bb5977ffce5 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -962,7 +962,7 @@ def _get_inverse_affine_matrix_tensor( shear: Tensor, ) -> Tensor: rot = angle * torch.pi / 180.0 - shear_rad = shear[0] * torch.pi / 180.0 + shear_rad = shear * torch.pi / 180.0 sx, sy = shear_rad[0], shear_rad[1] cx, cy = center[0], center[1] tx, ty = translate[0], translate[1] From 87468c45a20fb4d4e4a5d7068340b0014d900bb4 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Thu, 9 Dec 2021 02:49:05 -0500 Subject: [PATCH 44/51] fix --- torchvision/transforms/functional.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index bb5977ffce5..7fcd59162fc 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -975,8 +975,8 @@ def _get_inverse_affine_matrix_tensor( # Inverted rotation matrix with scale and shear # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 - zero = torch.zeros(1, device=a.device) - matrix = torch.cat([d, -b, zero, -c, a, zero]) / scale + zero = torch.zeros([], device=a.device) + matrix = torch.stack([d, -b, zero, -c, a, zero]) / scale # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 # Apply center translation: C * RSS^-1 * C^-1 * T^-1 @@ -1167,9 +1167,11 @@ def affine( if isinstance(shear, tuple): shear = list(shear) + if isinstance(shear, Tensor): + shear = shear.flatten() if len(shear) == 1: if isinstance(shear, Tensor): - shear = shear.flatten().repeat(2) + shear = shear.repeat(2) else: shear = [shear[0], shear[0]] From 7c01367065f6645d0a9709b3f66ae0ed904122d7 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Thu, 9 Dec 2021 04:28:32 -0500 Subject: [PATCH 45/51] fix --- torchvision/transforms/functional.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 7fcd59162fc..f859bd49172 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -975,7 +975,8 @@ def _get_inverse_affine_matrix_tensor( # Inverted rotation matrix with scale and shear # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 - zero = torch.zeros([], device=a.device) + empty_list: List[int] = [] + zero = torch.zeros(empty_list, device=a.device) matrix = torch.stack([d, -b, zero, -c, a, zero]) / scale # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 From dd499f85074fb9be239702286d13af5f476194d2 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Thu, 9 Dec 2021 04:59:48 -0500 Subject: [PATCH 46/51] fix device --- torchvision/transforms/functional_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 839f5912e83..ab0bcaad083 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -727,7 +727,7 @@ def _compute_output_size(matrix: Tensor, w: int, h: int) -> Tuple[int, int]: [-0.5 * w, 0.5 * h, 1.0], [0.5 * w, 0.5 * h, 1.0], [0.5 * w, -0.5 * h, 1.0], - ] + ], device=matrix.device ) theta = matrix.reshape(1, 2, 3) new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2) From dc93a7619b9e5a48edac9d5e52be7d58a5212b1c Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Thu, 9 Dec 2021 08:04:16 -0500 Subject: [PATCH 47/51] fix a typo --- torchvision/transforms/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index f859bd49172..5177114d4ef 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1066,7 +1066,7 @@ def rotate( # we need to set -angle. angle = -torch.as_tensor(angle, dtype=torch.float, device=img.device) translate = torch.zeros(2, dtype=torch.float, device=img.device) - scale = torch.zeros(1, dtype=torch.float, device=img.device) + scale = torch.ones(1, dtype=torch.float, device=img.device) shear = torch.zeros(2, dtype=torch.float, device=img.device) matrix = _get_inverse_affine_matrix_tensor(center_t, angle, translate, scale, shear) return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill) From 68232e7f0beb58edeae62a40c82aed68f65b84b6 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Thu, 9 Dec 2021 08:40:44 -0500 Subject: [PATCH 48/51] ufmt format fix --- torchvision/transforms/functional_tensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index ab0bcaad083..88b8cf5e89f 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -727,7 +727,8 @@ def _compute_output_size(matrix: Tensor, w: int, h: int) -> Tuple[int, int]: [-0.5 * w, 0.5 * h, 1.0], [0.5 * w, 0.5 * h, 1.0], [0.5 * w, -0.5 * h, 1.0], - ], device=matrix.device + ], + device=matrix.device, ) theta = matrix.reshape(1, 2, 3) new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2) From 85185984191a718154eb133215a3a17599195a9e Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Thu, 9 Dec 2021 09:21:27 -0500 Subject: [PATCH 49/51] fix bugs --- torchvision/transforms/functional_tensor.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 88b8cf5e89f..285c248a80c 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -16,7 +16,7 @@ def _assert_image_tensor(img: Tensor) -> None: raise TypeError("Tensor is not a torch image.") -def _assert_threshold(img: Tensor, threshold: Tensor) -> None: +def _assert_threshold(img: Tensor, threshold: float) -> None: bound = 1 if img.is_floating_point() else 255 if threshold > bound: raise TypeError("Threshold should be less than bound of img.") @@ -871,7 +871,7 @@ def invert(img: Tensor) -> Tensor: return bound - img -def posterize(img: Tensor, bits: Tensor) -> Tensor: +def posterize(img: Tensor, bits: int) -> Tensor: _assert_image_tensor(img) @@ -885,7 +885,7 @@ def posterize(img: Tensor, bits: Tensor) -> Tensor: return img & mask -def solarize(img: Tensor, threshold: Tensor) -> Tensor: +def solarize(img: Tensor, threshold: float) -> Tensor: _assert_image_tensor(img) @@ -943,8 +943,6 @@ def autocontrast(img: Tensor) -> Tensor: if img.ndim < 3: raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") - elif img.ndim == 3: - img = img.unsqueeze(0) _assert_channels(img, [1, 3]) From 9fce5390fc2a7821081546a82aeb87598fcc2539 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Thu, 9 Dec 2021 09:42:04 -0500 Subject: [PATCH 50/51] Merge branch 'patch-1' --- torchvision/transforms/functional.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 5177114d4ef..783731ef880 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1056,11 +1056,10 @@ def rotate( center_t = torch.zeros(2, device=img.device) if center is not None: - img_size = get_image_size(img) + img_size = torch.as_tensor(get_image_size(img), device=img.device) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. center_org = torch.as_tensor(center, device=img.device) - img_size_t = torch.as_tensor(img_size, device=img.device) - center_t = 1.0 * (center_org - img_size_t * 0.5) + center_t = 1.0 * (center_org - img_size * 0.5) # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. @@ -1074,10 +1073,10 @@ def rotate( def affine( img: Tensor, - angle: float = 0.0, - translate: Optional[List[float]] = None, - scale: float = 1.0, - shear: Optional[List[float]] = None, + angle: float, + translate: List[float], + scale: float, + shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None, resample: Optional[int] = None, @@ -1113,11 +1112,6 @@ def affine( Returns: PIL Image or Tensor: Transformed image. """ - if translate is None: - translate = [0.0, 0.0] - - if shear is None: - shear = [0.0, 0.0] if resample is not None: warnings.warn( From d4011fc81cdf32cbd4aa3394ea4429b4e4f9b9d3 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Thu, 9 Dec 2021 14:03:40 -0500 Subject: [PATCH 51/51] revert List[float] back to List[int] --- torchvision/transforms/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 783731ef880..07563978581 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -993,7 +993,7 @@ def rotate( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - center: Optional[List[float]] = None, + center: Optional[List[int]] = None, fill: Optional[List[float]] = None, resample: Optional[int] = None, ) -> Tensor: @@ -1074,7 +1074,7 @@ def rotate( def affine( img: Tensor, angle: float, - translate: List[float], + translate: List[int], scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST,