Skip to content

feat: add functional pad on segmentation mask #5866

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,16 @@ def resized_crop_segmentation_mask():
yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size)


@register_kernel_info_from_sample_inputs_fn
def pad_segmentation_mask():
for mask, padding, padding_mode in itertools.product(
make_segmentation_masks(),
[[1], [1, 1], [1, 1, 2, 2]], # padding
["constant", "symmetric", "edge", "reflect"], # padding mode,
):
yield SampleInput(mask, padding=padding, padding_mode=padding_mode)


@pytest.mark.parametrize(
"kernel",
[
Expand Down Expand Up @@ -1031,3 +1041,47 @@ def _compute_expected(mask, top_, left_, height_, width_, size_):
expected_mask = _compute_expected(in_mask, top, left, height, width, size)
output_mask = F.resized_crop_segmentation_mask(in_mask, top, left, height, width, size)
torch.testing.assert_close(output_mask, expected_mask)


@pytest.mark.parametrize("device", cpu_and_gpu())
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)

out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1])

expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device)
expected_mask[:, 1:-1, 1:-1] = 1
torch.testing.assert_close(out_mask, expected_mask)


@pytest.mark.parametrize("padding", [[1, 2, 3, 4], [1], 1, [1, 2]])
def test_correctness_pad_segmentation_mask(padding):
def _compute_expected_mask():
def parse_padding():
if isinstance(padding, int):
return [padding] * 4
if isinstance(padding, list):
if len(padding) == 1:
return padding * 4
if len(padding) == 2:
return padding * 2 # [left, up, right, down]

return padding

h, w = mask.shape[-2], mask.shape[-1]
pad_left, pad_up, pad_right, pad_down = parse_padding()

new_h = h + pad_up + pad_down
new_w = w + pad_left + pad_right

new_shape = (*mask.shape[:-2], new_h, new_w) if len(mask.shape) > 2 else (new_h, new_w)
expected_mask = torch.zeros(new_shape, dtype=torch.long)
expected_mask[..., pad_up:-pad_down, pad_left:-pad_right] = mask

return expected_mask

for mask in make_segmentation_masks():
out_mask = F.pad_segmentation_mask(mask, padding, "constant")

expected_mask = _compute_expected_mask()
torch.testing.assert_close(out_mask, expected_mask)
Comment on lines +1084 to +1087
Copy link
Collaborator

@vfdev-5 vfdev-5 Apr 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking of how to test all other padding modes. Maybe we could just check output value from out_mask for two lines: one horizontal and one vertical instead of constructing full expected mask.
While checking the lines we still need to route the checks according to the padding mode.
@federicopozzi33 What do you think ?

Copy link
Contributor Author

@federicopozzi33 federicopozzi33 Apr 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems pretty reasonable.

Anyway, the code needed for constructing the two lines IMO still "mimics" (or re-implements) the padding operation (which is what I wanted to avoid, but it seems that there are no other options).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking of how to test all other padding modes.

I don't think we need a correctness check for them. Internally, torch.nn.functional.pad does the heavy lifting. Thus, if we rely on that giving us the correct behavior there is no need to check if the values of the padding are correct.

Copy link
Collaborator

@vfdev-5 vfdev-5 May 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier if we do not check correctness on the output, another option is to mock torch_pad and ensure that it is called with correct configuration otherwise the code is not covered. I'm talking about all other non-tested padding options.

1 change: 1 addition & 0 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
pad_bounding_box,
pad_image_tensor,
pad_image_pil,
pad_segmentation_mask,
crop_bounding_box,
crop_image_tensor,
crop_image_pil,
Expand Down
14 changes: 14 additions & 0 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,20 @@ def rotate_segmentation_mask(
pad_image_pil = _FP.pad


def pad_segmentation_mask(
segmentation_mask: torch.Tensor, padding: List[int], padding_mode: str = "constant"
) -> torch.Tensor:
num_masks, height, width = segmentation_mask.shape[-3:]
extra_dims = segmentation_mask.shape[:-3]

padded_mask = pad_image_tensor(
img=segmentation_mask.view(-1, num_masks, height, width), padding=padding, fill=0, padding_mode=padding_mode
)

new_height, new_width = padded_mask.shape[-2:]
return padded_mask.view(extra_dims + (num_masks, new_height, new_width))


def pad_bounding_box(
bounding_box: torch.Tensor, padding: List[int], format: features.BoundingBoxFormat
) -> torch.Tensor:
Expand Down