diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 7c6aee1f376..03eefd5aaee 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1,8 +1,9 @@ +import functools import math import numbers import warnings from enum import Enum -from typing import List, Tuple, Any, Optional +from typing import List, Tuple, Any, Optional, TypeVar, Callable, cast import numpy as np import torch @@ -18,6 +19,17 @@ from . import functional_pil as F_pil from . import functional_tensor as F_t +F = TypeVar("F", bound=Callable[..., Any]) + + +def log_api_usage_once(fn: F) -> F: + @functools.wraps(fn) + def wrapper(*args, **kwargs): + _log_api_usage_once(fn) + return fn(*args, **kwargs) + + return cast(F, wrapper) + class InterpolationMode(Enum): """Interpolation modes @@ -103,6 +115,7 @@ def _is_numpy_image(img: Any) -> bool: return img.ndim in {2, 3} +@log_api_usage_once def to_tensor(pic): """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This function does not support torchscript. @@ -115,8 +128,6 @@ def to_tensor(pic): Returns: Tensor: Converted image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(to_tensor) if not (F_pil._is_pil_image(pic) or _is_numpy(pic)): raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}") @@ -157,6 +168,7 @@ def to_tensor(pic): return img +@log_api_usage_once def pil_to_tensor(pic): """Convert a ``PIL Image`` to a tensor of the same type. This function does not support torchscript. @@ -173,8 +185,6 @@ def pil_to_tensor(pic): Returns: Tensor: Converted image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(pil_to_tensor) if not F_pil._is_pil_image(pic): raise TypeError(f"pic should be PIL Image. Got {type(pic)}") @@ -192,6 +202,7 @@ def pil_to_tensor(pic): return img +@log_api_usage_once def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: """Convert a tensor image to the given ``dtype`` and scale the values accordingly This function does not support PIL Image. @@ -214,14 +225,13 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range of the integer ``dtype``. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(convert_image_dtype) if not isinstance(image, torch.Tensor): raise TypeError("Input img should be Tensor Image") return F_t.convert_image_dtype(image, dtype) +@log_api_usage_once def to_pil_image(pic, mode=None): """Convert a tensor or an ndarray to PIL Image. This function does not support torchscript. @@ -236,8 +246,6 @@ def to_pil_image(pic, mode=None): Returns: PIL Image: Image converted to PIL Image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(to_pil_image) if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)): raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.") @@ -317,6 +325,7 @@ def to_pil_image(pic, mode=None): return Image.fromarray(npimg, mode=mode) +@log_api_usage_once def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor: """Normalize a float tensor image with mean and standard deviation. This transform does not support PIL Image. @@ -335,8 +344,6 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool Returns: Tensor: Normalized Tensor image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(normalize) if not isinstance(tensor, torch.Tensor): raise TypeError(f"Input tensor should be a torch tensor. Got {type(tensor)}.") @@ -364,6 +371,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool return tensor +@log_api_usage_once def resize( img: Tensor, size: List[int], @@ -416,8 +424,6 @@ def resize( Returns: PIL Image or Tensor: Resized image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(resize) # Backward compatibility with integer value if isinstance(interpolation, int): warnings.warn( @@ -438,6 +444,7 @@ def resize( return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias) +@log_api_usage_once def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor: r"""Pad the given image on all sides with the given "pad" value. If the image is torch Tensor, it is expected @@ -479,14 +486,13 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con Returns: PIL Image or Tensor: Padded image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(pad) if not isinstance(img, torch.Tensor): return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode) return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode) +@log_api_usage_once def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: """Crop the given image at specified location and output size. If the image is torch Tensor, it is expected @@ -503,15 +509,13 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: Returns: PIL Image or Tensor: Cropped image. """ - - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(crop) if not isinstance(img, torch.Tensor): return F_pil.crop(img, top, left, height, width) return F_t.crop(img, top, left, height, width) +@log_api_usage_once def center_crop(img: Tensor, output_size: List[int]) -> Tensor: """Crops the given image at the center. If the image is torch Tensor, it is expected @@ -526,8 +530,6 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: Returns: PIL Image or Tensor: Cropped image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(center_crop) if isinstance(output_size, numbers.Number): output_size = (int(output_size), int(output_size)) elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: @@ -553,6 +555,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: return crop(img, crop_top, crop_left, crop_height, crop_width) +@log_api_usage_once def resized_crop( img: Tensor, top: int, @@ -584,13 +587,12 @@ def resized_crop( Returns: PIL Image or Tensor: Cropped image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(resized_crop) img = crop(img, top, left, height, width) img = resize(img, size, interpolation) return img +@log_api_usage_once def hflip(img: Tensor) -> Tensor: """Horizontally flip the given image. @@ -603,8 +605,6 @@ def hflip(img: Tensor) -> Tensor: Returns: PIL Image or Tensor: Horizontally flipped image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(hflip) if not isinstance(img, torch.Tensor): return F_pil.hflip(img) @@ -639,6 +639,7 @@ def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[i return output +@log_api_usage_once def perspective( img: Tensor, startpoints: List[List[int]], @@ -670,9 +671,6 @@ def perspective( Returns: PIL Image or Tensor: transformed Image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(perspective) - coeffs = _get_perspective_coeffs(startpoints, endpoints) # Backward compatibility with integer value @@ -693,6 +691,7 @@ def perspective( return F_t.perspective(img, coeffs, interpolation=interpolation.value, fill=fill) +@log_api_usage_once def vflip(img: Tensor) -> Tensor: """Vertically flip the given image. @@ -705,14 +704,13 @@ def vflip(img: Tensor) -> Tensor: Returns: PIL Image or Tensor: Vertically flipped image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(vflip) if not isinstance(img, torch.Tensor): return F_pil.vflip(img) return F_t.vflip(img) +@log_api_usage_once def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """Crop the given image into four corners and the central crop. If the image is torch Tensor, it is expected @@ -732,8 +730,6 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten tuple: tuple (tl, tr, bl, br, center) Corresponding top left, top right, bottom left, bottom right and center crop. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(five_crop) if isinstance(size, numbers.Number): size = (int(size), int(size)) elif isinstance(size, (tuple, list)) and len(size) == 1: @@ -758,6 +754,7 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten return tl, tr, bl, br, center +@log_api_usage_once def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[Tensor]: """Generate ten cropped images from the given image. Crop the given image into four corners and the central crop plus the @@ -781,8 +778,6 @@ def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[ Corresponding top left, top right, bottom left, bottom right and center crop and same for the flipped image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(ten_crop) if isinstance(size, numbers.Number): size = (int(size), int(size)) elif isinstance(size, (tuple, list)) and len(size) == 1: @@ -802,6 +797,7 @@ def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[ return first_five + second_five +@log_api_usage_once def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: """Adjust brightness of an image. @@ -816,14 +812,13 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: Returns: PIL Image or Tensor: Brightness adjusted image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(adjust_brightness) if not isinstance(img, torch.Tensor): return F_pil.adjust_brightness(img, brightness_factor) return F_t.adjust_brightness(img, brightness_factor) +@log_api_usage_once def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: """Adjust contrast of an image. @@ -838,14 +833,13 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: Returns: PIL Image or Tensor: Contrast adjusted image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(adjust_contrast) if not isinstance(img, torch.Tensor): return F_pil.adjust_contrast(img, contrast_factor) return F_t.adjust_contrast(img, contrast_factor) +@log_api_usage_once def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: """Adjust color saturation of an image. @@ -860,14 +854,13 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: Returns: PIL Image or Tensor: Saturation adjusted image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(adjust_saturation) if not isinstance(img, torch.Tensor): return F_pil.adjust_saturation(img, saturation_factor) return F_t.adjust_saturation(img, saturation_factor) +@log_api_usage_once def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: """Adjust hue of an image. @@ -896,14 +889,13 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: Returns: PIL Image or Tensor: Hue adjusted image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(adjust_hue) if not isinstance(img, torch.Tensor): return F_pil.adjust_hue(img, hue_factor) return F_t.adjust_hue(img, hue_factor) +@log_api_usage_once def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: r"""Perform gamma correction on an image. @@ -929,8 +921,6 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: Returns: PIL Image or Tensor: Gamma correction adjusted image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(adjust_gamma) if not isinstance(img, torch.Tensor): return F_pil.adjust_gamma(img, gamma, gain) @@ -993,6 +983,7 @@ def _get_inverse_affine_matrix( return matrix +@log_api_usage_once def rotate( img: Tensor, angle: float, @@ -1036,8 +1027,6 @@ def rotate( .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(rotate) if resample is not None: warnings.warn( "The parameter 'resample' is deprecated since 0.12 and will be removed 0.14. " @@ -1078,6 +1067,7 @@ def rotate( return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill) +@log_api_usage_once def affine( img: Tensor, angle: float, @@ -1125,8 +1115,6 @@ def affine( Returns: PIL Image or Tensor: Transformed image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(affine) if resample is not None: warnings.warn( "The parameter 'resample' is deprecated since 0.12 and will be removed in 0.14. " @@ -1211,6 +1199,7 @@ def affine( @torch.jit.unused +@log_api_usage_once def to_grayscale(img, num_output_channels=1): """Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image. This transform does not support torch Tensor. @@ -1225,14 +1214,13 @@ def to_grayscale(img, num_output_channels=1): - if num_output_channels = 1 : returned image is single channel - if num_output_channels = 3 : returned image is 3 channel with r = g = b """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(to_grayscale) if isinstance(img, Image.Image): return F_pil.to_grayscale(img, num_output_channels) raise TypeError("Input should be PIL Image") +@log_api_usage_once def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: """Convert RGB image to grayscale version of image. If the image is torch Tensor, it is expected @@ -1252,14 +1240,13 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: - if num_output_channels = 1 : returned image is single channel - if num_output_channels = 3 : returned image is 3 channel with r = g = b """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(rgb_to_grayscale) if not isinstance(img, torch.Tensor): return F_pil.to_grayscale(img, num_output_channels) return F_t.rgb_to_grayscale(img, num_output_channels) +@log_api_usage_once def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor: """Erase the input Tensor Image with given value. This transform does not support PIL Image. @@ -1276,8 +1263,6 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool Returns: Tensor Image: Erased image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(erase) if not isinstance(img, torch.Tensor): raise TypeError(f"img should be Tensor Image. Got {type(img)}") @@ -1288,6 +1273,7 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool return img +@log_api_usage_once def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor: """Performs Gaussian blurring on the image by given kernel. If the image is torch Tensor, it is expected @@ -1314,8 +1300,6 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa Returns: PIL Image or Tensor: Gaussian Blurred version of the image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(gaussian_blur) if not isinstance(kernel_size, (int, list, tuple)): raise TypeError(f"kernel_size should be int or a sequence of integers. Got {type(kernel_size)}") if isinstance(kernel_size, int): @@ -1355,6 +1339,7 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa return output +@log_api_usage_once def invert(img: Tensor) -> Tensor: """Invert the colors of an RGB/grayscale image. @@ -1367,14 +1352,13 @@ def invert(img: Tensor) -> Tensor: Returns: PIL Image or Tensor: Color inverted image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(invert) if not isinstance(img, torch.Tensor): return F_pil.invert(img) return F_t.invert(img) +@log_api_usage_once def posterize(img: Tensor, bits: int) -> Tensor: """Posterize an image by reducing the number of bits for each color channel. @@ -1388,8 +1372,6 @@ def posterize(img: Tensor, bits: int) -> Tensor: Returns: PIL Image or Tensor: Posterized image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(posterize) if not (0 <= bits <= 8): raise ValueError(f"The number if bits should be between 0 and 8. Got {bits}") @@ -1399,6 +1381,7 @@ def posterize(img: Tensor, bits: int) -> Tensor: return F_t.posterize(img, bits) +@log_api_usage_once def solarize(img: Tensor, threshold: float) -> Tensor: """Solarize an RGB/grayscale image by inverting all pixel values above a threshold. @@ -1411,14 +1394,13 @@ def solarize(img: Tensor, threshold: float) -> Tensor: Returns: PIL Image or Tensor: Solarized image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(solarize) if not isinstance(img, torch.Tensor): return F_pil.solarize(img, threshold) return F_t.solarize(img, threshold) +@log_api_usage_once def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: """Adjust the sharpness of an image. @@ -1433,14 +1415,13 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: Returns: PIL Image or Tensor: Sharpness adjusted image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(adjust_sharpness) if not isinstance(img, torch.Tensor): return F_pil.adjust_sharpness(img, sharpness_factor) return F_t.adjust_sharpness(img, sharpness_factor) +@log_api_usage_once def autocontrast(img: Tensor) -> Tensor: """Maximize contrast of an image by remapping its pixels per channel so that the lowest becomes black and the lightest @@ -1455,14 +1436,13 @@ def autocontrast(img: Tensor) -> Tensor: Returns: PIL Image or Tensor: An image that was autocontrasted. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(autocontrast) if not isinstance(img, torch.Tensor): return F_pil.autocontrast(img) return F_t.autocontrast(img) +@log_api_usage_once def equalize(img: Tensor) -> Tensor: """Equalize the histogram of an image by applying a non-linear mapping to the input in order to create a uniform @@ -1478,8 +1458,6 @@ def equalize(img: Tensor) -> Tensor: Returns: PIL Image or Tensor: An image that was equalized. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(equalize) if not isinstance(img, torch.Tensor): return F_pil.equalize(img)