Skip to content

Commit 016bac9

Browse files
jdsgomesfacebook-github-bot
authored andcommitted
[fbsync] [proto] Added functional affine_segmentation_mask op (#5613)
Summary: * Added functional affine_bounding_box op with tests * Updated comments and added another test case * Update _geometry.py * Added affine_segmentation_mask with tests * Fixed device mismatch issue Added a cude/cpu test Reduced the number of test samples * Added test_correctness_affine_segmentation_mask_on_fixed_input * Updates according to the review * Replaced [None, ...] by [None, :] * Adressed review comments * Fixed formatting and more updates according to the review * Fixed bad merge (Note: this ignores all push blocking failures!) Reviewed By: datumbox Differential Revision: D35216766 fbshipit-source-id: d0ff4779f109bfcb0f6b52ba114e5104e200f242
1 parent 63b6086 commit 016bac9

File tree

3 files changed

+125
-1
lines changed

3 files changed

+125
-1
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,22 @@ def make_one_hot_labels(
138138
yield make_one_hot_label(extra_dims_)
139139

140140

141+
def make_segmentation_mask(size=None, *, num_categories=80, extra_dims=(), dtype=torch.long):
142+
size = size or torch.randint(16, 33, (2,)).tolist()
143+
shape = (*extra_dims, 1, *size)
144+
data = make_tensor(shape, low=0, high=num_categories, dtype=dtype)
145+
return features.SegmentationMask(data)
146+
147+
148+
def make_segmentation_masks(
149+
image_sizes=((16, 16), (7, 33), (31, 9)),
150+
dtypes=(torch.long,),
151+
extra_dims=((), (4,), (2, 3)),
152+
):
153+
for image_size, dtype, extra_dims_ in itertools.product(image_sizes, dtypes, extra_dims):
154+
yield make_segmentation_mask(size=image_size, dtype=dtype, extra_dims=extra_dims_)
155+
156+
141157
class SampleInput:
142158
def __init__(self, *args, **kwargs):
143159
self.args = args
@@ -212,7 +228,7 @@ def resize_bounding_box():
212228
@register_kernel_info_from_sample_inputs_fn
213229
def affine_image_tensor():
214230
for image, angle, translate, scale, shear in itertools.product(
215-
make_images(extra_dims=()),
231+
make_images(extra_dims=((), (4,))),
216232
[-87, 15, 90], # angle
217233
[5, -5], # translate
218234
[0.77, 1.27], # scale
@@ -248,6 +264,24 @@ def affine_bounding_box():
248264
)
249265

250266

267+
@register_kernel_info_from_sample_inputs_fn
268+
def affine_segmentation_mask():
269+
for image, angle, translate, scale, shear in itertools.product(
270+
make_segmentation_masks(extra_dims=((), (4,))),
271+
[-87, 15, 90], # angle
272+
[5, -5], # translate
273+
[0.77, 1.27], # scale
274+
[0, 12], # shear
275+
):
276+
yield SampleInput(
277+
image,
278+
angle=angle,
279+
translate=(translate, translate),
280+
scale=scale,
281+
shear=(shear, shear),
282+
)
283+
284+
251285
@register_kernel_info_from_sample_inputs_fn
252286
def rotate_bounding_box():
253287
for bounding_box, angle, expand, center in itertools.product(
@@ -444,6 +478,76 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
444478
np.testing.assert_allclose(out_box.cpu().numpy(), a_out_box)
445479

446480

481+
@pytest.mark.parametrize("angle", [-54, 56])
482+
@pytest.mark.parametrize("translate", [-7, 8])
483+
@pytest.mark.parametrize("scale", [0.89, 1.12])
484+
@pytest.mark.parametrize("shear", [4])
485+
@pytest.mark.parametrize("center", [None, (12, 14)])
486+
def test_correctness_affine_segmentation_mask(angle, translate, scale, shear, center):
487+
def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_):
488+
assert mask.ndim == 3 and mask.shape[0] == 1
489+
affine_matrix = _compute_affine_matrix(angle_, translate_, scale_, shear_, center_)
490+
inv_affine_matrix = np.linalg.inv(affine_matrix)
491+
inv_affine_matrix = inv_affine_matrix[:2, :]
492+
493+
expected_mask = torch.zeros_like(mask.cpu())
494+
for out_y in range(expected_mask.shape[1]):
495+
for out_x in range(expected_mask.shape[2]):
496+
output_pt = np.array([out_x + 0.5, out_y + 0.5, 1.0])
497+
input_pt = np.floor(np.dot(inv_affine_matrix, output_pt)).astype(np.int32)
498+
in_x, in_y = input_pt[:2]
499+
if 0 <= in_x < mask.shape[2] and 0 <= in_y < mask.shape[1]:
500+
expected_mask[0, out_y, out_x] = mask[0, in_y, in_x]
501+
return expected_mask.to(mask.device)
502+
503+
for mask in make_segmentation_masks(extra_dims=((), (4,))):
504+
output_mask = F.affine_segmentation_mask(
505+
mask,
506+
angle=angle,
507+
translate=(translate, translate),
508+
scale=scale,
509+
shear=(shear, shear),
510+
center=center,
511+
)
512+
if center is None:
513+
center = [s // 2 for s in mask.shape[-2:][::-1]]
514+
515+
if mask.ndim < 4:
516+
masks = [mask]
517+
else:
518+
masks = [m for m in mask]
519+
520+
expected_masks = []
521+
for mask in masks:
522+
expected_mask = _compute_expected_mask(mask, angle, (translate, translate), scale, (shear, shear), center)
523+
expected_masks.append(expected_mask)
524+
if len(expected_masks) > 1:
525+
expected_masks = torch.stack(expected_masks)
526+
else:
527+
expected_masks = expected_masks[0]
528+
torch.testing.assert_close(output_mask, expected_masks)
529+
530+
531+
@pytest.mark.parametrize("device", cpu_and_gpu())
532+
def test_correctness_affine_segmentation_mask_on_fixed_input(device):
533+
# Check transformation against known expected output and CPU/CUDA devices
534+
535+
# Create a fixed input segmentation mask with 2 square masks
536+
# in top-left, bottom-left corners
537+
mask = torch.zeros(1, 32, 32, dtype=torch.long, device=device)
538+
mask[0, 2:10, 2:10] = 1
539+
mask[0, 32 - 9 : 32 - 3, 3:9] = 2
540+
541+
# Rotate 90 degrees and scale
542+
expected_mask = torch.rot90(mask, k=-1, dims=(-2, -1))
543+
expected_mask = torch.nn.functional.interpolate(expected_mask[None, :].float(), size=(64, 64), mode="nearest")
544+
expected_mask = expected_mask[0, :, 16 : 64 - 16, 16 : 64 - 16].long()
545+
546+
out_mask = F.affine_segmentation_mask(mask, 90, [0.0, 0.0], 64.0 / 32.0, [0.0, 0.0])
547+
548+
torch.testing.assert_close(out_mask, expected_mask)
549+
550+
447551
@pytest.mark.parametrize("angle", range(-90, 90, 56))
448552
@pytest.mark.parametrize("expand", [True, False])
449553
@pytest.mark.parametrize("center", [None, (12, 14)])

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
affine_bounding_box,
5353
affine_image_tensor,
5454
affine_image_pil,
55+
affine_segmentation_mask,
5556
rotate_bounding_box,
5657
rotate_image_tensor,
5758
rotate_image_pil,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,25 @@ def affine_bounding_box(
294294
).view(original_shape)
295295

296296

297+
def affine_segmentation_mask(
298+
img: torch.Tensor,
299+
angle: float,
300+
translate: List[float],
301+
scale: float,
302+
shear: List[float],
303+
center: Optional[List[float]] = None,
304+
) -> torch.Tensor:
305+
return affine_image_tensor(
306+
img,
307+
angle=angle,
308+
translate=translate,
309+
scale=scale,
310+
shear=shear,
311+
interpolation=InterpolationMode.NEAREST,
312+
center=center,
313+
)
314+
315+
297316
def rotate_image_tensor(
298317
img: torch.Tensor,
299318
angle: float,

0 commit comments

Comments
 (0)