Skip to content

Commit b8af91a

Browse files
datumboxpmeier
andauthored
[prototype] Fix BC-breakages on input params of F (#6636)
* Fix `size` in resize. * Update torchvision/prototype/features/_bounding_box.py Co-authored-by: Philip Meier <[email protected]> * Address some of the comments. * Fix `output_size` in center_crop. * Fix `CenterCrop` transform * Fix `size` in five_crop. * Fix `size` in ten_crop. * Fix `kernel_size` and `sigma` in gaussian_blur. * Fix `angle` and `shear` in affine. * Fixing JIT-scriptability issues. * Update TODOs. * Restore fake types for `Union[int, List[int]]` and `Union[int, float, List[float]]` * Fixing tests * Fix linter * revert unnecessary JIT mitigations. * Cherrypick Philip's 6dfc965 * Linter fix * Adding center float casting Co-authored-by: Philip Meier <[email protected]>
1 parent c2ca691 commit b8af91a

File tree

10 files changed

+78
-61
lines changed

10 files changed

+78
-61
lines changed

test/test_prototype_transforms.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,9 @@ def test_assertions(self):
799799
with pytest.raises(ValueError, match="Kernel size value should be an odd and positive number"):
800800
transforms.GaussianBlur(4)
801801

802-
with pytest.raises(TypeError, match="sigma should be a single float or a list/tuple with length 2"):
802+
with pytest.raises(
803+
TypeError, match="sigma should be a single int or float or a list/tuple with length 2 floats."
804+
):
803805
transforms.GaussianBlur(3, sigma=[1, 2, 3])
804806

805807
with pytest.raises(ValueError, match="If sigma is a single number, it must be positive"):
@@ -833,7 +835,7 @@ def test__transform(self, kernel_size, sigma, mocker):
833835
if isinstance(sigma, (tuple, list)):
834836
assert transform.sigma == sigma
835837
else:
836-
assert transform.sigma == (sigma, sigma)
838+
assert transform.sigma == [sigma, sigma]
837839

838840
fn = mocker.patch("torchvision.prototype.transforms.functional.gaussian_blur")
839841
inpt = mocker.MagicMock(spec=features.Image)

torchvision/prototype/features/_bounding_box.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def resize( # type: ignore[override]
8484
antialias: bool = False,
8585
) -> BoundingBox:
8686
output = self._F.resize_bounding_box(self, size, image_size=self.image_size, max_size=max_size)
87+
if isinstance(size, int):
88+
size = [size]
8789
image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1])
8890
return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype)
8991

@@ -95,6 +97,8 @@ def center_crop(self, output_size: List[int]) -> BoundingBox:
9597
output = self._F.center_crop_bounding_box(
9698
self, format=self.format, output_size=output_size, image_size=self.image_size
9799
)
100+
if isinstance(output_size, int):
101+
output_size = [output_size]
98102
image_size = (output_size[0], output_size[0]) if len(output_size) == 1 else (output_size[0], output_size[1])
99103
return BoundingBox.new_like(self, output, image_size=image_size)
100104

@@ -160,7 +164,7 @@ def rotate(
160164

161165
def affine(
162166
self,
163-
angle: float,
167+
angle: Union[int, float],
164168
translate: List[float],
165169
scale: float,
166170
shear: List[float],

torchvision/prototype/features/_feature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def rotate(
169169

170170
def affine(
171171
self,
172-
angle: float,
172+
angle: Union[int, float],
173173
translate: List[float],
174174
scale: float,
175175
shear: List[float],

torchvision/prototype/features/_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def rotate(
198198

199199
def affine(
200200
self,
201-
angle: float,
201+
angle: Union[int, float],
202202
translate: List[float],
203203
scale: float,
204204
shear: List[float],

torchvision/prototype/features/_mask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def rotate(
7070

7171
def affine(
7272
self,
73-
angle: float,
73+
angle: Union[int, float],
7474
translate: List[float],
7575
scale: float,
7676
shear: List[float],

torchvision/prototype/transforms/_geometry.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
_check_sequence_input,
2020
_setup_angle,
2121
_setup_fill_arg,
22+
_setup_float_or_seq,
2223
_setup_size,
2324
has_all,
2425
has_any,
@@ -67,9 +68,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
6768

6869

6970
class CenterCrop(Transform):
70-
def __init__(self, size: List[int]):
71+
def __init__(self, size: Union[int, Sequence[int]]):
7172
super().__init__()
72-
self.size = size
73+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
7374

7475
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
7576
return F.center_crop(inpt, output_size=self.size)
@@ -320,7 +321,7 @@ def __init__(
320321
degrees: Union[numbers.Number, Sequence],
321322
translate: Optional[Sequence[float]] = None,
322323
scale: Optional[Sequence[float]] = None,
323-
shear: Optional[Union[float, Sequence[float]]] = None,
324+
shear: Optional[Union[int, float, Sequence[float]]] = None,
324325
interpolation: InterpolationMode = InterpolationMode.NEAREST,
325326
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
326327
center: Optional[List[float]] = None,
@@ -545,23 +546,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
545546
)
546547

547548

548-
def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]:
549-
if not isinstance(arg, (float, Sequence)):
550-
raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}")
551-
if isinstance(arg, Sequence) and len(arg) != req_size:
552-
raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}")
553-
if isinstance(arg, Sequence):
554-
for element in arg:
555-
if not isinstance(element, float):
556-
raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}")
557-
558-
if isinstance(arg, float):
559-
arg = [float(arg), float(arg)]
560-
if isinstance(arg, (list, tuple)) and len(arg) == 1:
561-
arg = [arg[0], arg[0]]
562-
return arg
563-
564-
565549
class ElasticTransform(Transform):
566550
def __init__(
567551
self,

torchvision/prototype/transforms/_misc.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torchvision.prototype import features
99
from torchvision.prototype.transforms import functional as F, Transform
1010

11-
from ._utils import _setup_size, has_any, query_bounding_box
11+
from ._utils import _setup_float_or_seq, _setup_size, has_any, query_bounding_box
1212

1313

1414
class Identity(Transform):
@@ -112,25 +112,25 @@ def forward(self, *inpts: Any) -> Any:
112112

113113
class GaussianBlur(Transform):
114114
def __init__(
115-
self, kernel_size: Union[int, Sequence[int]], sigma: Union[float, Sequence[float]] = (0.1, 2.0)
115+
self, kernel_size: Union[int, Sequence[int]], sigma: Union[int, float, Sequence[float]] = (0.1, 2.0)
116116
) -> None:
117117
super().__init__()
118118
self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
119119
for ks in self.kernel_size:
120120
if ks <= 0 or ks % 2 == 0:
121121
raise ValueError("Kernel size value should be an odd and positive number.")
122122

123-
if isinstance(sigma, float):
123+
if isinstance(sigma, (int, float)):
124124
if sigma <= 0:
125125
raise ValueError("If sigma is a single number, it must be positive.")
126-
sigma = (sigma, sigma)
126+
sigma = float(sigma)
127127
elif isinstance(sigma, Sequence) and len(sigma) == 2:
128128
if not 0.0 < sigma[0] <= sigma[1]:
129129
raise ValueError("sigma values should be positive and of the form (min, max).")
130130
else:
131-
raise TypeError("sigma should be a single float or a list/tuple with length 2 floats.")
131+
raise TypeError("sigma should be a single int or float or a list/tuple with length 2 floats.")
132132

133-
self.sigma = sigma
133+
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
134134

135135
def _get_params(self, sample: Any) -> Dict[str, Any]:
136136
sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item()

torchvision/prototype/transforms/_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,23 @@
1616
from typing_extensions import Literal
1717

1818

19+
def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]:
20+
if not isinstance(arg, (float, Sequence)):
21+
raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}")
22+
if isinstance(arg, Sequence) and len(arg) != req_size:
23+
raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}")
24+
if isinstance(arg, Sequence):
25+
for element in arg:
26+
if not isinstance(element, float):
27+
raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}")
28+
29+
if isinstance(arg, float):
30+
arg = [float(arg), float(arg)]
31+
if isinstance(arg, (list, tuple)) and len(arg) == 1:
32+
arg = [arg[0], arg[0]]
33+
return arg
34+
35+
1936
def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
2037
if isinstance(fill, dict):
2138
for key, value in fill.items():

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ def resize_image_tensor(
9797
max_size: Optional[int] = None,
9898
antialias: bool = False,
9999
) -> torch.Tensor:
100+
if isinstance(size, int):
101+
size = [size]
100102
num_channels, old_height, old_width = get_dimensions_image_tensor(image)
101103
new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
102104
extra_dims = image.shape[:-3]
@@ -145,6 +147,8 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N
145147
def resize_bounding_box(
146148
bounding_box: torch.Tensor, size: List[int], image_size: Tuple[int, int], max_size: Optional[int] = None
147149
) -> torch.Tensor:
150+
if isinstance(size, int):
151+
size = [size]
148152
old_height, old_width = image_size
149153
new_height, new_width = _compute_resized_output_size(image_size, size=size, max_size=max_size)
150154
ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device)
@@ -171,7 +175,7 @@ def resize(
171175

172176

173177
def _affine_parse_args(
174-
angle: float,
178+
angle: Union[int, float],
175179
translate: List[float],
176180
scale: float,
177181
shear: List[float],
@@ -214,15 +218,18 @@ def _affine_parse_args(
214218
if len(shear) != 2:
215219
raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")
216220

217-
if center is not None and not isinstance(center, (list, tuple)):
218-
raise TypeError("Argument center should be a sequence")
221+
if center is not None:
222+
if not isinstance(center, (list, tuple)):
223+
raise TypeError("Argument center should be a sequence")
224+
else:
225+
center = [float(c) for c in center]
219226

220227
return angle, translate, shear, center
221228

222229

223230
def affine_image_tensor(
224231
img: torch.Tensor,
225-
angle: float,
232+
angle: Union[int, float],
226233
translate: List[float],
227234
scale: float,
228235
shear: List[float],
@@ -254,7 +261,7 @@ def affine_image_tensor(
254261
@torch.jit.unused
255262
def affine_image_pil(
256263
img: PIL.Image.Image,
257-
angle: float,
264+
angle: Union[int, float],
258265
translate: List[float],
259266
scale: float,
260267
shear: List[float],
@@ -278,34 +285,26 @@ def affine_image_pil(
278285
def _affine_bounding_box_xyxy(
279286
bounding_box: torch.Tensor,
280287
image_size: Tuple[int, int],
281-
angle: float,
282-
translate: Optional[List[float]] = None,
283-
scale: Optional[float] = None,
284-
shear: Optional[List[float]] = None,
288+
angle: Union[int, float],
289+
translate: List[float],
290+
scale: float,
291+
shear: List[float],
285292
center: Optional[List[float]] = None,
286293
expand: bool = False,
287294
) -> torch.Tensor:
288-
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
289-
device = bounding_box.device
290-
291-
if translate is None:
292-
translate = [0.0, 0.0]
293-
294-
if scale is None:
295-
scale = 1.0
296-
297-
if shear is None:
298-
shear = [0.0, 0.0]
295+
angle, translate, shear, center = _affine_parse_args(
296+
angle, translate, scale, shear, InterpolationMode.NEAREST, center
297+
)
299298

300299
if center is None:
301300
height, width = image_size
302-
center_f = [width * 0.5, height * 0.5]
303-
else:
304-
center_f = [float(c) for c in center]
301+
center = [width * 0.5, height * 0.5]
302+
303+
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
304+
device = bounding_box.device
305305

306-
translate_f = [float(t) for t in translate]
307306
affine_matrix = torch.tensor(
308-
_get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear, inverted=False),
307+
_get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False),
309308
dtype=dtype,
310309
device=device,
311310
).view(2, 3)
@@ -351,7 +350,7 @@ def affine_bounding_box(
351350
bounding_box: torch.Tensor,
352351
format: features.BoundingBoxFormat,
353352
image_size: Tuple[int, int],
354-
angle: float,
353+
angle: Union[int, float],
355354
translate: List[float],
356355
scale: float,
357356
shear: List[float],
@@ -373,7 +372,7 @@ def affine_bounding_box(
373372

374373
def affine_mask(
375374
mask: torch.Tensor,
376-
angle: float,
375+
angle: Union[int, float],
377376
translate: List[float],
378377
scale: float,
379378
shear: List[float],
@@ -419,14 +418,15 @@ def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT:
419418

420419
def affine(
421420
inpt: features.InputTypeJIT,
422-
angle: float,
421+
angle: Union[int, float],
423422
translate: List[float],
424423
scale: float,
425424
shear: List[float],
426425
interpolation: InterpolationMode = InterpolationMode.NEAREST,
427426
fill: features.FillTypeJIT = None,
428427
center: Optional[List[float]] = None,
429428
) -> features.InputTypeJIT:
429+
# TODO: consider deprecating integers from angle and shear on the future
430430
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
431431
return affine_image_tensor(
432432
inpt,
@@ -528,7 +528,16 @@ def rotate_bounding_box(
528528
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
529529
).view(-1, 4)
530530

531-
out_bboxes = _affine_bounding_box_xyxy(bounding_box, image_size, angle=-angle, center=center, expand=expand)
531+
out_bboxes = _affine_bounding_box_xyxy(
532+
bounding_box,
533+
image_size,
534+
angle=-angle,
535+
translate=[0.0, 0.0],
536+
scale=1.0,
537+
shear=[0.0, 0.0],
538+
center=center,
539+
expand=expand,
540+
)
532541

533542
return convert_format_bounding_box(
534543
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False

torchvision/prototype/transforms/functional/_misc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def normalize(
2323
def gaussian_blur_image_tensor(
2424
img: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
2525
) -> torch.Tensor:
26+
# TODO: consider deprecating integers from sigma on the future
2627
if isinstance(kernel_size, int):
2728
kernel_size = [kernel_size, kernel_size]
2829
if len(kernel_size) != 2:

0 commit comments

Comments
 (0)