Skip to content

Commit 7251769

Browse files
pmeierdatumbox
andauthored
Transforms without dispatcher (#5421)
* add prototype transforms that don't need dispatchers * cleanup * remove legacy_transform decorator * remove legacy classes * remove explicit param passing * streamline extra_repr * remove obsolete ._supports() method * cleanup * remove Query * cleanup * fix tests * kernels -> functional * move image size and num channels extraction to functional * extend legacy function to extract image size and num channels * implement dispatching for auto augment * fix auto augment dispatch * revert some naming changes * remove ability to pass params to autoaugment * fix legacy image size extraction * align prototype.transforms.functional with transforms.functional * cleanup * fix image size and channels extraction * fix affine and rotate * revert image size to (width, height) * Minor corrections Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent f15ba56 commit 7251769

28 files changed

+938
-1076
lines changed

test/test_prototype_transforms.py

Lines changed: 15 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import itertools
22

3-
import PIL.Image
43
import pytest
54
import torch
6-
from test_prototype_transforms_kernels import make_images, make_bounding_boxes, make_one_hot_labels
5+
from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels
76
from torchvision.prototype import transforms, features
87
from torchvision.transforms.functional import to_pil_image
98

@@ -25,15 +24,6 @@ def make_vanilla_tensor_bounding_boxes(*args, **kwargs):
2524
yield bounding_box.data
2625

2726

28-
INPUT_CREATIONS_FNS = {
29-
features.Image: make_images,
30-
features.BoundingBox: make_bounding_boxes,
31-
features.OneHotLabel: make_one_hot_labels,
32-
torch.Tensor: make_vanilla_tensor_images,
33-
PIL.Image.Image: make_pil_images,
34-
}
35-
36-
3727
def parametrize(transforms_with_inputs):
3828
return pytest.mark.parametrize(
3929
("transform", "input"),
@@ -52,15 +42,21 @@ def parametrize(transforms_with_inputs):
5242
def parametrize_from_transforms(*transforms):
5343
transforms_with_inputs = []
5444
for transform in transforms:
55-
dispatcher = transform._DISPATCHER
56-
if dispatcher is None:
57-
continue
58-
59-
for type_ in dispatcher._kernels:
45+
for creation_fn in [
46+
make_images,
47+
make_bounding_boxes,
48+
make_one_hot_labels,
49+
make_vanilla_tensor_images,
50+
make_pil_images,
51+
]:
52+
inputs = list(creation_fn())
6053
try:
61-
inputs = INPUT_CREATIONS_FNS[type_]()
62-
except KeyError:
54+
output = transform(inputs[0])
55+
except Exception:
6356
continue
57+
else:
58+
if output is inputs[0]:
59+
continue
6460

6561
transforms_with_inputs.append((transform, inputs))
6662

@@ -69,7 +65,7 @@ def parametrize_from_transforms(*transforms):
6965

7066
class TestSmoke:
7167
@parametrize_from_transforms(
72-
transforms.RandomErasing(),
68+
transforms.RandomErasing(p=1.0),
7369
transforms.HorizontalFlip(),
7470
transforms.Resize([16, 16]),
7571
transforms.CenterCrop([16, 16]),
@@ -141,35 +137,6 @@ def test_auto_augment(self, transform, input):
141137
def test_normalize(self, transform, input):
142138
transform(input)
143139

144-
@parametrize(
145-
[
146-
(
147-
transforms.ConvertColorSpace("grayscale"),
148-
itertools.chain(
149-
make_images(),
150-
make_vanilla_tensor_images(color_spaces=["rgb"]),
151-
make_pil_images(color_spaces=["rgb"]),
152-
),
153-
)
154-
]
155-
)
156-
def test_convert_bounding_color_space(self, transform, input):
157-
transform(input)
158-
159-
@parametrize(
160-
[
161-
(
162-
transforms.ConvertBoundingBoxFormat("xyxy", old_format="xywh"),
163-
itertools.chain(
164-
make_bounding_boxes(),
165-
make_vanilla_tensor_bounding_boxes(formats=["xywh"]),
166-
),
167-
)
168-
]
169-
)
170-
def test_convert_bounding_box_format(self, transform, input):
171-
transform(input)
172-
173140
@parametrize(
174141
[
175142
(

test/test_prototype_transforms_kernels.py renamed to test/test_prototype_transforms_functional.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55
import torch.testing
6-
import torchvision.prototype.transforms.kernels as K
6+
import torchvision.prototype.transforms.functional as F
77
from torch import jit
88
from torch.nn.functional import one_hot
99
from torchvision.prototype import features
@@ -134,10 +134,10 @@ def __init__(self, *args, **kwargs):
134134
self.kwargs = kwargs
135135

136136

137-
class KernelInfo:
137+
class FunctionalInfo:
138138
def __init__(self, name, *, sample_inputs_fn):
139139
self.name = name
140-
self.kernel = getattr(K, name)
140+
self.functional = getattr(F, name)
141141
self._sample_inputs_fn = sample_inputs_fn
142142

143143
def sample_inputs(self):
@@ -146,21 +146,21 @@ def sample_inputs(self):
146146
def __call__(self, *args, **kwargs):
147147
if len(args) == 1 and not kwargs and isinstance(args[0], SampleInput):
148148
sample_input = args[0]
149-
return self.kernel(*sample_input.args, **sample_input.kwargs)
149+
return self.functional(*sample_input.args, **sample_input.kwargs)
150150

151-
return self.kernel(*args, **kwargs)
151+
return self.functional(*args, **kwargs)
152152

153153

154-
KERNEL_INFOS = []
154+
FUNCTIONAL_INFOS = []
155155

156156

157157
def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn):
158-
KERNEL_INFOS.append(KernelInfo(sample_inputs_fn.__name__, sample_inputs_fn=sample_inputs_fn))
158+
FUNCTIONAL_INFOS.append(FunctionalInfo(sample_inputs_fn.__name__, sample_inputs_fn=sample_inputs_fn))
159159
return sample_inputs_fn
160160

161161

162162
@register_kernel_info_from_sample_inputs_fn
163-
def horizontal_flip_image():
163+
def horizontal_flip_image_tensor():
164164
for image in make_images():
165165
yield SampleInput(image)
166166

@@ -172,12 +172,12 @@ def horizontal_flip_bounding_box():
172172

173173

174174
@register_kernel_info_from_sample_inputs_fn
175-
def resize_image():
175+
def resize_image_tensor():
176176
for image, interpolation in itertools.product(
177177
make_images(),
178178
[
179-
K.InterpolationMode.BILINEAR,
180-
K.InterpolationMode.NEAREST,
179+
F.InterpolationMode.BILINEAR,
180+
F.InterpolationMode.NEAREST,
181181
],
182182
):
183183
height, width = image.shape[-2:]
@@ -200,20 +200,20 @@ def resize_bounding_box():
200200

201201

202202
class TestKernelsCommon:
203-
@pytest.mark.parametrize("kernel_info", KERNEL_INFOS, ids=lambda kernel_info: kernel_info.name)
204-
def test_scriptable(self, kernel_info):
205-
jit.script(kernel_info.kernel)
203+
@pytest.mark.parametrize("functional_info", FUNCTIONAL_INFOS, ids=lambda functional_info: functional_info.name)
204+
def test_scriptable(self, functional_info):
205+
jit.script(functional_info.functional)
206206

207207
@pytest.mark.parametrize(
208-
("kernel_info", "sample_input"),
208+
("functional_info", "sample_input"),
209209
[
210-
pytest.param(kernel_info, sample_input, id=f"{kernel_info.name}-{idx}")
211-
for kernel_info in KERNEL_INFOS
212-
for idx, sample_input in enumerate(kernel_info.sample_inputs())
210+
pytest.param(functional_info, sample_input, id=f"{functional_info.name}-{idx}")
211+
for functional_info in FUNCTIONAL_INFOS
212+
for idx, sample_input in enumerate(functional_info.sample_inputs())
213213
],
214214
)
215-
def test_eager_vs_scripted(self, kernel_info, sample_input):
216-
eager = kernel_info(sample_input)
217-
scripted = jit.script(kernel_info.kernel)(*sample_input.args, **sample_input.kwargs)
215+
def test_eager_vs_scripted(self, functional_info, sample_input):
216+
eager = functional_info(sample_input)
217+
scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs)
218218

219219
torch.testing.assert_close(eager, scripted)

torchvision/prototype/features/_bounding_box.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
4141
# promote this out of the prototype state
4242

4343
# import at runtime to avoid cyclic imports
44-
from torchvision.prototype.transforms.kernels import convert_bounding_box_format
44+
from torchvision.prototype.transforms.functional import convert_bounding_box_format
4545

4646
if isinstance(format, str):
4747
format = BoundingBoxFormat[format]

torchvision/prototype/features/_encoded.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def decode(self) -> Image:
4343
# promote this out of the prototype state
4444

4545
# import at runtime to avoid cyclic imports
46-
from torchvision.prototype.transforms.kernels import decode_image_with_pil
46+
from torchvision.prototype.transforms.functional import decode_image_with_pil
4747

4848
return Image(decode_image_with_pil(self))
4949

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip
2-
from . import kernels # usort: skip
1+
from torchvision.transforms import InterpolationMode, AutoAugmentPolicy # usort: skip
2+
33
from . import functional # usort: skip
4+
45
from ._transform import Transform # usort: skip
56

67
from ._augment import RandomErasing, RandomMixup, RandomCutmix
78
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment
89
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
910
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop
10-
from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertColorSpace
11+
from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
1112
from ._misc import Identity, Normalize, ToDtype, Lambda
1213
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
1314
from ._type_conversion import DecodeImage, LabelToOneHot

torchvision/prototype/transforms/_augment.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import warnings
44
from typing import Any, Dict, Tuple
55

6-
import PIL.Image
76
import torch
87
from torchvision.prototype import features
98
from torchvision.prototype.transforms import Transform, functional as F
@@ -12,9 +11,6 @@
1211

1312

1413
class RandomErasing(Transform):
15-
_DISPATCHER = F.erase
16-
_FAIL_TYPES = {PIL.Image.Image, features.BoundingBox, features.SegmentationMask}
17-
1814
def __init__(
1915
self,
2016
p: float = 0.5,
@@ -45,8 +41,8 @@ def __init__(
4541

4642
def _get_params(self, sample: Any) -> Dict[str, Any]:
4743
image = query_image(sample)
48-
img_h, img_w = F.get_image_size(image)
4944
img_c = F.get_image_num_channels(image)
45+
img_w, img_h = F.get_image_size(image)
5046

5147
if isinstance(self.value, (int, float)):
5248
value = [self.value]
@@ -93,16 +89,24 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
9389
return dict(zip("ijhwv", (i, j, h, w, v)))
9490

9591
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
96-
if torch.rand(1) >= self.p:
92+
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
93+
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
94+
elif isinstance(input, features.Image):
95+
output = F.erase_image_tensor(input, **params)
96+
return features.Image.new_like(input, output)
97+
elif isinstance(input, torch.Tensor):
98+
return F.erase_image_tensor(input, **params)
99+
else:
97100
return input
98101

99-
return super()._transform(input, params)
102+
def forward(self, *inputs: Any) -> Any:
103+
if torch.rand(1) >= self.p:
104+
return inputs if len(inputs) > 1 else inputs[0]
105+
106+
return super().forward(*inputs)
100107

101108

102109
class RandomMixup(Transform):
103-
_DISPATCHER = F.mixup
104-
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
105-
106110
def __init__(self, *, alpha: float) -> None:
107111
super().__init__()
108112
self.alpha = alpha
@@ -111,11 +115,20 @@ def __init__(self, *, alpha: float) -> None:
111115
def _get_params(self, sample: Any) -> Dict[str, Any]:
112116
return dict(lam=float(self._dist.sample(())))
113117

118+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
119+
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
120+
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
121+
elif isinstance(input, features.Image):
122+
output = F.mixup_image_tensor(input, **params)
123+
return features.Image.new_like(input, output)
124+
elif isinstance(input, features.OneHotLabel):
125+
output = F.mixup_one_hot_label(input, **params)
126+
return features.OneHotLabel.new_like(input, output)
127+
else:
128+
return input
114129

115-
class RandomCutmix(Transform):
116-
_DISPATCHER = F.cutmix
117-
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
118130

131+
class RandomCutmix(Transform):
119132
def __init__(self, *, alpha: float) -> None:
120133
super().__init__()
121134
self.alpha = alpha
@@ -125,7 +138,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
125138
lam = float(self._dist.sample(()))
126139

127140
image = query_image(sample)
128-
H, W = F.get_image_size(image)
141+
W, H = F.get_image_size(image)
129142

130143
r_x = torch.randint(W, ())
131144
r_y = torch.randint(H, ())
@@ -143,3 +156,15 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
143156
lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
144157

145158
return dict(box=box, lam_adjusted=lam_adjusted)
159+
160+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
161+
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
162+
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
163+
elif isinstance(input, features.Image):
164+
output = F.cutmix_image_tensor(input, box=params["box"])
165+
return features.Image.new_like(input, output)
166+
elif isinstance(input, features.OneHotLabel):
167+
output = F.cutmix_one_hot_label(input, lam_adjusted=params["lam_adjusted"])
168+
return features.OneHotLabel.new_like(input, output)
169+
else:
170+
return input

0 commit comments

Comments
 (0)