Skip to content

Commit a7d17ec

Browse files
committed
Added tests for rotation, affine and zoom transforms
1 parent ec67676 commit a7d17ec

File tree

3 files changed

+21
-8
lines changed

3 files changed

+21
-8
lines changed

test/test_prototype_transforms.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ class TestSmoke:
7373
transforms.RandomHorizontalFlip(),
7474
transforms.Pad(5),
7575
transforms.RandomZoomOut(),
76+
transforms.RandomRotation(degrees=(-45, 45)),
77+
transforms.RandomAffine(degrees=(-45, 45)),
7678
)
7779
def test_common(self, transform, input):
7880
transform(input)

torchvision/prototype/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
RandomVerticalFlip,
1818
Pad,
1919
RandomZoomOut,
20+
RandomRotation,
21+
RandomAffine,
2022
)
2123
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
2224
from ._misc import Identity, Normalize, ToDtype, Lambda

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -210,18 +210,22 @@ def affine_image_tensor(
210210
fill: Optional[List[float]] = None,
211211
center: Optional[List[float]] = None,
212212
) -> torch.Tensor:
213+
num_channels, height, width = img.shape[-3:]
214+
extra_dims = img.shape[:-3]
215+
img = img.view(-1, num_channels, height, width)
216+
213217
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
214218

215219
center_f = [0.0, 0.0]
216220
if center is not None:
217-
_, height, width = get_dimensions_image_tensor(img)
218221
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
219222
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
220223

221224
translate_f = [1.0 * t for t in translate]
222225
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
223226

224-
return _FT.affine(img, matrix, interpolation=interpolation.value, fill=fill)
227+
output = _FT.affine(img, matrix, interpolation=interpolation.value, fill=fill)
228+
return output.view(extra_dims + (num_channels, height, width))
225229

226230

227231
def affine_image_pil(
@@ -344,15 +348,15 @@ def affine_bounding_box(
344348

345349

346350
def affine_segmentation_mask(
347-
img: torch.Tensor,
351+
mask: torch.Tensor,
348352
angle: float,
349353
translate: List[float],
350354
scale: float,
351355
shear: List[float],
352356
center: Optional[List[float]] = None,
353357
) -> torch.Tensor:
354358
return affine_image_tensor(
355-
img,
359+
mask,
356360
angle=angle,
357361
translate=translate,
358362
scale=scale,
@@ -423,6 +427,10 @@ def rotate_image_tensor(
423427
fill: Optional[List[float]] = None,
424428
center: Optional[List[float]] = None,
425429
) -> torch.Tensor:
430+
num_channels, height, width = img.shape[-3:]
431+
extra_dims = img.shape[:-3]
432+
img = img.view(-1, num_channels, height, width)
433+
426434
center_f = [0.0, 0.0]
427435
if center is not None:
428436
if expand:
@@ -435,7 +443,8 @@ def rotate_image_tensor(
435443
# due to current incoherence of rotation angle direction between affine and rotate implementations
436444
# we need to set -angle.
437445
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
438-
return _FT.rotate(img, matrix, interpolation=interpolation.value, expand=expand, fill=fill)
446+
output = _FT.rotate(img, matrix, interpolation=interpolation.value, expand=expand, fill=fill)
447+
return output.view(extra_dims + (num_channels, height, width))
439448

440449

441450
def rotate_image_pil(
@@ -518,15 +527,15 @@ def rotate(
518527
def pad_image_tensor(
519528
img: torch.Tensor, padding: Union[int, List[int]], fill: Union[int, float] = 0, padding_mode: str = "constant"
520529
) -> torch.Tensor:
521-
num_masks, height, width = img.shape[-3:]
530+
num_channels, height, width = img.shape[-3:]
522531
extra_dims = img.shape[:-3]
523532

524533
padded_image = _FT.pad(
525-
img=img.view(-1, num_masks, height, width), padding=padding, fill=fill, padding_mode=padding_mode
534+
img=img.view(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode
526535
)
527536

528537
new_height, new_width = padded_image.shape[-2:]
529-
return padded_image.view(extra_dims + (num_masks, new_height, new_width))
538+
return padded_image.view(extra_dims + (num_channels, new_height, new_width))
530539

531540

532541
# TODO: This should be removed once pytorch pad supports non-scalar padding values

0 commit comments

Comments
 (0)