Skip to content

Commit 31d5f10

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] fix vanilla tensor image detection (#5518)
Summary: * fix vanilla tensor image detection * fix naming Reviewed By: vmoens Differential Revision: D34878991 fbshipit-source-id: cbe7f908bde0b2cfb90a7174ebc65daebe6fa09e Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 1de9f17 commit 31d5f10

File tree

6 files changed

+22
-14
lines changed

6 files changed

+22
-14
lines changed

torchvision/prototype/transforms/_augment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torchvision.prototype import features
88
from torchvision.prototype.transforms import Transform, functional as F
99

10-
from ._utils import query_image, get_image_dimensions, has_all, has_any
10+
from ._utils import query_image, get_image_dimensions, has_all, has_any, is_simple_tensor
1111

1212

1313
class RandomErasing(Transform):
@@ -90,7 +90,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
9090
if isinstance(input, features.Image):
9191
output = F.erase_image_tensor(input, **params)
9292
return features.Image.new_like(input, output)
93-
elif isinstance(input, torch.Tensor):
93+
elif is_simple_tensor(input):
9494
return F.erase_image_tensor(input, **params)
9595
else:
9696
return input

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torchvision.prototype.utils._internal import query_recursively
99
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
1010

11-
from ._utils import get_image_dimensions
11+
from ._utils import get_image_dimensions, is_simple_tensor
1212

1313
K = TypeVar("K")
1414
V = TypeVar("V")
@@ -89,7 +89,7 @@ def _dispatch_image_kernels(
8989
if isinstance(input, features.Image):
9090
output = image_tensor_kernel(input, *args, **kwargs)
9191
return features.Image.new_like(input, output)
92-
elif isinstance(input, torch.Tensor):
92+
elif is_simple_tensor(input):
9393
return image_tensor_kernel(input, *args, **kwargs)
9494
else: # isinstance(input, PIL.Image.Image):
9595
return image_pil_kernel(input, *args, **kwargs)

torchvision/prototype/transforms/_geometry.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
99
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
1010

11-
from ._utils import query_image, get_image_dimensions, has_any
11+
from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor
1212

1313

1414
class HorizontalFlip(Transform):
@@ -21,7 +21,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
2121
return features.BoundingBox.new_like(input, output)
2222
elif isinstance(input, PIL.Image.Image):
2323
return F.horizontal_flip_image_pil(input)
24-
elif isinstance(input, torch.Tensor):
24+
elif is_simple_tensor(input):
2525
return F.horizontal_flip_image_tensor(input)
2626
else:
2727
return input
@@ -49,7 +49,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
4949
return features.BoundingBox.new_like(input, output, image_size=cast(Tuple[int, int], tuple(self.size)))
5050
elif isinstance(input, PIL.Image.Image):
5151
return F.resize_image_pil(input, self.size, interpolation=self.interpolation)
52-
elif isinstance(input, torch.Tensor):
52+
elif is_simple_tensor(input):
5353
return F.resize_image_tensor(input, self.size, interpolation=self.interpolation)
5454
else:
5555
return input
@@ -64,7 +64,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
6464
if isinstance(input, features.Image):
6565
output = F.center_crop_image_tensor(input, self.output_size)
6666
return features.Image.new_like(input, output)
67-
elif isinstance(input, torch.Tensor):
67+
elif is_simple_tensor(input):
6868
return F.center_crop_image_tensor(input, self.output_size)
6969
elif isinstance(input, PIL.Image.Image):
7070
return F.center_crop_image_pil(input, self.output_size)
@@ -156,7 +156,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
156156
input, **params, size=list(self.size), interpolation=self.interpolation
157157
)
158158
return features.Image.new_like(input, output)
159-
elif isinstance(input, torch.Tensor):
159+
elif is_simple_tensor(input):
160160
return F.resized_crop_image_tensor(input, **params, size=list(self.size), interpolation=self.interpolation)
161161
elif isinstance(input, PIL.Image.Image):
162162
return F.resized_crop_image_pil(input, **params, size=list(self.size), interpolation=self.interpolation)

torchvision/prototype/transforms/_meta.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from torchvision.prototype.transforms import Transform, functional as F
77
from torchvision.transforms.functional import convert_image_dtype
88

9+
from ._utils import is_simple_tensor
10+
911

1012
class ConvertBoundingBoxFormat(Transform):
1113
def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None:
@@ -15,7 +17,7 @@ def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None:
1517
self.format = format
1618

1719
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
18-
if type(input) is features.BoundingBox:
20+
if isinstance(input, features.BoundingBox):
1921
output = F.convert_bounding_box_format(input, old_format=input.format, new_format=params["format"])
2022
return features.BoundingBox.new_like(input, output, format=params["format"])
2123
else:
@@ -28,9 +30,11 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None:
2830
self.dtype = dtype
2931

3032
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
31-
if type(input) is features.Image:
33+
if isinstance(input, features.Image):
3234
output = convert_image_dtype(input, dtype=self.dtype)
3335
return features.Image.new_like(input, output, dtype=self.dtype)
36+
elif is_simple_tensor(input):
37+
return convert_image_dtype(input, dtype=self.dtype)
3438
else:
3539
return input
3640

@@ -57,7 +61,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
5761
input, old_color_space=input.color_space, new_color_space=self.color_space
5862
)
5963
return features.Image.new_like(input, output, color_space=self.color_space)
60-
elif isinstance(input, torch.Tensor):
64+
elif is_simple_tensor(input):
6165
if self.old_color_space is None:
6266
raise RuntimeError(
6367
f"In order to convert vanilla tensor images, `{type(self).__name__}(...)` "

torchvision/prototype/transforms/_type_conversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
class DecodeImage(Transform):
88
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
9-
if type(input) is features.EncodedImage:
9+
if isinstance(input, features.EncodedImage):
1010
output = F.decode_image_with_pil(input)
1111
return features.Image(output)
1212
else:
@@ -19,7 +19,7 @@ def __init__(self, num_categories: int = -1):
1919
self.num_categories = num_categories
2020

2121
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
22-
if type(input) is features.Label:
22+
if isinstance(input, features.Label):
2323
num_categories = self.num_categories
2424
if num_categories == -1 and input.categories is not None:
2525
num_categories = len(input.categories)

torchvision/prototype/transforms/_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,7 @@ def has_any(sample: Any, *types: Type) -> bool:
4646

4747
def has_all(sample: Any, *types: Type) -> bool:
4848
return not bool(set(types) - set(_extract_types(sample)))
49+
50+
51+
def is_simple_tensor(input: Any) -> bool:
52+
return isinstance(input, torch.Tensor) and not isinstance(input, features._Feature)

0 commit comments

Comments
 (0)