@@ -420,6 +420,8 @@ def center_crop_bounding_box():
420
420
yield SampleInput (
421
421
bounding_box , format = bounding_box .format , output_size = output_size , image_size = bounding_box .image_size
422
422
)
423
+
424
+
423
425
def center_crop_segmentation_mask ():
424
426
for mask , output_size in itertools .product (
425
427
make_segmentation_masks (),
@@ -1344,6 +1346,8 @@ def _compute_expected_bbox(bbox, output_size_):
1344
1346
else :
1345
1347
expected_bboxes = expected_bboxes [0 ]
1346
1348
torch .testing .assert_close (output_boxes , expected_bboxes )
1349
+
1350
+
1347
1351
def test_correctness_center_crop_segmentation_mask_on_fixed_input (device ):
1348
1352
mask = torch .ones ((1 , 6 , 6 ), dtype = torch .long , device = device )
1349
1353
mask [:, 1 :5 , 2 :4 ] = 0
@@ -1353,9 +1357,27 @@ def test_correctness_center_crop_segmentation_mask_on_fixed_input(device):
1353
1357
torch .testing .assert_close (out_mask , expected_mask )
1354
1358
1355
1359
1360
+ @pytest .mark .parametrize ("output_size" , [[4 , 3 ], [4 ]])
1361
+ def test_correctness_center_crop_segmentation_mask (output_size ):
1362
+ def _compute_expected_segmentation_mask ():
1363
+ _output_size = output_size if isinstance (output_size , tuple ) else (output_size , output_size )
1364
+
1365
+ _ , h , w = mask .shape
1366
+ left = w - _output_size [0 ]
1367
+ top = h - _output_size [1 ]
1368
+
1369
+ return mask [:, top : _output_size [1 ], left : _output_size [0 ]]
1370
+
1371
+ mask = torch .randint (0 , 2 , shape = (1 , 6 , 6 ))
1372
+ actual = F .center_crop_segmentation_mask (mask , output_size )
1373
+
1374
+ expected = _compute_expected_segmentation_mask ()
1375
+ assert expected == actual
1376
+
1377
+
1356
1378
@pytest .mark .parametrize ("output_size" , [[4 , 3 ], [4 ], [7 , 7 ]])
1357
1379
@patch ("torchvision.prototype.transforms.functional._geometry.center_crop_image_tensor" )
1358
- def test_correctness_center_crop_segmentation_mask (center_crop_mock , output_size ):
1380
+ def test_correctness_center_crop_segmentation_mask_mock (center_crop_mock , output_size ):
1359
1381
mask , expected = Mock (spec = torch .Tensor ), Mock (spec = torch .Tensor )
1360
1382
center_crop_mock .return_value = expected
1361
1383
0 commit comments