Skip to content

Commit 6dfc965

Browse files
committed
revert kernel fixes
1 parent fe8e1bf commit 6dfc965

File tree

1 file changed

+20
-22
lines changed

1 file changed

+20
-22
lines changed

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -279,26 +279,33 @@ def _affine_bounding_box_xyxy(
279279
bounding_box: torch.Tensor,
280280
image_size: Tuple[int, int],
281281
angle: float,
282-
translate: List[float],
283-
scale: float,
284-
shear: List[float],
282+
translate: Optional[List[float]] = None,
283+
scale: Optional[float] = None,
284+
shear: Optional[List[float]] = None,
285285
center: Optional[List[float]] = None,
286286
expand: bool = False,
287287
) -> torch.Tensor:
288-
# This is just a dummy value to avoid raising an error in `_affine_parse_args` although we don't have an
289-
# interpolation mode for bounding boxes.
290-
interpolation = InterpolationMode.NEAREST
291-
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
288+
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
289+
device = bounding_box.device
290+
291+
if translate is None:
292+
translate = [0.0, 0.0]
293+
294+
if scale is None:
295+
scale = 1.0
296+
297+
if shear is None:
298+
shear = [0.0, 0.0]
292299

293300
if center is None:
294301
height, width = image_size
295-
center = [width * 0.5, height * 0.5]
296-
297-
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
298-
device = bounding_box.device
302+
center_f = [width * 0.5, height * 0.5]
303+
else:
304+
center_f = [float(c) for c in center]
299305

306+
translate_f = [float(t) for t in translate]
300307
affine_matrix = torch.tensor(
301-
_get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False),
308+
_get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear, inverted=False),
302309
dtype=dtype,
303310
device=device,
304311
).view(2, 3)
@@ -521,16 +528,7 @@ def rotate_bounding_box(
521528
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
522529
).view(-1, 4)
523530

524-
out_bboxes = _affine_bounding_box_xyxy(
525-
bounding_box,
526-
image_size,
527-
angle=-angle,
528-
translate=[0.0, 0.0],
529-
scale=1.0,
530-
shear=[0.0, 0.0],
531-
center=center,
532-
expand=expand,
533-
)
531+
out_bboxes = _affine_bounding_box_xyxy(bounding_box, image_size, angle=-angle, center=center, expand=expand)
534532

535533
return convert_format_bounding_box(
536534
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False

0 commit comments

Comments
 (0)