diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 4e8496e9758..3563a28d403 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -225,9 +225,18 @@ def test_random_resized_crop(self, transform, input): ) ] ) - def test_convertolor_space(self, transform, input): + def test_convert_color_space(self, transform, input): transform(input) + def test_convert_color_space_unsupported_types(self): + transform = transforms.ConvertColorSpace( + color_space=features.ColorSpace.RGB, old_color_space=features.ColorSpace.GRAY + ) + + for inpt in [make_bounding_box(format="XYXY"), make_segmentation_mask()]: + output = transform(inpt) + assert output is inpt + @pytest.mark.parametrize("p", [0.0, 1.0]) class TestRandomHorizontalFlip: diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index 02e827916ce..fd2af16ac31 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -3,7 +3,6 @@ import numpy as np import PIL.Image -import torch import torchvision.prototype.transforms.functional as F from torchvision.prototype import features from torchvision.prototype.features import ColorSpace @@ -18,7 +17,7 @@ class ToTensor(Transform): # Updated transformed types for ToTensor - _transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray) + _transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray) def __init__(self) -> None: warnings.warn( @@ -52,7 +51,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class ToPILImage(Transform): # Updated transformed types for ToPILImage - _transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray) + _transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray) def __init__(self, mode: Optional[str] = None) -> None: warnings.warn( diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 7a7fcf0ff10..5b98e90aee1 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -42,7 +42,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class ConvertColorSpace(Transform): # F.convert_color_space does NOT handle `_Feature`'s in general - _transformed_types = (torch.Tensor, features.Image, PIL.Image.Image) + _transformed_types = (is_simple_tensor, features.Image, PIL.Image.Image) def __init__( self, diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index 9a12b53f355..d99c3277c78 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -1,18 +1,19 @@ import enum -from typing import Any, Dict, Tuple, Type +from typing import Any, Callable, Dict, Tuple, Type, Union import PIL.Image import torch from torch import nn from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision.prototype.features import _Feature +from torchvision.prototype.transforms._utils import _isinstance, is_simple_tensor from torchvision.utils import _log_api_usage_once class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation - _transformed_types: Tuple[Type, ...] = (torch.Tensor, _Feature, PIL.Image.Image) + _transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (is_simple_tensor, _Feature, PIL.Image.Image) def __init__(self) -> None: super().__init__() @@ -31,7 +32,8 @@ def forward(self, *inputs: Any) -> Any: flat_inputs, spec = tree_flatten(sample) flat_outputs = [ - self._transform(inpt, params) if isinstance(inpt, self._transformed_types) else inpt for inpt in flat_inputs + self._transform(inpt, params) if _isinstance(inpt, self._transformed_types) else inpt + for inpt in flat_inputs ] return tree_unflatten(flat_outputs, spec) diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index 8e5a09d14ad..b677ccc9d9c 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -3,7 +3,6 @@ import numpy as np import PIL.Image -import torch from torch.nn.functional import one_hot from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform @@ -44,7 +43,7 @@ def extra_repr(self) -> str: class ToImageTensor(Transform): # Updated transformed types for ToImageTensor - _transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray) + _transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray) def __init__(self, *, copy: bool = False) -> None: super().__init__() @@ -61,7 +60,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class ToImagePIL(Transform): # Updated transformed types for ToImagePIL - _transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray) + _transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray) def __init__(self, *, mode: Optional[str] = None) -> None: super().__init__() diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index f3e0b495078..09df2ea3a44 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -45,12 +45,18 @@ def query_chw(sample: Any) -> Tuple[int, int, int]: return chws.pop() +def _isinstance(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool: + for type_or_check in types_or_checks: + if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj): + return True + return False + + def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: flat_sample, _ = tree_flatten(sample) - for type_or_check in types_or_checks: - for obj in flat_sample: - if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj): - return True + for obj in flat_sample: + if _isinstance(obj, types_or_checks): + return True return False