diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 8f923475664..790a0bcb2ec 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -154,6 +154,19 @@ def test_rotate_interpolation_type(self): res2 = F.rotate(tensor, 45, interpolation=BILINEAR) assert_equal(res1, res2) + @pytest.mark.parametrize("fn", [F.rotate, scripted_rotate]) + @pytest.mark.parametrize("center", [None, torch.tensor([0.1, 0.2], requires_grad=True)]) + def test_differentiable_rotate(self, fn, center): + alpha = torch.tensor(1.0, requires_grad=True) + x = torch.zeros(1, 3, 10, 10) + + y = fn(x, alpha, interpolation=BILINEAR, center=center) + assert y.requires_grad + y.mean().backward() + assert alpha.grad is not None + if center is not None: + assert center.grad is not None + class TestAffine: diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index bd5b170626e..07563978581 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -779,7 +779,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: img (PIL Image or Tensor): Image to be adjusted. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. - brightness_factor (float): How much to adjust the brightness. Can be + brightness_factor (float or Tensor): How much to adjust the brightness. Can be any non negative number. 0 gives a black image, 1 gives the original image while 2 increases the brightness by a factor of 2. @@ -789,7 +789,8 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: if not isinstance(img, torch.Tensor): return F_pil.adjust_brightness(img, brightness_factor) - return F_t.adjust_brightness(img, brightness_factor) + brightness_factor_t = torch.as_tensor(brightness_factor, device=img.device) + return F_t.adjust_brightness(img, brightness_factor_t) def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: @@ -799,7 +800,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: img (PIL Image or Tensor): Image to be adjusted. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. - contrast_factor (float): How much to adjust the contrast. Can be any + contrast_factor (float or Tensor): How much to adjust the contrast. Can be any non negative number. 0 gives a solid gray image, 1 gives the original image while 2 increases the contrast by a factor of 2. @@ -809,7 +810,8 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: if not isinstance(img, torch.Tensor): return F_pil.adjust_contrast(img, contrast_factor) - return F_t.adjust_contrast(img, contrast_factor) + contrast_factor_t = torch.as_tensor(contrast_factor, device=img.device) + return F_t.adjust_contrast(img, contrast_factor_t) def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: @@ -819,7 +821,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: img (PIL Image or Tensor): Image to be adjusted. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. - saturation_factor (float): How much to adjust the saturation. 0 will + saturation_factor (float or Tensor): How much to adjust the saturation. 0 will give a black and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2. @@ -829,7 +831,8 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: if not isinstance(img, torch.Tensor): return F_pil.adjust_saturation(img, saturation_factor) - return F_t.adjust_saturation(img, saturation_factor) + saturation_factor_t = torch.as_tensor(saturation_factor, device=img.device) + return F_t.adjust_saturation(img, saturation_factor_t) def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: @@ -851,7 +854,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image mode "1", "I", "F" and modes with transparency (alpha channel) are not supported. - hue_factor (float): How much to shift the hue channel. Should be in + hue_factor (float or Tensor): How much to shift the hue channel. Should be in [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in HSV space in positive and negative direction respectively. 0 means no shift. Therefore, both -0.5 and 0.5 will give an image @@ -863,7 +866,8 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: if not isinstance(img, torch.Tensor): return F_pil.adjust_hue(img, hue_factor) - return F_t.adjust_hue(img, hue_factor) + hue_factor_t = torch.as_tensor(hue_factor, device=img.device) + return F_t.adjust_hue(img, hue_factor_t) def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: @@ -884,17 +888,19 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, modes with transparency (alpha channel) are not supported. - gamma (float): Non negative real number, same as :math:`\gamma` in the equation. + gamma (float or Tensor): Non negative real number, same as :math:`\gamma` in the equation. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter. - gain (float): The constant multiplier. + gain (float or Tensor): The constant multiplier. Returns: PIL Image or Tensor: Gamma correction adjusted image. """ if not isinstance(img, torch.Tensor): return F_pil.adjust_gamma(img, gamma, gain) - return F_t.adjust_gamma(img, gamma, gain) + gamma_t = torch.as_tensor(gamma, device=img.device) + gain_t = torch.as_tensor(gain, device=img.device) + return F_t.adjust_gamma(img, gamma_t, gain_t) def _get_inverse_affine_matrix( @@ -948,6 +954,40 @@ def _get_inverse_affine_matrix( return matrix +def _get_inverse_affine_matrix_tensor( + center: Tensor, + angle: Tensor, + translate: Tensor, + scale: Tensor, + shear: Tensor, +) -> Tensor: + rot = angle * torch.pi / 180.0 + shear_rad = shear * torch.pi / 180.0 + sx, sy = shear_rad[0], shear_rad[1] + cx, cy = center[0], center[1] + tx, ty = translate[0], translate[1] + + # RSS without scaling + a = torch.cos(rot - sy) / torch.cos(sy) + b = -torch.cos(rot - sy) * torch.tan(sx) / torch.cos(sy) - torch.sin(rot) + c = torch.sin(rot - sy) / torch.cos(sy) + d = -torch.sin(rot - sy) * torch.tan(sx) / torch.cos(sy) + torch.cos(rot) + + # Inverted rotation matrix with scale and shear + # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 + empty_list: List[int] = [] + zero = torch.zeros(empty_list, device=a.device) + matrix = torch.stack([d, -b, zero, -c, a, zero]) / scale + + # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 + # Apply center translation: C * RSS^-1 * C^-1 * T^-1 + new_matrix = matrix.clone() + new_matrix[2] = matrix[2] + cx + matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty) + new_matrix[5] = matrix[5] + cy + matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty) + + return new_matrix + + def rotate( img: Tensor, angle: float, @@ -963,7 +1003,7 @@ def rotate( Args: img (PIL Image or Tensor): image to be rotated. - angle (number): rotation angle value in degrees, counter-clockwise. + angle (number or Tensor): rotation angle value in degrees, counter-clockwise. interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. @@ -972,7 +1012,7 @@ def rotate( If true, expands the output image to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. Note that the expand flag assumes rotation around the center and no translation. - center (sequence, optional): Optional center of rotation. Origin is the upper left corner. + center (sequence or Tensor, optional): Optional center of rotation. Origin is the upper left corner. Default is the center of the image. fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. @@ -1001,10 +1041,10 @@ def rotate( ) interpolation = _interpolation_modes_from_int(interpolation) - if not isinstance(angle, (int, float)): - raise TypeError("Argument angle should be int or float") + if not isinstance(angle, (int, float, Tensor)): + raise TypeError("Argument angle should be int, float or Tensor") - if center is not None and not isinstance(center, (list, tuple)): + if center is not None and not isinstance(center, (list, tuple, Tensor)): raise TypeError("Argument center should be a sequence") if not isinstance(interpolation, InterpolationMode): @@ -1014,15 +1054,20 @@ def rotate( pil_interpolation = pil_modes_mapping[interpolation] return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill) - center_f = [0.0, 0.0] + center_t = torch.zeros(2, device=img.device) if center is not None: - img_size = get_image_size(img) + img_size = torch.as_tensor(get_image_size(img), device=img.device) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. - center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)] + center_org = torch.as_tensor(center, device=img.device) + center_t = 1.0 * (center_org - img_size * 0.5) # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. - matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) + angle = -torch.as_tensor(angle, dtype=torch.float, device=img.device) + translate = torch.zeros(2, dtype=torch.float, device=img.device) + scale = torch.ones(1, dtype=torch.float, device=img.device) + shear = torch.zeros(2, dtype=torch.float, device=img.device) + matrix = _get_inverse_affine_matrix_tensor(center_t, angle, translate, scale, shear) return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill) @@ -1043,10 +1088,10 @@ def affine( Args: img (PIL Image or Tensor): image to transform. - angle (number): rotation angle in degrees between -180 and 180, clockwise direction. - translate (sequence of integers): horizontal and vertical translations (post-rotation translation) - scale (float): overall scale - shear (float or sequence): shear angle value in degrees between -180 to 180, clockwise direction. + angle (number or Tensor): rotation angle in degrees between -180 and 180, clockwise direction. + translate (sequence of integers or Tensor): horizontal and vertical translations (post-rotation translation) + scale (float or Tensor): overall scale + shear (float or sequence or Tensor): shear angle value in degrees between -180 to 180, clockwise direction. If a sequence is specified, the first value corresponds to a shear parallel to the x axis, while the second value corresponds to a shear parallel to the y axis. interpolation (InterpolationMode): Desired interpolation enum defined by @@ -1067,6 +1112,7 @@ def affine( Returns: PIL Image or Tensor: Transformed image. """ + if resample is not None: warnings.warn( "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" @@ -1085,19 +1131,20 @@ def affine( warnings.warn("Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead") fill = fillcolor - if not isinstance(angle, (int, float)): - raise TypeError("Argument angle should be int or float") + if not isinstance(angle, (int, float, Tensor)): + raise TypeError("Argument angle should be int, float or Tensor") - if not isinstance(translate, (list, tuple)): - raise TypeError("Argument translate should be a sequence") + if not isinstance(translate, (list, tuple, Tensor)): + raise TypeError("Argument translate should be a sequence or Tensor") if len(translate) != 2: raise ValueError("Argument translate should be a sequence of length 2") - if scale <= 0.0: + scale_float = scale.item() if isinstance(scale, Tensor) else scale + if scale_float <= 0.0: raise ValueError("Argument scale should be positive") - if not isinstance(shear, (numbers.Number, (list, tuple))): + if not isinstance(shear, (numbers.Number, (list, tuple, Tensor))): raise TypeError("Shear should be either a single value or a sequence of two values") if not isinstance(interpolation, InterpolationMode): @@ -1115,8 +1162,13 @@ def affine( if isinstance(shear, tuple): shear = list(shear) + if isinstance(shear, Tensor): + shear = shear.flatten() if len(shear) == 1: - shear = [shear[0], shear[0]] + if isinstance(shear, Tensor): + shear = shear.repeat(2) + else: + shear = [shear[0], shear[0]] if len(shear) != 2: raise ValueError(f"Shear should be a sequence containing two values. Got {shear}") @@ -1131,8 +1183,12 @@ def affine( pil_interpolation = pil_modes_mapping[interpolation] return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill) - translate_f = [1.0 * t for t in translate] - matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate_f, scale, shear) + center_t = torch.zeros(2, device=img.device, dtype=torch.float) + angle_t = torch.as_tensor(angle, device=img.device, dtype=torch.float) + translate_t = torch.as_tensor(translate, device=img.device, dtype=torch.float) + scale_t = torch.as_tensor(scale, device=img.device, dtype=torch.float) + shear_t = torch.as_tensor(shear, device=img.device, dtype=torch.float) + matrix = _get_inverse_affine_matrix_tensor(center_t, angle_t, translate_t, scale_t, shear_t) return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill) @@ -1338,7 +1394,7 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: img (PIL Image or Tensor): Image to be adjusted. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. - sharpness_factor (float): How much to adjust the sharpness. Can be + sharpness_factor (float or Tensor): How much to adjust the sharpness. Can be any non negative number. 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness by a factor of 2. @@ -1348,7 +1404,8 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: if not isinstance(img, torch.Tensor): return F_pil.adjust_sharpness(img, sharpness_factor) - return F_t.adjust_sharpness(img, sharpness_factor) + sharpness_factor_t = torch.as_tensor(sharpness_factor, device=img.device) + return F_t.adjust_sharpness(img, sharpness_factor_t) def autocontrast(img: Tensor) -> Tensor: diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 4e20c19e45f..285c248a80c 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -158,7 +158,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: return l_img -def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: +def adjust_brightness(img: Tensor, brightness_factor: Tensor) -> Tensor: if brightness_factor < 0: raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") @@ -169,7 +169,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: return _blend(img, torch.zeros_like(img), brightness_factor) -def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: +def adjust_contrast(img: Tensor, contrast_factor: Tensor) -> Tensor: if contrast_factor < 0: raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") @@ -186,7 +186,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: return _blend(img, mean, contrast_factor) -def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: +def adjust_hue(img: Tensor, hue_factor: Tensor) -> Tensor: if not (-0.5 <= hue_factor <= 0.5): raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") @@ -215,7 +215,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: return img_hue_adj -def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: +def adjust_saturation(img: Tensor, saturation_factor: Tensor) -> Tensor: if saturation_factor < 0: raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") @@ -229,7 +229,10 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: return _blend(img, rgb_to_grayscale(img), saturation_factor) -def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: +def adjust_gamma(img: Tensor, gamma: Tensor, gain: Optional[Tensor] = None) -> Tensor: + if gain is None: + gain = torch.ones(1, device=img.device, dtype=torch.float) + if not isinstance(img, torch.Tensor): raise TypeError("Input img should be a Tensor.") @@ -317,8 +320,7 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa return first_five + second_five -def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: - ratio = float(ratio) +def _blend(img1: Tensor, img2: Tensor, ratio: Tensor) -> Tensor: bound = 1.0 if img1.is_floating_point() else 255.0 return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype) @@ -568,7 +570,7 @@ def resize( def _assert_grid_transform_inputs( img: Tensor, - matrix: Optional[List[float]], + matrix: Optional[Tensor], interpolation: str, fill: Optional[List[float]], supported_interpolation_modes: List[str], @@ -580,8 +582,8 @@ def _assert_grid_transform_inputs( _assert_image_tensor(img) - if matrix is not None and not isinstance(matrix, list): - raise TypeError("Argument matrix should be a list") + if matrix is not None and not isinstance(matrix, Tensor): + raise TypeError("Argument matrix should be a Tensor") if matrix is not None and len(matrix) != 6: raise ValueError("Argument matrix should have 6 float values") @@ -698,19 +700,22 @@ def _gen_affine_grid( def affine( - img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None + img: Tensor, + matrix: Tensor, + interpolation: str = "nearest", + fill: Optional[List[float]] = None, ) -> Tensor: _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) + theta = matrix.to(dtype=dtype, device=img.device).reshape(1, 2, 3) shape = img.shape # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) return _apply_grid_transform(img, grid, interpolation, fill=fill) -def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]: +def _compute_output_size(matrix: Tensor, w: int, h: int) -> Tuple[int, int]: # Inspired of PIL implementation: # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 @@ -722,9 +727,10 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] [-0.5 * w, 0.5 * h, 1.0], [0.5 * w, 0.5 * h, 1.0], [0.5 * w, -0.5 * h, 1.0], - ] + ], + device=matrix.device, ) - theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) + theta = matrix.reshape(1, 2, 3) new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2) min_vals, _ = new_pts.min(dim=0) max_vals, _ = new_pts.max(dim=0) @@ -739,7 +745,7 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] def rotate( img: Tensor, - matrix: List[float], + matrix: Tensor, interpolation: str = "nearest", expand: bool = False, fill: Optional[List[float]] = None, @@ -748,7 +754,7 @@ def rotate( w, h = img.shape[-1], img.shape[-2] ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) + theta = matrix.to(dtype=dtype, device=img.device).reshape(1, 2, 3) # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh) @@ -917,7 +923,7 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor: return result -def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: +def adjust_sharpness(img: Tensor, sharpness_factor: Tensor) -> Tensor: if sharpness_factor < 0: raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.") @@ -943,8 +949,8 @@ def autocontrast(img: Tensor) -> Tensor: bound = 1.0 if img.is_floating_point() else 255.0 dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype) - maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype) + minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype).clone() + maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype).clone() scale = bound / (maximum - minimum) eq_idxs = torch.isfinite(scale).logical_not() minimum[eq_idxs] = 0