Skip to content

Commit 2e70ee1

Browse files
authored
[proto] Fix for handling numpy arrays by Transform (#6385)
* [proto] Fix for handling Numpy arrays by Transform * transformed_types -> _transformed_types
1 parent 1b44be3 commit 2e70ee1

File tree

5 files changed

+119
-12
lines changed

5 files changed

+119
-12
lines changed

test/test_prototype_transforms.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import itertools
22

3+
import numpy as np
4+
35
import PIL.Image
46

57
import pytest
@@ -991,3 +993,94 @@ def test__transform(self, p, inpt_type, mocker):
991993
fn.assert_called_once_with(erase_image_tensor_inpt, **params)
992994
else:
993995
fn.call_count == 0
996+
997+
998+
class TestTransform:
999+
@pytest.mark.parametrize(
1000+
"inpt_type",
1001+
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int],
1002+
)
1003+
def test_check_transformed_types(self, inpt_type, mocker):
1004+
# This test ensures that we correctly handle which types to transform and which to bypass
1005+
t = transforms.Transform()
1006+
inpt = mocker.MagicMock(spec=inpt_type)
1007+
1008+
if inpt_type in (np.ndarray, str, int):
1009+
output = t(inpt)
1010+
assert output is inpt
1011+
else:
1012+
with pytest.raises(NotImplementedError):
1013+
t(inpt)
1014+
1015+
1016+
class TestToImageTensor:
1017+
@pytest.mark.parametrize(
1018+
"inpt_type",
1019+
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int],
1020+
)
1021+
def test__transform(self, inpt_type, mocker):
1022+
fn = mocker.patch(
1023+
"torchvision.prototype.transforms.functional.to_image_tensor",
1024+
return_value=torch.rand(1, 3, 8, 8),
1025+
)
1026+
1027+
inpt = mocker.MagicMock(spec=inpt_type)
1028+
transform = transforms.ToImageTensor()
1029+
transform(inpt)
1030+
if inpt_type in (features.BoundingBox, str, int):
1031+
fn.call_count == 0
1032+
else:
1033+
fn.assert_called_once_with(inpt, copy=transform.copy)
1034+
1035+
1036+
class TestToImagePIL:
1037+
@pytest.mark.parametrize(
1038+
"inpt_type",
1039+
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int],
1040+
)
1041+
def test__transform(self, inpt_type, mocker):
1042+
fn = mocker.patch("torchvision.prototype.transforms.functional.to_image_pil")
1043+
1044+
inpt = mocker.MagicMock(spec=inpt_type)
1045+
transform = transforms.ToImagePIL()
1046+
transform(inpt)
1047+
if inpt_type in (features.BoundingBox, str, int):
1048+
fn.call_count == 0
1049+
else:
1050+
fn.assert_called_once_with(inpt, copy=transform.copy)
1051+
1052+
1053+
class TestToPILImage:
1054+
@pytest.mark.parametrize(
1055+
"inpt_type",
1056+
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int],
1057+
)
1058+
def test__transform(self, inpt_type, mocker):
1059+
fn = mocker.patch("torchvision.transforms.functional.to_pil_image")
1060+
1061+
inpt = mocker.MagicMock(spec=inpt_type)
1062+
with pytest.warns(UserWarning, match="deprecated and will be removed"):
1063+
transform = transforms.ToPILImage()
1064+
transform(inpt)
1065+
if inpt_type in (PIL.Image.Image, features.BoundingBox, str, int):
1066+
fn.call_count == 0
1067+
else:
1068+
fn.assert_called_once_with(inpt, mode=transform.mode)
1069+
1070+
1071+
class TestToTensor:
1072+
@pytest.mark.parametrize(
1073+
"inpt_type",
1074+
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int],
1075+
)
1076+
def test__transform(self, inpt_type, mocker):
1077+
fn = mocker.patch("torchvision.transforms.functional.to_tensor")
1078+
1079+
inpt = mocker.MagicMock(spec=inpt_type)
1080+
with pytest.warns(UserWarning, match="deprecated and will be removed"):
1081+
transform = transforms.ToTensor()
1082+
transform(inpt)
1083+
if inpt_type in (features.Image, torch.Tensor, features.BoundingBox, str, int):
1084+
fn.call_count == 0
1085+
else:
1086+
fn.assert_called_once_with(inpt)

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,6 @@
3434
)
3535
from ._meta import ConvertBoundingBoxFormat, ConvertImageColorSpace, ConvertImageDtype
3636
from ._misc import GaussianBlur, Identity, Lambda, Normalize, ToDtype
37-
from ._type_conversion import DecodeImage, LabelToOneHot
37+
from ._type_conversion import DecodeImage, LabelToOneHot, ToImagePIL, ToImageTensor
3838

3939
from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip

torchvision/prototype/transforms/_deprecated.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
import PIL.Image
6+
import torch
67
from torchvision.prototype import features
78
from torchvision.prototype.features import ColorSpace
89
from torchvision.prototype.transforms import Transform
@@ -15,6 +16,10 @@
1516

1617

1718
class ToTensor(Transform):
19+
20+
# Updated transformed types for ToTensor
21+
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
22+
1823
def __init__(self) -> None:
1924
warnings.warn(
2025
"The transform `ToTensor()` is deprecated and will be removed in a future release. "
@@ -23,8 +28,6 @@ def __init__(self) -> None:
2328
super().__init__()
2429

2530
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
26-
# TODO: Transforms allows to pass only (torch.Tensor, _Feature, PIL.Image.Image)
27-
# so input as np.ndarray is not possible. We need to make it possible
2831
if isinstance(inpt, (PIL.Image.Image, np.ndarray)):
2932
return _F.to_tensor(inpt)
3033
else:
@@ -47,6 +50,10 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
4750

4851

4952
class ToPILImage(Transform):
53+
54+
# Updated transformed types for ToPILImage
55+
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
56+
5057
def __init__(self, mode: Optional[str] = None) -> None:
5158
warnings.warn(
5259
"The transform `ToPILImage()` is deprecated and will be removed in a future release. "
@@ -56,8 +63,6 @@ def __init__(self, mode: Optional[str] = None) -> None:
5663
self.mode = mode
5764

5865
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
59-
# TODO: Transforms allows to pass only (torch.Tensor, _Feature, PIL.Image.Image)
60-
# so input as np.ndarray is not possible. We need to make it possible
6166
if is_simple_tensor(inpt) or isinstance(inpt, (features.Image, np.ndarray)):
6267
return _F.to_pil_image(inpt, mode=self.mode)
6368
else:

torchvision/prototype/transforms/_transform.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import enum
2-
from typing import Any, Dict
2+
from typing import Any, Dict, Tuple, Type
33

44
import PIL.Image
55
import torch
@@ -10,6 +10,10 @@
1010

1111

1212
class Transform(nn.Module):
13+
14+
# Class attribute defining transformed types. Other types are passed-through without any transformation
15+
_transformed_types: Tuple[Type, ...] = (torch.Tensor, _Feature, PIL.Image.Image)
16+
1317
def __init__(self) -> None:
1418
super().__init__()
1519
_log_api_usage_once(self)
@@ -26,9 +30,8 @@ def forward(self, *inputs: Any) -> Any:
2630
params = self._get_params(sample)
2731

2832
flat_inputs, spec = tree_flatten(sample)
29-
transformed_types = (torch.Tensor, _Feature, PIL.Image.Image)
3033
flat_outputs = [
31-
self._transform(inpt, params) if isinstance(inpt, transformed_types) else inpt for inpt in flat_inputs
34+
self._transform(inpt, params) if isinstance(inpt, self._transformed_types) else inpt for inpt in flat_inputs
3235
]
3336
return tree_unflatten(flat_outputs, spec)
3437

torchvision/prototype/transforms/_type_conversion.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import numpy as np
44
import PIL.Image
5+
6+
import torch
57
from torchvision.prototype import features
68
from torchvision.prototype.transforms import functional as F, Transform
79

@@ -40,13 +42,15 @@ def extra_repr(self) -> str:
4042

4143

4244
class ToImageTensor(Transform):
45+
46+
# Updated transformed types for ToImageTensor
47+
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
48+
4349
def __init__(self, *, copy: bool = False) -> None:
4450
super().__init__()
4551
self.copy = copy
4652

4753
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
48-
# TODO: Transforms allows to pass only (torch.Tensor, _Feature, PIL.Image.Image)
49-
# so input as np.ndarray is not possible. We need to make it possible
5054
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
5155
output = F.to_image_tensor(inpt, copy=self.copy)
5256
return features.Image(output)
@@ -55,13 +59,15 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
5559

5660

5761
class ToImagePIL(Transform):
62+
63+
# Updated transformed types for ToImagePIL
64+
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
65+
5866
def __init__(self, *, copy: bool = False) -> None:
5967
super().__init__()
6068
self.copy = copy
6169

6270
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
63-
# TODO: Transforms allows to pass only (torch.Tensor, _Feature, PIL.Image.Image)
64-
# so input as np.ndarray is not possible. We need to make it possible
6571
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
6672
return F.to_image_pil(inpt, copy=self.copy)
6773
else:

0 commit comments

Comments
 (0)