Skip to content

Commit 42b4bf3

Browse files
committed
Random erase bypass boxes and masks
Go back with if-return/elif-return/else-return
1 parent 014b8c7 commit 42b4bf3

File tree

3 files changed

+100
-90
lines changed

3 files changed

+100
-90
lines changed

torchvision/prototype/transforms/_augment.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
9494
elif isinstance(inpt, PIL.Image.Image):
9595
# TODO: We should implement a fallback to tensor, like gaussian_blur etc
9696
raise RuntimeError("Not implemented")
97-
elif isinstance(inpt, torch.Tensor):
98-
return F.erase_image_tensor(inpt, **params)
99-
raise TypeError(
100-
"RandomErasing transformation does not support bounding boxes, segmentation masks and plain labels"
101-
)
97+
else:
98+
return inpt
10299

103100

104101
class _BaseMixupCutmix(Transform):
@@ -133,12 +130,10 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
133130
output = inpt.clone()
134131
output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam))
135132
return features.Image.new_like(inpt, output)
136-
if isinstance(inpt, features.OneHotLabel):
133+
elif isinstance(inpt, features.OneHotLabel):
137134
return self._mixup_onehotlabel(inpt, lam)
138-
139-
raise TypeError(
140-
"RandomMixup transformation does not support bounding boxes, segmentation masks and plain labels"
141-
)
135+
else:
136+
return inpt
142137

143138

144139
class RandomCutmix(_BaseMixupCutmix):
@@ -175,10 +170,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
175170
output = inpt.clone()
176171
output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
177172
return features.Image.new_like(inpt, output)
178-
if isinstance(inpt, features.OneHotLabel):
173+
elif isinstance(inpt, features.OneHotLabel):
179174
lam_adjusted = params["lam_adjusted"]
180175
return self._mixup_onehotlabel(inpt, lam_adjusted)
181-
182-
raise TypeError(
183-
"RandomCutmix transformation does not support bounding boxes, segmentation masks and plain labels"
184-
)
176+
else:
177+
return inpt

torchvision/prototype/transforms/functional/_color.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
def adjust_brightness(inpt: DType, brightness_factor: float) -> DType:
1717
if isinstance(inpt, features._Feature):
1818
return inpt.adjust_brightness(brightness_factor=brightness_factor)
19-
if isinstance(inpt, PIL.Image.Image):
19+
elif isinstance(inpt, PIL.Image.Image):
2020
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
21-
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
21+
else:
22+
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
2223

2324

2425
adjust_saturation_image_tensor = _FT.adjust_saturation
@@ -28,9 +29,10 @@ def adjust_brightness(inpt: DType, brightness_factor: float) -> DType:
2829
def adjust_saturation(inpt: DType, saturation_factor: float) -> DType:
2930
if isinstance(inpt, features._Feature):
3031
return inpt.adjust_saturation(saturation_factor=saturation_factor)
31-
if isinstance(inpt, PIL.Image.Image):
32+
elif isinstance(inpt, PIL.Image.Image):
3233
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)
33-
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
34+
else:
35+
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
3436

3537

3638
adjust_contrast_image_tensor = _FT.adjust_contrast
@@ -40,9 +42,10 @@ def adjust_saturation(inpt: DType, saturation_factor: float) -> DType:
4042
def adjust_contrast(inpt: DType, contrast_factor: float) -> DType:
4143
if isinstance(inpt, features._Feature):
4244
return inpt.adjust_contrast(contrast_factor=contrast_factor)
43-
if isinstance(inpt, PIL.Image.Image):
45+
elif isinstance(inpt, PIL.Image.Image):
4446
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
45-
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
47+
else:
48+
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
4649

4750

4851
adjust_sharpness_image_tensor = _FT.adjust_sharpness
@@ -52,9 +55,10 @@ def adjust_contrast(inpt: DType, contrast_factor: float) -> DType:
5255
def adjust_sharpness(inpt: DType, sharpness_factor: float) -> DType:
5356
if isinstance(inpt, features._Feature):
5457
return inpt.adjust_sharpness(sharpness_factor=sharpness_factor)
55-
if isinstance(inpt, PIL.Image.Image):
58+
elif isinstance(inpt, PIL.Image.Image):
5659
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
57-
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
60+
else:
61+
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
5862

5963

6064
adjust_hue_image_tensor = _FT.adjust_hue
@@ -64,9 +68,10 @@ def adjust_sharpness(inpt: DType, sharpness_factor: float) -> DType:
6468
def adjust_hue(inpt: DType, hue_factor: float) -> DType:
6569
if isinstance(inpt, features._Feature):
6670
return inpt.adjust_hue(hue_factor=hue_factor)
67-
if isinstance(inpt, PIL.Image.Image):
71+
elif isinstance(inpt, PIL.Image.Image):
6872
return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
69-
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
73+
else:
74+
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
7075

7176

7277
adjust_gamma_image_tensor = _FT.adjust_gamma
@@ -76,9 +81,10 @@ def adjust_hue(inpt: DType, hue_factor: float) -> DType:
7681
def adjust_gamma(inpt: DType, gamma: float, gain: float = 1) -> DType:
7782
if isinstance(inpt, features._Feature):
7883
return inpt.adjust_gamma(gamma=gamma, gain=gain)
79-
if isinstance(inpt, PIL.Image.Image):
84+
elif isinstance(inpt, PIL.Image.Image):
8085
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
81-
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
86+
else:
87+
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
8288

8389

8490
posterize_image_tensor = _FT.posterize
@@ -88,9 +94,10 @@ def adjust_gamma(inpt: DType, gamma: float, gain: float = 1) -> DType:
8894
def posterize(inpt: DType, bits: int) -> DType:
8995
if isinstance(inpt, features._Feature):
9096
return inpt.posterize(bits=bits)
91-
if isinstance(inpt, PIL.Image.Image):
97+
elif isinstance(inpt, PIL.Image.Image):
9298
return posterize_image_pil(inpt, bits=bits)
93-
return posterize_image_tensor(inpt, bits=bits)
99+
else:
100+
return posterize_image_tensor(inpt, bits=bits)
94101

95102

96103
solarize_image_tensor = _FT.solarize
@@ -100,9 +107,10 @@ def posterize(inpt: DType, bits: int) -> DType:
100107
def solarize(inpt: DType, threshold: float) -> DType:
101108
if isinstance(inpt, features._Feature):
102109
return inpt.solarize(threshold=threshold)
103-
if isinstance(inpt, PIL.Image.Image):
110+
elif isinstance(inpt, PIL.Image.Image):
104111
return solarize_image_pil(inpt, threshold=threshold)
105-
return solarize_image_tensor(inpt, threshold=threshold)
112+
else:
113+
return solarize_image_tensor(inpt, threshold=threshold)
106114

107115

108116
autocontrast_image_tensor = _FT.autocontrast
@@ -112,9 +120,10 @@ def solarize(inpt: DType, threshold: float) -> DType:
112120
def autocontrast(inpt: DType) -> DType:
113121
if isinstance(inpt, features._Feature):
114122
return inpt.autocontrast()
115-
if isinstance(inpt, PIL.Image.Image):
123+
elif isinstance(inpt, PIL.Image.Image):
116124
return autocontrast_image_pil(inpt)
117-
return autocontrast_image_tensor(inpt)
125+
else:
126+
return autocontrast_image_tensor(inpt)
118127

119128

120129
equalize_image_tensor = _FT.equalize
@@ -124,9 +133,10 @@ def autocontrast(inpt: DType) -> DType:
124133
def equalize(inpt: DType) -> DType:
125134
if isinstance(inpt, features._Feature):
126135
return inpt.equalize()
127-
if isinstance(inpt, PIL.Image.Image):
136+
elif isinstance(inpt, PIL.Image.Image):
128137
return equalize_image_pil(inpt)
129-
return equalize_image_tensor(inpt)
138+
else:
139+
return equalize_image_tensor(inpt)
130140

131141

132142
invert_image_tensor = _FT.invert
@@ -136,6 +146,7 @@ def equalize(inpt: DType) -> DType:
136146
def invert(inpt: DType) -> DType:
137147
if isinstance(inpt, features._Feature):
138148
return inpt.invert()
139-
if isinstance(inpt, PIL.Image.Image):
149+
elif isinstance(inpt, PIL.Image.Image):
140150
return invert_image_pil(inpt)
141-
return invert_image_tensor(inpt)
151+
else:
152+
return invert_image_tensor(inpt)

0 commit comments

Comments
 (0)