From d1f9cc42e2391ea68597460389c2db8dbe94baec Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 15 Feb 2022 17:10:49 +0100 Subject: [PATCH 1/4] try api call logging with decorator --- torchvision/transforms/functional.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 7c6aee1f376..eb66c18b394 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1,7 +1,9 @@ +import functools import math import numbers import warnings from enum import Enum +from typing import Callable from typing import List, Tuple, Any, Optional import numpy as np @@ -19,6 +21,15 @@ from . import functional_tensor as F_t +def log_api_usage_once(fn: Callable) -> Callable: + @functools.wraps(fn) + def wrapper(*args, **kwargs): + _log_api_usage_once(fn) + return fn(*args, **kwargs) + + return wrapper + + class InterpolationMode(Enum): """Interpolation modes Available interpolation methods are ``nearest``, ``bilinear``, ``bicubic``, ``box``, ``hamming``, and ``lanczos``. @@ -364,6 +375,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool return tensor +@log_api_usage_once def resize( img: Tensor, size: List[int], @@ -416,8 +428,6 @@ def resize( Returns: PIL Image or Tensor: Resized image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(resize) # Backward compatibility with integer value if isinstance(interpolation, int): warnings.warn( From 6682f7c489079e0bbbf13d8ec0072649e7d598e3 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 15 Feb 2022 17:58:31 +0100 Subject: [PATCH 2/4] expand decorator test to all functional transforms --- torchvision/prototype/transforms/_presets.py | 6 +- .../prototype/transforms/kernels/_geometry.py | 2 +- torchvision/transforms/functional.py | 95 ++++++------------- 3 files changed, 35 insertions(+), 68 deletions(-) diff --git a/torchvision/prototype/transforms/_presets.py b/torchvision/prototype/transforms/_presets.py index d7c4ddb4684..11998f69c1a 100644 --- a/torchvision/prototype/transforms/_presets.py +++ b/torchvision/prototype/transforms/_presets.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, cast import torch from torch import Tensor, nn @@ -41,7 +41,7 @@ def forward(self, img: Tensor) -> Tensor: img = F.pil_to_tensor(img) img = F.convert_image_dtype(img, torch.float) img = F.normalize(img, mean=self._mean, std=self._std) - return img + return cast(Tensor, img) class Kinect400Eval(nn.Module): @@ -65,7 +65,7 @@ def forward(self, vid: Tensor) -> Tensor: vid = F.resize(vid, self._size, interpolation=self._interpolation) vid = F.center_crop(vid, self._crop_size) vid = F.convert_image_dtype(vid, torch.float) - vid = F.normalize(vid, mean=self._mean, std=self._std) + vid = cast(Tensor, F.normalize(vid, mean=self._mean, std=self._std)) return vid.permute(1, 0, 2, 3) # (T, C, H, W) => (C, T, H, W) diff --git a/torchvision/prototype/transforms/kernels/_geometry.py b/torchvision/prototype/transforms/kernels/_geometry.py index 72afc2e62a3..461e9f860ca 100644 --- a/torchvision/prototype/transforms/kernels/_geometry.py +++ b/torchvision/prototype/transforms/kernels/_geometry.py @@ -39,7 +39,7 @@ def resize_image( new_height, new_width = size num_channels, old_height, old_width = image.shape[-3:] batch_shape = image.shape[:-3] - return _F.resize( + return _F.resize( # type: ignore[no-any-return] image.reshape((-1, num_channels, old_height, old_width)), size=size, interpolation=interpolation, diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index eb66c18b394..5f953a470d6 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -114,6 +114,7 @@ def _is_numpy_image(img: Any) -> bool: return img.ndim in {2, 3} +@log_api_usage_once def to_tensor(pic): """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This function does not support torchscript. @@ -126,8 +127,6 @@ def to_tensor(pic): Returns: Tensor: Converted image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(to_tensor) if not (F_pil._is_pil_image(pic) or _is_numpy(pic)): raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}") @@ -168,6 +167,7 @@ def to_tensor(pic): return img +@log_api_usage_once def pil_to_tensor(pic): """Convert a ``PIL Image`` to a tensor of the same type. This function does not support torchscript. @@ -184,8 +184,6 @@ def pil_to_tensor(pic): Returns: Tensor: Converted image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(pil_to_tensor) if not F_pil._is_pil_image(pic): raise TypeError(f"pic should be PIL Image. Got {type(pic)}") @@ -203,6 +201,7 @@ def pil_to_tensor(pic): return img +@log_api_usage_once def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: """Convert a tensor image to the given ``dtype`` and scale the values accordingly This function does not support PIL Image. @@ -225,14 +224,13 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range of the integer ``dtype``. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(convert_image_dtype) if not isinstance(image, torch.Tensor): raise TypeError("Input img should be Tensor Image") return F_t.convert_image_dtype(image, dtype) +@log_api_usage_once def to_pil_image(pic, mode=None): """Convert a tensor or an ndarray to PIL Image. This function does not support torchscript. @@ -247,8 +245,6 @@ def to_pil_image(pic, mode=None): Returns: PIL Image: Image converted to PIL Image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(to_pil_image) if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)): raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.") @@ -328,6 +324,7 @@ def to_pil_image(pic, mode=None): return Image.fromarray(npimg, mode=mode) +@log_api_usage_once def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor: """Normalize a float tensor image with mean and standard deviation. This transform does not support PIL Image. @@ -346,8 +343,6 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool Returns: Tensor: Normalized Tensor image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(normalize) if not isinstance(tensor, torch.Tensor): raise TypeError(f"Input tensor should be a torch tensor. Got {type(tensor)}.") @@ -448,6 +443,7 @@ def resize( return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias) +@log_api_usage_once def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor: r"""Pad the given image on all sides with the given "pad" value. If the image is torch Tensor, it is expected @@ -489,14 +485,13 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con Returns: PIL Image or Tensor: Padded image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(pad) if not isinstance(img, torch.Tensor): return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode) return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode) +@log_api_usage_once def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: """Crop the given image at specified location and output size. If the image is torch Tensor, it is expected @@ -513,15 +508,13 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: Returns: PIL Image or Tensor: Cropped image. """ - - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(crop) if not isinstance(img, torch.Tensor): return F_pil.crop(img, top, left, height, width) return F_t.crop(img, top, left, height, width) +@log_api_usage_once def center_crop(img: Tensor, output_size: List[int]) -> Tensor: """Crops the given image at the center. If the image is torch Tensor, it is expected @@ -536,8 +529,6 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: Returns: PIL Image or Tensor: Cropped image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(center_crop) if isinstance(output_size, numbers.Number): output_size = (int(output_size), int(output_size)) elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: @@ -563,6 +554,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: return crop(img, crop_top, crop_left, crop_height, crop_width) +@log_api_usage_once def resized_crop( img: Tensor, top: int, @@ -594,13 +586,12 @@ def resized_crop( Returns: PIL Image or Tensor: Cropped image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(resized_crop) img = crop(img, top, left, height, width) img = resize(img, size, interpolation) return img +@log_api_usage_once def hflip(img: Tensor) -> Tensor: """Horizontally flip the given image. @@ -613,8 +604,6 @@ def hflip(img: Tensor) -> Tensor: Returns: PIL Image or Tensor: Horizontally flipped image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(hflip) if not isinstance(img, torch.Tensor): return F_pil.hflip(img) @@ -649,6 +638,7 @@ def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[i return output +@log_api_usage_once def perspective( img: Tensor, startpoints: List[List[int]], @@ -680,9 +670,6 @@ def perspective( Returns: PIL Image or Tensor: transformed Image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(perspective) - coeffs = _get_perspective_coeffs(startpoints, endpoints) # Backward compatibility with integer value @@ -703,6 +690,7 @@ def perspective( return F_t.perspective(img, coeffs, interpolation=interpolation.value, fill=fill) +@log_api_usage_once def vflip(img: Tensor) -> Tensor: """Vertically flip the given image. @@ -715,14 +703,13 @@ def vflip(img: Tensor) -> Tensor: Returns: PIL Image or Tensor: Vertically flipped image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(vflip) if not isinstance(img, torch.Tensor): return F_pil.vflip(img) return F_t.vflip(img) +@log_api_usage_once def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """Crop the given image into four corners and the central crop. If the image is torch Tensor, it is expected @@ -742,8 +729,6 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten tuple: tuple (tl, tr, bl, br, center) Corresponding top left, top right, bottom left, bottom right and center crop. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(five_crop) if isinstance(size, numbers.Number): size = (int(size), int(size)) elif isinstance(size, (tuple, list)) and len(size) == 1: @@ -768,6 +753,7 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten return tl, tr, bl, br, center +@log_api_usage_once def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[Tensor]: """Generate ten cropped images from the given image. Crop the given image into four corners and the central crop plus the @@ -791,8 +777,6 @@ def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[ Corresponding top left, top right, bottom left, bottom right and center crop and same for the flipped image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(ten_crop) if isinstance(size, numbers.Number): size = (int(size), int(size)) elif isinstance(size, (tuple, list)) and len(size) == 1: @@ -812,6 +796,7 @@ def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[ return first_five + second_five +@log_api_usage_once def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: """Adjust brightness of an image. @@ -826,14 +811,13 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: Returns: PIL Image or Tensor: Brightness adjusted image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(adjust_brightness) if not isinstance(img, torch.Tensor): return F_pil.adjust_brightness(img, brightness_factor) return F_t.adjust_brightness(img, brightness_factor) +@log_api_usage_once def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: """Adjust contrast of an image. @@ -848,14 +832,13 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: Returns: PIL Image or Tensor: Contrast adjusted image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(adjust_contrast) if not isinstance(img, torch.Tensor): return F_pil.adjust_contrast(img, contrast_factor) return F_t.adjust_contrast(img, contrast_factor) +@log_api_usage_once def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: """Adjust color saturation of an image. @@ -870,14 +853,13 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: Returns: PIL Image or Tensor: Saturation adjusted image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(adjust_saturation) if not isinstance(img, torch.Tensor): return F_pil.adjust_saturation(img, saturation_factor) return F_t.adjust_saturation(img, saturation_factor) +@log_api_usage_once def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: """Adjust hue of an image. @@ -906,14 +888,13 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: Returns: PIL Image or Tensor: Hue adjusted image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(adjust_hue) if not isinstance(img, torch.Tensor): return F_pil.adjust_hue(img, hue_factor) return F_t.adjust_hue(img, hue_factor) +@log_api_usage_once def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: r"""Perform gamma correction on an image. @@ -939,8 +920,6 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: Returns: PIL Image or Tensor: Gamma correction adjusted image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(adjust_gamma) if not isinstance(img, torch.Tensor): return F_pil.adjust_gamma(img, gamma, gain) @@ -1003,6 +982,7 @@ def _get_inverse_affine_matrix( return matrix +@log_api_usage_once def rotate( img: Tensor, angle: float, @@ -1046,8 +1026,6 @@ def rotate( .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(rotate) if resample is not None: warnings.warn( "The parameter 'resample' is deprecated since 0.12 and will be removed 0.14. " @@ -1088,6 +1066,7 @@ def rotate( return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill) +@log_api_usage_once def affine( img: Tensor, angle: float, @@ -1135,8 +1114,6 @@ def affine( Returns: PIL Image or Tensor: Transformed image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(affine) if resample is not None: warnings.warn( "The parameter 'resample' is deprecated since 0.12 and will be removed in 0.14. " @@ -1221,6 +1198,7 @@ def affine( @torch.jit.unused +@log_api_usage_once def to_grayscale(img, num_output_channels=1): """Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image. This transform does not support torch Tensor. @@ -1235,14 +1213,13 @@ def to_grayscale(img, num_output_channels=1): - if num_output_channels = 1 : returned image is single channel - if num_output_channels = 3 : returned image is 3 channel with r = g = b """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(to_grayscale) if isinstance(img, Image.Image): return F_pil.to_grayscale(img, num_output_channels) raise TypeError("Input should be PIL Image") +@log_api_usage_once def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: """Convert RGB image to grayscale version of image. If the image is torch Tensor, it is expected @@ -1262,14 +1239,13 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: - if num_output_channels = 1 : returned image is single channel - if num_output_channels = 3 : returned image is 3 channel with r = g = b """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(rgb_to_grayscale) if not isinstance(img, torch.Tensor): return F_pil.to_grayscale(img, num_output_channels) return F_t.rgb_to_grayscale(img, num_output_channels) +@log_api_usage_once def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor: """Erase the input Tensor Image with given value. This transform does not support PIL Image. @@ -1286,8 +1262,6 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool Returns: Tensor Image: Erased image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(erase) if not isinstance(img, torch.Tensor): raise TypeError(f"img should be Tensor Image. Got {type(img)}") @@ -1298,6 +1272,7 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool return img +@log_api_usage_once def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor: """Performs Gaussian blurring on the image by given kernel. If the image is torch Tensor, it is expected @@ -1324,8 +1299,6 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa Returns: PIL Image or Tensor: Gaussian Blurred version of the image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(gaussian_blur) if not isinstance(kernel_size, (int, list, tuple)): raise TypeError(f"kernel_size should be int or a sequence of integers. Got {type(kernel_size)}") if isinstance(kernel_size, int): @@ -1365,6 +1338,7 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa return output +@log_api_usage_once def invert(img: Tensor) -> Tensor: """Invert the colors of an RGB/grayscale image. @@ -1377,14 +1351,13 @@ def invert(img: Tensor) -> Tensor: Returns: PIL Image or Tensor: Color inverted image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(invert) if not isinstance(img, torch.Tensor): return F_pil.invert(img) return F_t.invert(img) +@log_api_usage_once def posterize(img: Tensor, bits: int) -> Tensor: """Posterize an image by reducing the number of bits for each color channel. @@ -1398,8 +1371,6 @@ def posterize(img: Tensor, bits: int) -> Tensor: Returns: PIL Image or Tensor: Posterized image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(posterize) if not (0 <= bits <= 8): raise ValueError(f"The number if bits should be between 0 and 8. Got {bits}") @@ -1409,6 +1380,7 @@ def posterize(img: Tensor, bits: int) -> Tensor: return F_t.posterize(img, bits) +@log_api_usage_once def solarize(img: Tensor, threshold: float) -> Tensor: """Solarize an RGB/grayscale image by inverting all pixel values above a threshold. @@ -1421,14 +1393,13 @@ def solarize(img: Tensor, threshold: float) -> Tensor: Returns: PIL Image or Tensor: Solarized image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(solarize) if not isinstance(img, torch.Tensor): return F_pil.solarize(img, threshold) return F_t.solarize(img, threshold) +@log_api_usage_once def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: """Adjust the sharpness of an image. @@ -1443,14 +1414,13 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: Returns: PIL Image or Tensor: Sharpness adjusted image. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(adjust_sharpness) if not isinstance(img, torch.Tensor): return F_pil.adjust_sharpness(img, sharpness_factor) return F_t.adjust_sharpness(img, sharpness_factor) +@log_api_usage_once def autocontrast(img: Tensor) -> Tensor: """Maximize contrast of an image by remapping its pixels per channel so that the lowest becomes black and the lightest @@ -1465,14 +1435,13 @@ def autocontrast(img: Tensor) -> Tensor: Returns: PIL Image or Tensor: An image that was autocontrasted. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(autocontrast) if not isinstance(img, torch.Tensor): return F_pil.autocontrast(img) return F_t.autocontrast(img) +@log_api_usage_once def equalize(img: Tensor) -> Tensor: """Equalize the histogram of an image by applying a non-linear mapping to the input in order to create a uniform @@ -1488,8 +1457,6 @@ def equalize(img: Tensor) -> Tensor: Returns: PIL Image or Tensor: An image that was equalized. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(equalize) if not isinstance(img, torch.Tensor): return F_pil.equalize(img) From e2e0e3d9d45e9a1b354d5cbc28582a7e46cbb902 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 16 Feb 2022 19:49:59 +0100 Subject: [PATCH 3/4] make decorator annotations more concrete --- torchvision/prototype/features/_image.py | 3 ++- torchvision/prototype/transforms/_presets.py | 6 +++--- torchvision/prototype/transforms/kernels/_geometry.py | 2 +- torchvision/transforms/functional.py | 4 ++-- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 5ecc4cbedb7..2e1ae3ce975 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -3,6 +3,7 @@ import warnings from typing import Any, Optional, Union, Tuple, cast +import PIL.Image import torch from torchvision.prototype.utils._internal import StrEnum from torchvision.transforms.functional import to_pil_image @@ -78,7 +79,7 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace: def show(self) -> None: # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we # promote this out of the prototype state - to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show() + cast(PIL.Image.Image, to_pil_image(make_grid(self.view(-1, *self.shape[-3:])))).show() def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image: # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we diff --git a/torchvision/prototype/transforms/_presets.py b/torchvision/prototype/transforms/_presets.py index 11998f69c1a..d7c4ddb4684 100644 --- a/torchvision/prototype/transforms/_presets.py +++ b/torchvision/prototype/transforms/_presets.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple, cast +from typing import Dict, Optional, Tuple import torch from torch import Tensor, nn @@ -41,7 +41,7 @@ def forward(self, img: Tensor) -> Tensor: img = F.pil_to_tensor(img) img = F.convert_image_dtype(img, torch.float) img = F.normalize(img, mean=self._mean, std=self._std) - return cast(Tensor, img) + return img class Kinect400Eval(nn.Module): @@ -65,7 +65,7 @@ def forward(self, vid: Tensor) -> Tensor: vid = F.resize(vid, self._size, interpolation=self._interpolation) vid = F.center_crop(vid, self._crop_size) vid = F.convert_image_dtype(vid, torch.float) - vid = cast(Tensor, F.normalize(vid, mean=self._mean, std=self._std)) + vid = F.normalize(vid, mean=self._mean, std=self._std) return vid.permute(1, 0, 2, 3) # (T, C, H, W) => (C, T, H, W) diff --git a/torchvision/prototype/transforms/kernels/_geometry.py b/torchvision/prototype/transforms/kernels/_geometry.py index 461e9f860ca..72afc2e62a3 100644 --- a/torchvision/prototype/transforms/kernels/_geometry.py +++ b/torchvision/prototype/transforms/kernels/_geometry.py @@ -39,7 +39,7 @@ def resize_image( new_height, new_width = size num_channels, old_height, old_width = image.shape[-3:] batch_shape = image.shape[:-3] - return _F.resize( # type: ignore[no-any-return] + return _F.resize( image.reshape((-1, num_channels, old_height, old_width)), size=size, interpolation=interpolation, diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 5f953a470d6..017882a67bc 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -21,9 +21,9 @@ from . import functional_tensor as F_t -def log_api_usage_once(fn: Callable) -> Callable: +def log_api_usage_once(fn: Callable[..., Tensor]) -> Callable[..., Tensor]: @functools.wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Tensor: _log_api_usage_once(fn) return fn(*args, **kwargs) From 1addc91c6dbcc36b59198fa3dcd1adaa6f81fa45 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 16 Feb 2022 20:25:54 +0100 Subject: [PATCH 4/4] make decorator more generic --- torchvision/prototype/features/_image.py | 3 +-- torchvision/transforms/functional.py | 11 ++++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 2e1ae3ce975..5ecc4cbedb7 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -3,7 +3,6 @@ import warnings from typing import Any, Optional, Union, Tuple, cast -import PIL.Image import torch from torchvision.prototype.utils._internal import StrEnum from torchvision.transforms.functional import to_pil_image @@ -79,7 +78,7 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace: def show(self) -> None: # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we # promote this out of the prototype state - cast(PIL.Image.Image, to_pil_image(make_grid(self.view(-1, *self.shape[-3:])))).show() + to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show() def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image: # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 017882a67bc..03eefd5aaee 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -3,8 +3,7 @@ import numbers import warnings from enum import Enum -from typing import Callable -from typing import List, Tuple, Any, Optional +from typing import List, Tuple, Any, Optional, TypeVar, Callable, cast import numpy as np import torch @@ -20,14 +19,16 @@ from . import functional_pil as F_pil from . import functional_tensor as F_t +F = TypeVar("F", bound=Callable[..., Any]) -def log_api_usage_once(fn: Callable[..., Tensor]) -> Callable[..., Tensor]: + +def log_api_usage_once(fn: F) -> F: @functools.wraps(fn) - def wrapper(*args: Any, **kwargs: Any) -> Tensor: + def wrapper(*args, **kwargs): _log_api_usage_once(fn) return fn(*args, **kwargs) - return wrapper + return cast(F, wrapper) class InterpolationMode(Enum):