diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index b24e9a41ff7..e0376eeb5c6 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -10,11 +10,11 @@ from torch import jit from torch.nn.functional import one_hot from torchvision.prototype import features +from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.functional_tensor import _max_value as get_max_value - make_tensor = functools.partial(torch.testing.make_tensor, device="cpu") @@ -421,6 +421,14 @@ def center_crop_bounding_box(): ) +def center_crop_segmentation_mask(): + for mask, output_size in itertools.product( + make_segmentation_masks(image_sizes=((16, 16), (7, 33), (31, 9))), + [[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size + ): + yield SampleInput(mask, output_size) + + @pytest.mark.parametrize( "kernel", [ @@ -1337,3 +1345,26 @@ def _compute_expected_bbox(bbox, output_size_): else: expected_bboxes = expected_bboxes[0] torch.testing.assert_close(output_boxes, expected_bboxes) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]]) +def test_correctness_center_crop_segmentation_mask(device, output_size): + def _compute_expected_segmentation_mask(mask, output_size): + crop_height, crop_width = output_size if len(output_size) > 1 else [output_size[0], output_size[0]] + + _, image_height, image_width = mask.shape + if crop_width > image_height or crop_height > image_width: + padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) + mask = F.pad_image_tensor(mask, padding, fill=0) + + left = round((image_width - crop_width) * 0.5) + top = round((image_height - crop_height) * 0.5) + + return mask[:, top : top + crop_height, left : left + crop_width] + + mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device) + actual = F.center_crop_segmentation_mask(mask, output_size) + + expected = _compute_expected_segmentation_mask(mask, output_size) + torch.testing.assert_close(expected, actual) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index cac2946b46e..2a6c7dce516 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -46,6 +46,7 @@ resize_image_pil, resize_segmentation_mask, center_crop_bounding_box, + center_crop_segmentation_mask, center_crop_image_tensor, center_crop_image_pil, resized_crop_bounding_box, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 7d6e26451c9..00c8a59e395 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -630,6 +630,10 @@ def center_crop_bounding_box( return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left) +def center_crop_segmentation_mask(segmentation_mask: torch.Tensor, output_size: List[int]) -> torch.Tensor: + return center_crop_image_tensor(img=segmentation_mask, output_size=output_size) + + def resized_crop_image_tensor( img: torch.Tensor, top: int,