Skip to content

simplify dispatcher if-elif #7084

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ no_implicit_optional = True

; warnings
warn_unused_ignores = True
warn_return_any = True

; miscellaneous strictness flags
allow_redefinition = True
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ToImageTensor(Transform):
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> datapoints.Image:
return F.to_image_tensor(inpt) # type: ignore[no-any-return]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to the relaxed strictness, this one (and the ones below) became obsolete.

return F.to_image_tensor(inpt)


class ToImagePIL(Transform):
Expand Down
3 changes: 3 additions & 0 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# TODO: Add _log_api_usage_once() in all mid-level kernels. If they remain not jit-scriptable we can use decorators

from torchvision.transforms import InterpolationMode # usort: skip

from ._utils import is_simple_tensor # usort: skip

from ._meta import (
clamp_bounding_box,
convert_format_bounding_box,
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/transforms/functional/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once

from ._utils import is_simple_tensor


def erase_image_tensor(
image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
Expand Down Expand Up @@ -45,9 +47,7 @@ def erase(
if not torch.jit.is_scripting():
_log_api_usage_once(erase)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
elif isinstance(inpt, datapoints.Image):
output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
Expand Down
39 changes: 11 additions & 28 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torchvision.utils import _log_api_usage_once

from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor
from ._utils import is_simple_tensor


def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
Expand Down Expand Up @@ -43,9 +44,7 @@ def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_brightness)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_brightness(brightness_factor=brightness_factor)
Expand Down Expand Up @@ -131,9 +130,7 @@ def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> da
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_contrast)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_contrast(contrast_factor=contrast_factor)
Expand Down Expand Up @@ -326,9 +323,7 @@ def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.I
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_hue)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_hue(hue_factor=hue_factor)
Expand Down Expand Up @@ -371,9 +366,7 @@ def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_gamma)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_gamma(gamma=gamma, gain=gain)
Expand Down Expand Up @@ -410,9 +403,7 @@ def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJ
if not torch.jit.is_scripting():
_log_api_usage_once(posterize)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return posterize_image_tensor(inpt, bits=bits)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.posterize(bits=bits)
Expand Down Expand Up @@ -443,9 +434,7 @@ def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.Inpu
if not torch.jit.is_scripting():
_log_api_usage_once(solarize)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return solarize_image_tensor(inpt, threshold=threshold)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.solarize(threshold=threshold)
Expand Down Expand Up @@ -498,9 +487,7 @@ def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(autocontrast)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return autocontrast_image_tensor(inpt)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.autocontrast()
Expand Down Expand Up @@ -593,9 +580,7 @@ def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(equalize)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return equalize_image_tensor(inpt)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.equalize()
Expand All @@ -610,7 +595,7 @@ def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:

def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.is_floating_point():
return 1.0 - image # type: ignore[no-any-return]
return 1.0 - image
elif image.dtype == torch.uint8:
return image.bitwise_not()
else: # signed integer dtypes
Expand All @@ -629,9 +614,7 @@ def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(invert)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return invert_image_tensor(inpt)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.invert()
Expand Down
10 changes: 6 additions & 4 deletions torchvision/prototype/transforms/functional/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from torchvision.prototype import datapoints
from torchvision.transforms import functional as _F

from ._utils import is_simple_tensor


@torch.jit.unused
def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
Expand All @@ -25,14 +27,14 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
def rgb_to_grayscale(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
if not torch.jit.is_scripting() and isinstance(inpt, (datapoints.Image, datapoints.Video)):
inpt = inpt.as_subclass(torch.Tensor)
old_color_space = None
elif isinstance(inpt, torch.Tensor):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
old_color_space = datapoints._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
else:
old_color_space = None

if isinstance(inpt, (datapoints.Image, datapoints.Video)):
inpt = inpt.as_subclass(torch.Tensor)

call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = (
f"convert_color_space(..., color_space=datapoints.ColorSpace.GRAY"
Expand Down
54 changes: 15 additions & 39 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

from ._meta import convert_format_bounding_box, get_spatial_size_image_pil

from ._utils import is_simple_tensor


def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-1)
Expand Down Expand Up @@ -60,9 +62,7 @@ def horizontal_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(horizontal_flip)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return horizontal_flip_image_tensor(inpt)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.horizontal_flip()
Expand Down Expand Up @@ -111,9 +111,7 @@ def vertical_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(vertical_flip)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return vertical_flip_image_tensor(inpt)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.vertical_flip()
Expand Down Expand Up @@ -241,9 +239,7 @@ def resize(
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(resize)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias)
Expand Down Expand Up @@ -744,9 +740,7 @@ def affine(
_log_api_usage_once(affine)

# TODO: consider deprecating integers from angle and shear on the future
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return affine_image_tensor(
inpt,
angle,
Expand Down Expand Up @@ -929,9 +923,7 @@ def rotate(
if not torch.jit.is_scripting():
_log_api_usage_once(rotate)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
Expand Down Expand Up @@ -1139,9 +1131,7 @@ def pad(
if not torch.jit.is_scripting():
_log_api_usage_once(pad)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)

elif isinstance(inpt, datapoints._datapoint.Datapoint):
Expand Down Expand Up @@ -1219,9 +1209,7 @@ def crop(inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, width:
if not torch.jit.is_scripting():
_log_api_usage_once(crop)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return crop_image_tensor(inpt, top, left, height, width)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.crop(top, left, height, width)
Expand Down Expand Up @@ -1476,9 +1464,7 @@ def perspective(
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(perspective)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return perspective_image_tensor(
inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
)
Expand Down Expand Up @@ -1639,9 +1625,7 @@ def elastic(
if not torch.jit.is_scripting():
_log_api_usage_once(elastic)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.elastic(displacement, interpolation=interpolation, fill=fill)
Expand Down Expand Up @@ -1754,9 +1738,7 @@ def center_crop(inpt: datapoints.InputTypeJIT, output_size: List[int]) -> datapo
if not torch.jit.is_scripting():
_log_api_usage_once(center_crop)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return center_crop_image_tensor(inpt, output_size)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.center_crop(output_size)
Expand Down Expand Up @@ -1850,9 +1832,7 @@ def resized_crop(
if not torch.jit.is_scripting():
_log_api_usage_once(resized_crop)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return resized_crop_image_tensor(
inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
)
Expand Down Expand Up @@ -1935,9 +1915,7 @@ def five_crop(

# TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with
# `ten_crop`
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return five_crop_image_tensor(inpt, size)
elif isinstance(inpt, datapoints.Image):
output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size)
Expand Down Expand Up @@ -1991,9 +1969,7 @@ def ten_crop(
if not torch.jit.is_scripting():
_log_api_usage_once(ten_crop)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
elif isinstance(inpt, datapoints.Image):
output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
Expand Down
Loading