|
10 | 10 | from torch import jit
|
11 | 11 | from torch.nn.functional import one_hot
|
12 | 12 | from torchvision.prototype import features
|
| 13 | +from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding |
13 | 14 | from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format
|
14 | 15 | from torchvision.transforms.functional import _get_perspective_coeffs
|
15 | 16 | from torchvision.transforms.functional_tensor import _max_value as get_max_value
|
16 | 17 |
|
17 |
| - |
18 | 18 | make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
|
19 | 19 |
|
20 | 20 |
|
@@ -423,7 +423,7 @@ def center_crop_bounding_box():
|
423 | 423 |
|
424 | 424 | def center_crop_segmentation_mask():
|
425 | 425 | for mask, output_size in itertools.product(
|
426 |
| - make_segmentation_masks(), |
| 426 | + make_segmentation_masks(image_sizes=((16, 16), (7, 33), (31, 9)), extra_dims=((), (4,), (2, 3))), |
427 | 427 | [[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
|
428 | 428 | ):
|
429 | 429 | yield SampleInput(mask, output_size)
|
@@ -1348,10 +1348,20 @@ def _compute_expected_bbox(bbox, output_size_):
|
1348 | 1348 |
|
1349 | 1349 |
|
1350 | 1350 | @pytest.mark.parametrize("device", cpu_and_gpu())
|
1351 |
| -@pytest.mark.parametrize("output_size", [[4, 3], [4], [7, 7]]) |
| 1351 | +@pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]]) |
1352 | 1352 | def test_correctness_center_crop_segmentation_mask(device, output_size):
|
1353 | 1353 | def _compute_expected_segmentation_mask(mask, output_size):
|
1354 |
| - return F.center_crop_image_tensor(mask, output_size) |
| 1354 | + crop_height, crop_width = output_size if len(output_size) > 1 else [output_size[0], output_size[0]] |
| 1355 | + |
| 1356 | + _, image_height, image_width = mask.shape |
| 1357 | + if crop_width > image_height or crop_height > image_width: |
| 1358 | + padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) |
| 1359 | + mask = F.pad_image_tensor(mask, padding, fill=0) |
| 1360 | + |
| 1361 | + left = round((image_width - crop_width) * 0.5) |
| 1362 | + top = round((image_height - crop_height) * 0.5) |
| 1363 | + |
| 1364 | + return mask[:, top : top + crop_height, left : left + crop_width] |
1355 | 1365 |
|
1356 | 1366 | mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device)
|
1357 | 1367 | actual = F.center_crop_segmentation_mask(mask, output_size)
|
|
0 commit comments