From eefa6fe602a6743f090d9b1465751e5cd19ea2ad Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 24 Aug 2022 18:14:02 +0200 Subject: [PATCH 1/3] Fixes unexpected behaviour with Transform._transformed_types and torch.Tensor --- test/test_prototype_transforms.py | 11 ++++++++++- torchvision/prototype/transforms/_transform.py | 12 +++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 2a17c0ffbe9..ce44a7f53d7 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -225,9 +225,18 @@ def test_random_resized_crop(self, transform, input): ) ] ) - def test_convertolor_space(self, transform, input): + def test_convert_color_space(self, transform, input): transform(input) + def test_convert_color_space_unsupported_types(self): + transform = transforms.ConvertColorSpace( + color_space=features.ColorSpace.RGB, old_color_space=features.ColorSpace.GRAY + ) + + for inpt in [make_bounding_box(format="XYXY"), make_segmentation_mask()]: + output = transform(inpt) + assert output is inpt + @pytest.mark.parametrize("p", [0.0, 1.0]) class TestRandomHorizontalFlip: diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index 9a12b53f355..e12dabc42b8 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -6,9 +6,18 @@ from torch import nn from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision.prototype.features import _Feature +from torchvision.prototype.transforms._utils import is_simple_tensor from torchvision.utils import _log_api_usage_once +def _isinstance(obj: Any, types: Tuple[Type, ...]) -> bool: + has_tensor = torch.Tensor in types + if not has_tensor: + return isinstance(obj, types) + types_ = tuple(t for t in types if t != torch.Tensor) + return isinstance(obj, types_) or is_simple_tensor(obj) + + class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation @@ -31,7 +40,8 @@ def forward(self, *inputs: Any) -> Any: flat_inputs, spec = tree_flatten(sample) flat_outputs = [ - self._transform(inpt, params) if isinstance(inpt, self._transformed_types) else inpt for inpt in flat_inputs + self._transform(inpt, params) if _isinstance(inpt, self._transformed_types) else inpt + for inpt in flat_inputs ] return tree_unflatten(flat_outputs, spec) From 847a4a69092172650553b8b5f441bbbb91390fc5 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 25 Aug 2022 12:05:02 +0200 Subject: [PATCH 2/3] Make code consistent to has_any, has_all implementation --- torchvision/prototype/transforms/_deprecated.py | 4 ++-- torchvision/prototype/transforms/_meta.py | 2 +- torchvision/prototype/transforms/_transform.py | 14 +++----------- .../prototype/transforms/_type_conversion.py | 4 ++-- torchvision/prototype/transforms/_utils.py | 14 ++++++++++---- 5 files changed, 18 insertions(+), 20 deletions(-) diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index 02e827916ce..db84f2295d4 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -18,7 +18,7 @@ class ToTensor(Transform): # Updated transformed types for ToTensor - _transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray) + _transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray) def __init__(self) -> None: warnings.warn( @@ -52,7 +52,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class ToPILImage(Transform): # Updated transformed types for ToPILImage - _transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray) + _transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray) def __init__(self, mode: Optional[str] = None) -> None: warnings.warn( diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index b3b87b7cb09..abcc671b157 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -42,7 +42,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class ConvertColorSpace(Transform): # F.convert_color_space does NOT handle `_Feature`'s in general - _transformed_types = (torch.Tensor, features.Image, PIL.Image.Image) + _transformed_types = (is_simple_tensor, features.Image, PIL.Image.Image) def __init__( self, diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index e12dabc42b8..d99c3277c78 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -1,27 +1,19 @@ import enum -from typing import Any, Dict, Tuple, Type +from typing import Any, Callable, Dict, Tuple, Type, Union import PIL.Image import torch from torch import nn from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision.prototype.features import _Feature -from torchvision.prototype.transforms._utils import is_simple_tensor +from torchvision.prototype.transforms._utils import _isinstance, is_simple_tensor from torchvision.utils import _log_api_usage_once -def _isinstance(obj: Any, types: Tuple[Type, ...]) -> bool: - has_tensor = torch.Tensor in types - if not has_tensor: - return isinstance(obj, types) - types_ = tuple(t for t in types if t != torch.Tensor) - return isinstance(obj, types_) or is_simple_tensor(obj) - - class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation - _transformed_types: Tuple[Type, ...] = (torch.Tensor, _Feature, PIL.Image.Image) + _transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (is_simple_tensor, _Feature, PIL.Image.Image) def __init__(self) -> None: super().__init__() diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index 88464c09436..8aee04f6db1 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -44,7 +44,7 @@ def extra_repr(self) -> str: class ToImageTensor(Transform): # Updated transformed types for ToImageTensor - _transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray) + _transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray) def __init__(self, *, copy: bool = False) -> None: super().__init__() @@ -61,7 +61,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class ToImagePIL(Transform): # Updated transformed types for ToImagePIL - _transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray) + _transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray) def __init__(self, *, mode: Optional[str] = None) -> None: super().__init__() diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index fe06132ca1c..6c2fe04192f 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -39,12 +39,18 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im return channels, height, width +def _isinstance(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool: + for type_or_check in types_or_checks: + if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj): + return True + return False + + def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: flat_sample, _ = tree_flatten(sample) - for type_or_check in types_or_checks: - for obj in flat_sample: - if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj): - return True + for obj in flat_sample: + if _isinstance(obj, types_or_checks): + return True return False From 8766032f331ea1334df5c39e2753a9d5ecc71a01 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 25 Aug 2022 12:23:53 +0200 Subject: [PATCH 3/3] Fixed failing flake8 check --- torchvision/prototype/transforms/_deprecated.py | 1 - torchvision/prototype/transforms/_type_conversion.py | 1 - 2 files changed, 2 deletions(-) diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index db84f2295d4..fd2af16ac31 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -3,7 +3,6 @@ import numpy as np import PIL.Image -import torch import torchvision.prototype.transforms.functional as F from torchvision.prototype import features from torchvision.prototype.features import ColorSpace diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index 59ac7948d43..b677ccc9d9c 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -3,7 +3,6 @@ import numpy as np import PIL.Image -import torch from torch.nn.functional import one_hot from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform