diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 36d1677ede5..dac43717d30 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -370,6 +370,16 @@ def resized_crop_segmentation_mask(): yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size) +@register_kernel_info_from_sample_inputs_fn +def pad_segmentation_mask(): + for mask, padding, padding_mode in itertools.product( + make_segmentation_masks(), + [[1], [1, 1], [1, 1, 2, 2]], # padding + ["constant", "symmetric", "edge", "reflect"], # padding mode, + ): + yield SampleInput(mask, padding=padding, padding_mode=padding_mode) + + @pytest.mark.parametrize( "kernel", [ @@ -1031,3 +1041,47 @@ def _compute_expected(mask, top_, left_, height_, width_, size_): expected_mask = _compute_expected(in_mask, top, left, height, width, size) output_mask = F.resized_crop_segmentation_mask(in_mask, top, left, height, width, size) torch.testing.assert_close(output_mask, expected_mask) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +def test_correctness_pad_segmentation_mask_on_fixed_input(device): + mask = torch.ones((1, 3, 3), dtype=torch.long, device=device) + + out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1]) + + expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device) + expected_mask[:, 1:-1, 1:-1] = 1 + torch.testing.assert_close(out_mask, expected_mask) + + +@pytest.mark.parametrize("padding", [[1, 2, 3, 4], [1], 1, [1, 2]]) +def test_correctness_pad_segmentation_mask(padding): + def _compute_expected_mask(): + def parse_padding(): + if isinstance(padding, int): + return [padding] * 4 + if isinstance(padding, list): + if len(padding) == 1: + return padding * 4 + if len(padding) == 2: + return padding * 2 # [left, up, right, down] + + return padding + + h, w = mask.shape[-2], mask.shape[-1] + pad_left, pad_up, pad_right, pad_down = parse_padding() + + new_h = h + pad_up + pad_down + new_w = w + pad_left + pad_right + + new_shape = (*mask.shape[:-2], new_h, new_w) if len(mask.shape) > 2 else (new_h, new_w) + expected_mask = torch.zeros(new_shape, dtype=torch.long) + expected_mask[..., pad_up:-pad_down, pad_left:-pad_right] = mask + + return expected_mask + + for mask in make_segmentation_masks(): + out_mask = F.pad_segmentation_mask(mask, padding, "constant") + + expected_mask = _compute_expected_mask() + torch.testing.assert_close(out_mask, expected_mask) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index dfbc81baea3..c13a94035ea 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -62,6 +62,7 @@ pad_bounding_box, pad_image_tensor, pad_image_pil, + pad_segmentation_mask, crop_bounding_box, crop_image_tensor, crop_image_pil, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 5f9e77fdbf4..602f865f724 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -396,6 +396,20 @@ def rotate_segmentation_mask( pad_image_pil = _FP.pad +def pad_segmentation_mask( + segmentation_mask: torch.Tensor, padding: List[int], padding_mode: str = "constant" +) -> torch.Tensor: + num_masks, height, width = segmentation_mask.shape[-3:] + extra_dims = segmentation_mask.shape[:-3] + + padded_mask = pad_image_tensor( + img=segmentation_mask.view(-1, num_masks, height, width), padding=padding, fill=0, padding_mode=padding_mode + ) + + new_height, new_width = padded_mask.shape[-2:] + return padded_mask.view(extra_dims + (num_masks, new_height, new_width)) + + def pad_bounding_box( bounding_box: torch.Tensor, padding: List[int], format: features.BoundingBoxFormat ) -> torch.Tensor: