Skip to content

Commit 51d0361

Browse files
author
Federico Pozzi
committed
feat: add functional center crop on mask
1 parent 970ba35 commit 51d0361

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import functools
22
import itertools
33
import math
4+
from unittest.mock import patch, Mock
45

56
import numpy as np
67
import pytest
@@ -380,6 +381,15 @@ def pad_segmentation_mask():
380381
yield SampleInput(mask, padding=padding, padding_mode=padding_mode)
381382

382383

384+
@register_kernel_info_from_sample_inputs_fn
385+
def center_crop_segmentation_mask():
386+
for mask, output_size in itertools.product(
387+
make_segmentation_masks(),
388+
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
389+
):
390+
yield SampleInput(mask, output_size)
391+
392+
383393
@pytest.mark.parametrize(
384394
"kernel",
385395
[
@@ -1085,3 +1095,25 @@ def parse_padding():
10851095

10861096
expected_mask = _compute_expected_mask()
10871097
torch.testing.assert_close(out_mask, expected_mask)
1098+
1099+
1100+
@pytest.mark.parametrize("device", cpu_and_gpu())
1101+
def test_correctness_center_crop_segmentation_mask_on_fixed_input(device):
1102+
mask = torch.ones((1, 6, 6), dtype=torch.long, device=device)
1103+
mask[:, 1:5, 2:4] = 0
1104+
1105+
out_mask = F.center_crop_segmentation_mask(mask, [2])
1106+
expected_mask = torch.zeros((1, 4, 2), dtype=torch.long, device=device)
1107+
torch.testing.assert_close(out_mask, expected_mask)
1108+
1109+
1110+
@pytest.mark.parametrize("output_size", [[4, 3], [4], [7, 7]])
1111+
@patch("torchvision.prototype.transforms.functional._geometry.center_crop_image_tensor")
1112+
def test_correctness_center_crop_segmentation_mask(center_crop_mock, output_size):
1113+
mask, expected = Mock(spec=torch.Tensor), Mock(spec=torch.Tensor)
1114+
center_crop_mock.return_value = expected
1115+
1116+
out_mask = F.center_crop_segmentation_mask(mask, output_size)
1117+
1118+
center_crop_mock.assert_called_once_with(img=mask, output_size=output_size)
1119+
assert expected is out_mask

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
resize_image_tensor,
4646
resize_image_pil,
4747
resize_segmentation_mask,
48+
center_crop_segmentation_mask,
4849
center_crop_image_tensor,
4950
center_crop_image_pil,
5051
resized_crop_bounding_box,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,10 @@ def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.I
530530
return crop_image_pil(img, crop_top, crop_left, crop_height, crop_width)
531531

532532

533+
def center_crop_segmentation_mask(segmentation_mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
534+
return center_crop_image_tensor(img=segmentation_mask, output_size=output_size)
535+
536+
533537
def resized_crop_image_tensor(
534538
img: torch.Tensor,
535539
top: int,

0 commit comments

Comments
 (0)