Skip to content

Commit 802b9b4

Browse files
author
Federico Pozzi
committed
test: improvements
1 parent 9a94978 commit 802b9b4

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
from torch import jit
1111
from torch.nn.functional import one_hot
1212
from torchvision.prototype import features
13+
from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding
1314
from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format
1415
from torchvision.transforms.functional import _get_perspective_coeffs
1516
from torchvision.transforms.functional_tensor import _max_value as get_max_value
1617

17-
1818
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
1919

2020

@@ -423,7 +423,7 @@ def center_crop_bounding_box():
423423

424424
def center_crop_segmentation_mask():
425425
for mask, output_size in itertools.product(
426-
make_segmentation_masks(),
426+
make_segmentation_masks(image_sizes=((16, 16), (7, 33), (31, 9)), extra_dims=((), (4,), (2, 3))),
427427
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
428428
):
429429
yield SampleInput(mask, output_size)
@@ -1348,10 +1348,20 @@ def _compute_expected_bbox(bbox, output_size_):
13481348

13491349

13501350
@pytest.mark.parametrize("device", cpu_and_gpu())
1351-
@pytest.mark.parametrize("output_size", [[4, 3], [4], [7, 7]])
1351+
@pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]])
13521352
def test_correctness_center_crop_segmentation_mask(device, output_size):
13531353
def _compute_expected_segmentation_mask(mask, output_size):
1354-
return F.center_crop_image_tensor(mask, output_size)
1354+
crop_height, crop_width = output_size if len(output_size) > 1 else [output_size[0], output_size[0]]
1355+
1356+
_, image_height, image_width = mask.shape
1357+
if crop_width > image_height or crop_height > image_width:
1358+
padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
1359+
mask = F.pad_image_tensor(mask, padding, fill=0)
1360+
1361+
left = round((image_width - crop_width) * 0.5)
1362+
top = round((image_height - crop_height) * 0.5)
1363+
1364+
return mask[:, top : top + crop_height, left : left + crop_width]
13551365

13561366
mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device)
13571367
actual = F.center_crop_segmentation_mask(mask, output_size)

0 commit comments

Comments
 (0)