|
1 | 1 | import functools
|
2 | 2 | import itertools
|
3 | 3 | import math
|
| 4 | +from unittest.mock import patch, Mock |
4 | 5 |
|
5 | 6 | import numpy as np
|
6 | 7 | import pytest
|
@@ -380,6 +381,15 @@ def pad_segmentation_mask():
|
380 | 381 | yield SampleInput(mask, padding=padding, padding_mode=padding_mode)
|
381 | 382 |
|
382 | 383 |
|
| 384 | +@register_kernel_info_from_sample_inputs_fn |
| 385 | +def center_crop_segmentation_mask(): |
| 386 | + for mask, output_size in itertools.product( |
| 387 | + make_segmentation_masks(), |
| 388 | + [[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size |
| 389 | + ): |
| 390 | + yield SampleInput(mask, output_size) |
| 391 | + |
| 392 | + |
383 | 393 | @pytest.mark.parametrize(
|
384 | 394 | "kernel",
|
385 | 395 | [
|
@@ -1085,3 +1095,25 @@ def parse_padding():
|
1085 | 1095 |
|
1086 | 1096 | expected_mask = _compute_expected_mask()
|
1087 | 1097 | torch.testing.assert_close(out_mask, expected_mask)
|
| 1098 | + |
| 1099 | + |
| 1100 | +@pytest.mark.parametrize("device", cpu_and_gpu()) |
| 1101 | +def test_correctness_center_crop_segmentation_mask_on_fixed_input(device): |
| 1102 | + mask = torch.ones((1, 6, 6), dtype=torch.long, device=device) |
| 1103 | + mask[:, 1:5, 2:4] = 0 |
| 1104 | + |
| 1105 | + out_mask = F.center_crop_segmentation_mask(mask, [2]) |
| 1106 | + expected_mask = torch.zeros((1, 4, 2), dtype=torch.long, device=device) |
| 1107 | + torch.testing.assert_close(out_mask, expected_mask) |
| 1108 | + |
| 1109 | + |
| 1110 | +@pytest.mark.parametrize("output_size", [[4, 3], [4], [7, 7]]) |
| 1111 | +@patch("torchvision.prototype.transforms.functional._geometry.center_crop_image_tensor") |
| 1112 | +def test_correctness_center_crop_segmentation_mask(center_crop_mock, output_size): |
| 1113 | + mask, expected = Mock(spec=torch.Tensor), Mock(spec=torch.Tensor) |
| 1114 | + center_crop_mock.return_value = expected |
| 1115 | + |
| 1116 | + out_mask = F.center_crop_segmentation_mask(mask, output_size) |
| 1117 | + |
| 1118 | + center_crop_mock.assert_called_once_with(img=mask, output_size=output_size) |
| 1119 | + assert expected is out_mask |
0 commit comments