Skip to content

Commit e13206d

Browse files
pmeierdatumbox
andauthored
add option to fail a transform on certain types rather than passthrough (#5432)
* add option to fail a transform on certain types rather than passthrough * address comments Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 2fe0c2d commit e13206d

File tree

4 files changed

+20
-10
lines changed

4 files changed

+20
-10
lines changed

torchvision/prototype/transforms/_augment.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33
import warnings
44
from typing import Any, Dict, Tuple
55

6+
import PIL.Image
67
import torch
8+
from torchvision.prototype import features
79
from torchvision.prototype.transforms import Transform, functional as F
810

911
from ._utils import query_image
1012

1113

1214
class RandomErasing(Transform):
1315
_DISPATCHER = F.erase
16+
_FAIL_TYPES = {PIL.Image.Image, features.BoundingBox, features.SegmentationMask}
1417

1518
def __init__(
1619
self,
@@ -98,6 +101,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
98101

99102
class RandomMixup(Transform):
100103
_DISPATCHER = F.mixup
104+
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
101105

102106
def __init__(self, *, alpha: float) -> None:
103107
super().__init__()
@@ -110,6 +114,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
110114

111115
class RandomCutmix(Transform):
112116
_DISPATCHER = F.cutmix
117+
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
113118

114119
def __init__(self, *, alpha: float) -> None:
115120
super().__init__()

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,6 @@ def __init__(
7979
"Invert": lambda input, magnitude, interpolation, fill: F.invert(input),
8080
}
8181

82-
def _is_supported(self, obj: Any) -> bool:
83-
return type(obj) in {features.Image, torch.Tensor} or isinstance(obj, PIL.Image.Image)
84-
8582
def _get_params(self, sample: Any) -> Dict[str, Any]:
8683
image = query_image(sample)
8784
num_channels = F.get_image_num_channels(image)
@@ -103,11 +100,13 @@ def _apply_transform(self, sample: Any, params: Dict[str, Any], transform_id: st
103100
dispatcher = self._DISPATCHER_MAP[transform_id]
104101

105102
def transform(input: Any) -> Any:
106-
if not self._is_supported(input):
103+
if type(input) in {features.Image, torch.Tensor} or isinstance(input, PIL.Image.Image):
104+
return dispatcher(input, magnitude, params["interpolation"], params["fill"])
105+
elif type(input) in {features.BoundingBox, features.SegmentationMask}:
106+
raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()")
107+
else:
107108
return input
108109

109-
return dispatcher(input, magnitude, params["interpolation"], params["fill"])
110-
111110
return apply_recursively(transform, sample)
112111

113112

torchvision/prototype/transforms/_geometry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Dict, List, Union, Sequence, Tuple, cast
44

55
import torch
6+
from torchvision.prototype import features
67
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
78
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
89

@@ -31,6 +32,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
3132

3233
class CenterCrop(Transform):
3334
_DISPATCHER = F.center_crop
35+
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
3436

3537
def __init__(self, output_size: List[int]):
3638
super().__init__()
@@ -42,6 +44,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
4244

4345
class RandomResizedCrop(Transform):
4446
_DISPATCHER = F.resized_crop
47+
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
4548

4649
def __init__(
4750
self,

torchvision/prototype/transforms/_transform.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import enum
22
import functools
3-
from typing import Any, Dict, Optional
3+
from typing import Any, Dict, Optional, Set, Type
44

55
from torch import nn
66
from torchvision.prototype.utils._internal import apply_recursively
@@ -11,6 +11,7 @@
1111

1212
class Transform(nn.Module):
1313
_DISPATCHER: Optional[Dispatcher] = None
14+
_FAIL_TYPES: Set[Type] = set()
1415

1516
def __init__(self) -> None:
1617
super().__init__()
@@ -23,11 +24,13 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
2324
if not self._DISPATCHER:
2425
raise NotImplementedError()
2526

26-
if input not in self._DISPATCHER:
27+
if input in self._DISPATCHER:
28+
return self._DISPATCHER(input, **params)
29+
elif type(input) in self._FAIL_TYPES:
30+
raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()")
31+
else:
2732
return input
2833

29-
return self._DISPATCHER(input, **params)
30-
3134
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any:
3235
sample = inputs if len(inputs) > 1 else inputs[0]
3336
return apply_recursively(functools.partial(self._transform, params=params or self._get_params(sample)), sample)

0 commit comments

Comments
 (0)