Skip to content

Commit e93cd31

Browse files
author
Federico Pozzi
committed
test: add correctness center crop with random segmentation mask
1 parent c9e4a74 commit e93cd31

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,8 @@ def center_crop_bounding_box():
420420
yield SampleInput(
421421
bounding_box, format=bounding_box.format, output_size=output_size, image_size=bounding_box.image_size
422422
)
423+
424+
423425
def center_crop_segmentation_mask():
424426
for mask, output_size in itertools.product(
425427
make_segmentation_masks(),
@@ -1344,6 +1346,8 @@ def _compute_expected_bbox(bbox, output_size_):
13441346
else:
13451347
expected_bboxes = expected_bboxes[0]
13461348
torch.testing.assert_close(output_boxes, expected_bboxes)
1349+
1350+
13471351
def test_correctness_center_crop_segmentation_mask_on_fixed_input(device):
13481352
mask = torch.ones((1, 6, 6), dtype=torch.long, device=device)
13491353
mask[:, 1:5, 2:4] = 0
@@ -1353,9 +1357,27 @@ def test_correctness_center_crop_segmentation_mask_on_fixed_input(device):
13531357
torch.testing.assert_close(out_mask, expected_mask)
13541358

13551359

1360+
@pytest.mark.parametrize("output_size", [[4, 3], [4]])
1361+
def test_correctness_center_crop_segmentation_mask(output_size):
1362+
def _compute_expected_segmentation_mask():
1363+
_output_size = output_size if isinstance(output_size, tuple) else (output_size, output_size)
1364+
1365+
_, h, w = mask.shape
1366+
left = w - _output_size[0]
1367+
top = h - _output_size[1]
1368+
1369+
return mask[:, top : _output_size[1], left : _output_size[0]]
1370+
1371+
mask = torch.randint(0, 2, shape=(1, 6, 6))
1372+
actual = F.center_crop_segmentation_mask(mask, output_size)
1373+
1374+
expected = _compute_expected_segmentation_mask()
1375+
assert expected == actual
1376+
1377+
13561378
@pytest.mark.parametrize("output_size", [[4, 3], [4], [7, 7]])
13571379
@patch("torchvision.prototype.transforms.functional._geometry.center_crop_image_tensor")
1358-
def test_correctness_center_crop_segmentation_mask(center_crop_mock, output_size):
1380+
def test_correctness_center_crop_segmentation_mask_mock(center_crop_mock, output_size):
13591381
mask, expected = Mock(spec=torch.Tensor), Mock(spec=torch.Tensor)
13601382
center_crop_mock.return_value = expected
13611383

0 commit comments

Comments
 (0)