Skip to content

Commit 4a0cec5

Browse files
committed
Restored fill default value to None
Updated code according to the review
1 parent 02ec95a commit 4a0cec5

File tree

10 files changed

+69
-92
lines changed

10 files changed

+69
-92
lines changed

test/test_prototype_transforms.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class TestSmoke:
7272
transforms.ConvertImageDtype(),
7373
transforms.RandomHorizontalFlip(),
7474
transforms.Pad(5),
75+
transforms.RandomZoomOut(),
7576
)
7677
def test_common(self, transform, input):
7778
transform(input)

torchvision/prototype/features/_bounding_box.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def resized_crop(
130130
def pad(
131131
self,
132132
padding: Union[int, Sequence[int]],
133-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
133+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
134134
padding_mode: str = "constant",
135135
) -> BoundingBox:
136136
from torchvision.prototype.transforms import functional as _F
@@ -160,7 +160,7 @@ def rotate(
160160
angle: float,
161161
interpolation: InterpolationMode = InterpolationMode.NEAREST,
162162
expand: bool = False,
163-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
163+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
164164
center: Optional[List[float]] = None,
165165
) -> BoundingBox:
166166
from torchvision.prototype.transforms import functional as _F
@@ -180,7 +180,7 @@ def affine(
180180
scale: float,
181181
shear: List[float],
182182
interpolation: InterpolationMode = InterpolationMode.NEAREST,
183-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
183+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
184184
center: Optional[List[float]] = None,
185185
) -> BoundingBox:
186186
from torchvision.prototype.transforms import functional as _F
@@ -201,7 +201,7 @@ def perspective(
201201
self,
202202
perspective_coeffs: List[float],
203203
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
204-
fill: Optional[List[float]] = None,
204+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
205205
) -> BoundingBox:
206206
from torchvision.prototype.transforms import functional as _F
207207

torchvision/prototype/features/_feature.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def resized_crop(
122122
def pad(
123123
self,
124124
padding: Union[int, Sequence[int]],
125-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
125+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
126126
padding_mode: str = "constant",
127127
) -> Any:
128128
return self
@@ -132,7 +132,7 @@ def rotate(
132132
angle: float,
133133
interpolation: InterpolationMode = InterpolationMode.NEAREST,
134134
expand: bool = False,
135-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
135+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
136136
center: Optional[List[float]] = None,
137137
) -> Any:
138138
return self
@@ -144,7 +144,7 @@ def affine(
144144
scale: float,
145145
shear: List[float],
146146
interpolation: InterpolationMode = InterpolationMode.NEAREST,
147-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
147+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
148148
center: Optional[List[float]] = None,
149149
) -> Any:
150150
return self
@@ -153,7 +153,7 @@ def perspective(
153153
self,
154154
perspective_coeffs: List[float],
155155
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
156-
fill: Optional[List[float]] = None,
156+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
157157
) -> Any:
158158
return self
159159

torchvision/prototype/features/_image.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def resized_crop(
166166
def pad(
167167
self,
168168
padding: Union[int, Sequence[int]],
169-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
169+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
170170
padding_mode: str = "constant",
171171
) -> Image:
172172
from torchvision.prototype.transforms import functional as _F
@@ -175,6 +175,9 @@ def pad(
175175
if not isinstance(padding, int):
176176
padding = list(padding)
177177

178+
if fill is None:
179+
fill = 0
180+
178181
# PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
179182
if isinstance(fill, (int, float)):
180183
output = _F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
@@ -190,18 +193,12 @@ def rotate(
190193
angle: float,
191194
interpolation: InterpolationMode = InterpolationMode.NEAREST,
192195
expand: bool = False,
193-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
196+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
194197
center: Optional[List[float]] = None,
195198
) -> Image:
196-
from torchvision.prototype.transforms import functional as _F
199+
from torchvision.prototype.transforms.functional import _geometry as _F
197200

198-
# This cast does Sequence -> List[float] to please mypy and torch.jit.script
199-
if not isinstance(fill, (int, float)):
200-
fill = [float(v) for v in list(fill)]
201-
202-
if isinstance(fill, (int, float)):
203-
# It is OK to cast int to float as later we use inpt.dtype
204-
fill = [float(fill)]
201+
fill = _F._convert_fill_arg(fill)
205202

206203
output = _F.rotate_image_tensor(
207204
self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center
@@ -215,18 +212,12 @@ def affine(
215212
scale: float,
216213
shear: List[float],
217214
interpolation: InterpolationMode = InterpolationMode.NEAREST,
218-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
215+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
219216
center: Optional[List[float]] = None,
220217
) -> Image:
221-
from torchvision.prototype.transforms import functional as _F
218+
from torchvision.prototype.transforms.functional import _geometry as _F
222219

223-
# This cast does Sequence -> List[float] to please mypy and torch.jit.script
224-
if not isinstance(fill, (int, float)):
225-
fill = [float(v) for v in list(fill)]
226-
227-
if isinstance(fill, (int, float)):
228-
# It is OK to cast int to float as later we use inpt.dtype
229-
fill = [float(fill)]
220+
fill = _F._convert_fill_arg(fill)
230221

231222
output = _F.affine_image_tensor(
232223
self,
@@ -244,9 +235,11 @@ def perspective(
244235
self,
245236
perspective_coeffs: List[float],
246237
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
247-
fill: Optional[List[float]] = None,
238+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
248239
) -> Image:
249-
from torchvision.prototype.transforms import functional as _F
240+
from torchvision.prototype.transforms.functional import _geometry as _F
241+
242+
fill = _F._convert_fill_arg(fill)
250243

251244
output = _F.perspective_image_tensor(self, perspective_coeffs, interpolation=interpolation, fill=fill)
252245
return Image.new_like(self, output)

torchvision/prototype/features/_segmentation_mask.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def resized_crop(
6262
def pad(
6363
self,
6464
padding: Union[int, Sequence[int]],
65-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
65+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
6666
padding_mode: str = "constant",
6767
) -> SegmentationMask:
6868
from torchvision.prototype.transforms import functional as _F
@@ -79,7 +79,7 @@ def rotate(
7979
angle: float,
8080
interpolation: InterpolationMode = InterpolationMode.NEAREST,
8181
expand: bool = False,
82-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
82+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
8383
center: Optional[List[float]] = None,
8484
) -> SegmentationMask:
8585
from torchvision.prototype.transforms import functional as _F
@@ -94,7 +94,7 @@ def affine(
9494
scale: float,
9595
shear: List[float],
9696
interpolation: InterpolationMode = InterpolationMode.NEAREST,
97-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
97+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
9898
center: Optional[List[float]] = None,
9999
) -> SegmentationMask:
100100
from torchvision.prototype.transforms import functional as _F
@@ -113,7 +113,7 @@ def perspective(
113113
self,
114114
perspective_coeffs: List[float],
115115
interpolation: InterpolationMode = InterpolationMode.NEAREST,
116-
fill: Optional[List[float]] = None,
116+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
117117
) -> SegmentationMask:
118118
from torchvision.prototype.transforms import functional as _F
119119

torchvision/prototype/transforms/_augment.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,12 @@ def forward(self, *inpts: Any) -> Any:
113113
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
114114
return super().forward(sample)
115115

116-
117-
def _mixup_onehotlabel(inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel:
118-
if inpt.ndim < 2:
119-
raise ValueError("Need a batch of one hot labels")
120-
output = inpt.clone()
121-
output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam))
122-
return features.OneHotLabel.new_like(inpt, output)
116+
def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel:
117+
if inpt.ndim < 2:
118+
raise ValueError("Need a batch of one hot labels")
119+
output = inpt.clone()
120+
output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam))
121+
return features.OneHotLabel.new_like(inpt, output)
123122

124123

125124
class RandomMixup(_BaseMixupCutmix):
@@ -135,7 +134,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
135134
output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam))
136135
return features.Image.new_like(inpt, output)
137136
if isinstance(inpt, features.OneHotLabel):
138-
return _mixup_onehotlabel(inpt, lam)
137+
return self._mixup_onehotlabel(inpt, lam)
139138

140139
raise TypeError(
141140
"RandomMixup transformation does not support bounding boxes, segmentation masks and plain labels"
@@ -178,7 +177,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
178177
return features.Image.new_like(inpt, output)
179178
if isinstance(inpt, features.OneHotLabel):
180179
lam_adjusted = params["lam_adjusted"]
181-
return _mixup_onehotlabel(inpt, lam_adjusted)
180+
return self._mixup_onehotlabel(inpt, lam_adjusted)
182181

183182
raise TypeError(
184183
"RandomCutmix transformation does not support bounding boxes, segmentation masks and plain labels"

torchvision/prototype/transforms/_color.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import collections.abc
2-
from typing import Any, Dict, Union, Tuple, Optional, Sequence, Callable, TypeVar
2+
from typing import Any, Dict, Union, Tuple, Optional, Sequence, TypeVar
33

44
import PIL.Image
55
import torch
@@ -52,24 +52,6 @@ def _check_input(
5252

5353
return None if value[0] == value[1] == center else (float(value[0]), float(value[1]))
5454

55-
def _image_transform(
56-
self,
57-
inpt: T,
58-
*,
59-
kernel_tensor: Callable[..., torch.Tensor],
60-
kernel_pil: Callable[..., PIL.Image.Image],
61-
**kwargs: Any,
62-
) -> T:
63-
if isinstance(inpt, features.Image):
64-
output = kernel_tensor(inpt, **kwargs)
65-
return features.Image.new_like(inpt, output)
66-
elif is_simple_tensor(inpt):
67-
return kernel_tensor(inpt, **kwargs)
68-
elif isinstance(inpt, PIL.Image.Image):
69-
return kernel_pil(inpt, **kwargs) # type: ignore[no-any-return]
70-
else:
71-
raise RuntimeError
72-
7355
@staticmethod
7456
def _generate_value(left: float, right: float) -> float:
7557
return float(torch.distributions.Uniform(left, right).sample())

torchvision/prototype/transforms/_geometry.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,6 @@ def __init__(
270270
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
271271
raise ValueError(f"Invalid canvas side range provided {side_range}.")
272272

273-
self._pad_op = Pad(0, padding_mode="constant")
274-
275273
def _get_params(self, sample: Any) -> Dict[str, Any]:
276274
image = query_image(sample)
277275
orig_c, orig_h, orig_w = get_image_dimensions(image)
@@ -293,11 +291,8 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
293291

294292
return dict(padding=padding, fill=fill)
295293

296-
def forward(self, *inputs: Any) -> Any:
297-
params = self._get_params(inputs)
298-
self._pad_op.padding = params["padding"]
299-
self._pad_op.fill = params["fill"]
300-
return self._pad_op(*inputs)
294+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
295+
return F.pad(inpt, **params)
301296

302297

303298
class RandomRotation(Transform):

0 commit comments

Comments
 (0)