From 30899821b57f215d276d77ad193172ed6de2db5a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 13 Jan 2023 09:54:10 +0100 Subject: [PATCH 1/4] simplify dispatcher if-elif --- .../transforms/functional/__init__.py | 3 ++ .../transforms/functional/_augment.py | 6 +-- .../prototype/transforms/functional/_color.py | 37 ++++--------- .../transforms/functional/_deprecated.py | 10 ++-- .../transforms/functional/_geometry.py | 54 ++++++------------- .../prototype/transforms/functional/_meta.py | 24 +++------ .../prototype/transforms/functional/_misc.py | 10 ++-- .../transforms/functional/_temporal.py | 4 +- .../prototype/transforms/functional/_utils.py | 10 ++++ torchvision/prototype/transforms/utils.py | 8 +-- 10 files changed, 63 insertions(+), 103 deletions(-) create mode 100644 torchvision/prototype/transforms/functional/_utils.py diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index ec2da6ee518..30ef6e3fc99 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -1,6 +1,9 @@ # TODO: Add _log_api_usage_once() in all mid-level kernels. If they remain not jit-scriptable we can use decorators from torchvision.transforms import InterpolationMode # usort: skip + +from ._utils import is_simple_tensor # usort: skip + from ._meta import ( clamp_bounding_box, convert_format_bounding_box, diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 12af2444ef0..0164a0b5b9b 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -7,6 +7,8 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once +from ._utils import is_simple_tensor + def erase_image_tensor( image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False @@ -45,9 +47,7 @@ def erase( if not torch.jit.is_scripting(): _log_api_usage_once(erase) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) elif isinstance(inpt, datapoints.Image): output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 037bb5f8413..a04e208bdbd 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -8,6 +8,7 @@ from torchvision.utils import _log_api_usage_once from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor +from ._utils import is_simple_tensor def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: @@ -43,9 +44,7 @@ def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) - if not torch.jit.is_scripting(): _log_api_usage_once(adjust_brightness) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.adjust_brightness(brightness_factor=brightness_factor) @@ -131,9 +130,7 @@ def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> da if not torch.jit.is_scripting(): _log_api_usage_once(adjust_contrast) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.adjust_contrast(contrast_factor=contrast_factor) @@ -326,9 +323,7 @@ def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.I if not torch.jit.is_scripting(): _log_api_usage_once(adjust_hue) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.adjust_hue(hue_factor=hue_factor) @@ -371,9 +366,7 @@ def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) - if not torch.jit.is_scripting(): _log_api_usage_once(adjust_gamma) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.adjust_gamma(gamma=gamma, gain=gain) @@ -410,9 +403,7 @@ def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJ if not torch.jit.is_scripting(): _log_api_usage_once(posterize) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return posterize_image_tensor(inpt, bits=bits) elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.posterize(bits=bits) @@ -443,9 +434,7 @@ def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.Inpu if not torch.jit.is_scripting(): _log_api_usage_once(solarize) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return solarize_image_tensor(inpt, threshold=threshold) elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.solarize(threshold=threshold) @@ -498,9 +487,7 @@ def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: if not torch.jit.is_scripting(): _log_api_usage_once(autocontrast) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return autocontrast_image_tensor(inpt) elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.autocontrast() @@ -593,9 +580,7 @@ def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: if not torch.jit.is_scripting(): _log_api_usage_once(equalize) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return equalize_image_tensor(inpt) elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.equalize() @@ -629,9 +614,7 @@ def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: if not torch.jit.is_scripting(): _log_api_usage_once(invert) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return invert_image_tensor(inpt) elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.invert() diff --git a/torchvision/prototype/transforms/functional/_deprecated.py b/torchvision/prototype/transforms/functional/_deprecated.py index 25b54917b33..f6fb0af0ae9 100644 --- a/torchvision/prototype/transforms/functional/_deprecated.py +++ b/torchvision/prototype/transforms/functional/_deprecated.py @@ -7,6 +7,8 @@ from torchvision.prototype import datapoints from torchvision.transforms import functional as _F +from ._utils import is_simple_tensor + @torch.jit.unused def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image: @@ -25,14 +27,14 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima def rgb_to_grayscale( inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1 ) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: - if not torch.jit.is_scripting() and isinstance(inpt, (datapoints.Image, datapoints.Video)): - inpt = inpt.as_subclass(torch.Tensor) - old_color_space = None - elif isinstance(inpt, torch.Tensor): + if torch.jit.is_scripting() or is_simple_tensor(inpt): old_color_space = datapoints._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type] else: old_color_space = None + if isinstance(inpt, (datapoints.Image, datapoints.Video)): + inpt = inpt.as_subclass(torch.Tensor) + call = ", num_output_channels=3" if num_output_channels == 3 else "" replacement = ( f"convert_color_space(..., color_space=datapoints.ColorSpace.GRAY" diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index ba417a0ce84..66e777dbdcc 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -23,6 +23,8 @@ from ._meta import convert_format_bounding_box, get_spatial_size_image_pil +from ._utils import is_simple_tensor + def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: return image.flip(-1) @@ -60,9 +62,7 @@ def horizontal_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: if not torch.jit.is_scripting(): _log_api_usage_once(horizontal_flip) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return horizontal_flip_image_tensor(inpt) elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.horizontal_flip() @@ -111,9 +111,7 @@ def vertical_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: if not torch.jit.is_scripting(): _log_api_usage_once(vertical_flip) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return vertical_flip_image_tensor(inpt) elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.vertical_flip() @@ -241,9 +239,7 @@ def resize( ) -> datapoints.InputTypeJIT: if not torch.jit.is_scripting(): _log_api_usage_once(resize) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias) @@ -744,9 +740,7 @@ def affine( _log_api_usage_once(affine) # TODO: consider deprecating integers from angle and shear on the future - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return affine_image_tensor( inpt, angle, @@ -929,9 +923,7 @@ def rotate( if not torch.jit.is_scripting(): _log_api_usage_once(rotate) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center) @@ -1139,9 +1131,7 @@ def pad( if not torch.jit.is_scripting(): _log_api_usage_once(pad) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) elif isinstance(inpt, datapoints._datapoint.Datapoint): @@ -1219,9 +1209,7 @@ def crop(inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, width: if not torch.jit.is_scripting(): _log_api_usage_once(crop) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return crop_image_tensor(inpt, top, left, height, width) elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.crop(top, left, height, width) @@ -1476,9 +1464,7 @@ def perspective( ) -> datapoints.InputTypeJIT: if not torch.jit.is_scripting(): _log_api_usage_once(perspective) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return perspective_image_tensor( inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients ) @@ -1639,9 +1625,7 @@ def elastic( if not torch.jit.is_scripting(): _log_api_usage_once(elastic) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill) elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.elastic(displacement, interpolation=interpolation, fill=fill) @@ -1754,9 +1738,7 @@ def center_crop(inpt: datapoints.InputTypeJIT, output_size: List[int]) -> datapo if not torch.jit.is_scripting(): _log_api_usage_once(center_crop) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return center_crop_image_tensor(inpt, output_size) elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.center_crop(output_size) @@ -1850,9 +1832,7 @@ def resized_crop( if not torch.jit.is_scripting(): _log_api_usage_once(resized_crop) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return resized_crop_image_tensor( inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation ) @@ -1935,9 +1915,7 @@ def five_crop( # TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with # `ten_crop` - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return five_crop_image_tensor(inpt, size) elif isinstance(inpt, datapoints.Image): output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size) @@ -1991,9 +1969,7 @@ def ten_crop( if not torch.jit.is_scripting(): _log_api_usage_once(ten_crop) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) elif isinstance(inpt, datapoints.Image): output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip) diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 28de0536978..8ca2bf60521 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -9,6 +9,8 @@ from torchvision.utils import _log_api_usage_once +from ._utils import is_simple_tensor + def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: chw = list(image.shape[-3:]) @@ -29,9 +31,7 @@ def get_dimensions(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT] if not torch.jit.is_scripting(): _log_api_usage_once(get_dimensions) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return get_dimensions_image_tensor(inpt) elif isinstance(inpt, (datapoints.Image, datapoints.Video)): channels = inpt.num_channels @@ -68,9 +68,7 @@ def get_num_channels(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJI if not torch.jit.is_scripting(): _log_api_usage_once(get_num_channels) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return get_num_channels_image_tensor(inpt) elif isinstance(inpt, (datapoints.Image, datapoints.Video)): return inpt.num_channels @@ -120,9 +118,7 @@ def get_spatial_size(inpt: datapoints.InputTypeJIT) -> List[int]: if not torch.jit.is_scripting(): _log_api_usage_once(get_spatial_size) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return get_spatial_size_image_tensor(inpt) elif isinstance(inpt, (datapoints.Image, datapoints.Video, datapoints.BoundingBox, datapoints.Mask)): return list(inpt.spatial_size) @@ -143,7 +139,7 @@ def get_num_frames(inpt: datapoints.VideoTypeJIT) -> int: if not torch.jit.is_scripting(): _log_api_usage_once(get_num_frames) - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Video)): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return get_num_frames_video(inpt) elif isinstance(inpt, datapoints.Video): return inpt.num_frames @@ -336,9 +332,7 @@ def convert_color_space( if not torch.jit.is_scripting(): _log_api_usage_once(convert_color_space) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): if old_color_space is None: raise RuntimeError( "In order to convert the color space of simple tensors, " @@ -443,9 +437,7 @@ def convert_dtype( if not torch.jit.is_scripting(): _log_api_usage_once(convert_dtype) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return convert_dtype_image_tensor(inpt, dtype) elif isinstance(inpt, datapoints.Image): output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index bc9408d0e2c..59570768160 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -10,7 +10,7 @@ from torchvision.utils import _log_api_usage_once -from ..utils import is_simple_tensor +from ._utils import is_simple_tensor def normalize_image_tensor( @@ -61,9 +61,9 @@ def normalize( if not torch.jit.is_scripting(): _log_api_usage_once(normalize) - if is_simple_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video)): + if isinstance(inpt, (datapoints.Image, datapoints.Video)): inpt = inpt.as_subclass(torch.Tensor) - else: + elif not is_simple_tensor(inpt): raise TypeError( f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead." @@ -175,9 +175,7 @@ def gaussian_blur( if not torch.jit.is_scripting(): _log_api_usage_once(gaussian_blur) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma) diff --git a/torchvision/prototype/transforms/functional/_temporal.py b/torchvision/prototype/transforms/functional/_temporal.py index 35f4a84ce7c..d39a64534ca 100644 --- a/torchvision/prototype/transforms/functional/_temporal.py +++ b/torchvision/prototype/transforms/functional/_temporal.py @@ -4,6 +4,8 @@ from torchvision.utils import _log_api_usage_once +from ._utils import is_simple_tensor + def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temporal_dim: int = -4) -> torch.Tensor: # Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19 @@ -18,7 +20,7 @@ def uniform_temporal_subsample( if not torch.jit.is_scripting(): _log_api_usage_once(uniform_temporal_subsample) - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Video)): + if torch.jit.is_scripting() or is_simple_tensor(inpt): return uniform_temporal_subsample_video(inpt, num_samples, temporal_dim=temporal_dim) elif isinstance(inpt, datapoints.Video): if temporal_dim != -4 and inpt.ndim - 4 != temporal_dim: diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py new file mode 100644 index 00000000000..1a31bf66ad6 --- /dev/null +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -0,0 +1,10 @@ +from typing import Any + +import torch +from torchvision.prototype.datapoints._datapoint import Datapoint + +from typing_extensions import TypeGuard + + +def is_simple_tensor(inpt: Any) -> TypeGuard[torch.Tensor]: + return isinstance(inpt, torch.Tensor) and not isinstance(inpt, Datapoint) diff --git a/torchvision/prototype/transforms/utils.py b/torchvision/prototype/transforms/utils.py index 9ab2ed2602b..ff7fff50ced 100644 --- a/torchvision/prototype/transforms/utils.py +++ b/torchvision/prototype/transforms/utils.py @@ -3,16 +3,10 @@ from typing import Any, Callable, List, Tuple, Type, Union import PIL.Image -import torch from torchvision._utils import sequence_to_str from torchvision.prototype import datapoints -from torchvision.prototype.datapoints._datapoint import Datapoint -from torchvision.prototype.transforms.functional import get_dimensions, get_spatial_size - - -def is_simple_tensor(inpt: Any) -> bool: - return isinstance(inpt, torch.Tensor) and not isinstance(inpt, Datapoint) +from torchvision.prototype.transforms.functional import get_dimensions, get_spatial_size, is_simple_tensor def query_bounding_box(flat_inputs: List[Any]) -> datapoints.BoundingBox: From 5011a1947daa8d933e7d1634cb005763eeeef6a4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 13 Jan 2023 09:56:45 +0100 Subject: [PATCH 2/4] remove typeguard --- torchvision/prototype/transforms/functional/_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py index 1a31bf66ad6..e4efeb6016f 100644 --- a/torchvision/prototype/transforms/functional/_utils.py +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -3,8 +3,6 @@ import torch from torchvision.prototype.datapoints._datapoint import Datapoint -from typing_extensions import TypeGuard - -def is_simple_tensor(inpt: Any) -> TypeGuard[torch.Tensor]: +def is_simple_tensor(inpt: Any) -> bool: return isinstance(inpt, torch.Tensor) and not isinstance(inpt, Datapoint) From e991d958e5e191eb86c6474fb114a3326c4c5db1 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 16 Jan 2023 08:35:14 +0100 Subject: [PATCH 3/4] relax mypy --- mypy.ini | 2 +- torchvision/prototype/transforms/_type_conversion.py | 2 +- torchvision/prototype/transforms/functional/_color.py | 2 +- torchvision/prototype/transforms/functional/_meta.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mypy.ini b/mypy.ini index c1d174f4595..91ee4e7455b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -32,7 +32,7 @@ no_implicit_optional = True ; warnings warn_unused_ignores = True -warn_return_any = True +;warn_return_any = True ; miscellaneous strictness flags allow_redefinition = True diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index 01908650fb4..c84aee62afe 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -46,7 +46,7 @@ class ToImageTensor(Transform): def _transform( self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] ) -> datapoints.Image: - return F.to_image_tensor(inpt) # type: ignore[no-any-return] + return F.to_image_tensor(inpt) class ToImagePIL(Transform): diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index a04e208bdbd..53de1f407c8 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -595,7 +595,7 @@ def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: if image.is_floating_point(): - return 1.0 - image # type: ignore[no-any-return] + return 1.0 - image elif image.dtype == torch.uint8: return image.bitwise_not() else: # signed integer dtypes diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 8ca2bf60521..62f9664fc47 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -123,7 +123,7 @@ def get_spatial_size(inpt: datapoints.InputTypeJIT) -> List[int]: elif isinstance(inpt, (datapoints.Image, datapoints.Video, datapoints.BoundingBox, datapoints.Mask)): return list(inpt.spatial_size) elif isinstance(inpt, PIL.Image.Image): - return get_spatial_size_image_pil(inpt) # type: ignore[no-any-return] + return get_spatial_size_image_pil(inpt) else: raise TypeError( f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " From 1b709f808e4ed2045777ffe46f939bb5f13f9bb0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 16 Jan 2023 08:35:47 +0100 Subject: [PATCH 4/4] cleanup --- mypy.ini | 1 - 1 file changed, 1 deletion(-) diff --git a/mypy.ini b/mypy.ini index 91ee4e7455b..eb88b233fc0 100644 --- a/mypy.ini +++ b/mypy.ini @@ -32,7 +32,6 @@ no_implicit_optional = True ; warnings warn_unused_ignores = True -;warn_return_any = True ; miscellaneous strictness flags allow_redefinition = True