Skip to content

Commit 56b0497

Browse files
authored
Add more tests for bounding boxes (#7276)
1 parent dfa81ce commit 56b0497

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

test/test_prototype_transforms.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,21 @@ def test_common(self, transform, adapter, container_type, image_or_video, device
269269
else:
270270
assert output_item is input_item
271271

272+
if isinstance(input_item, datapoints.BoundingBox) and not isinstance(
273+
transform, transforms.ConvertBoundingBoxFormat
274+
):
275+
assert output_item.format == input_item.format
276+
277+
# Enforce that the transform does not turn a degenerate box marked by RandomIoUCrop (or any other future
278+
# transform that does this), back into a valid one.
279+
# TODO: we should test that against all degenerate boxes above
280+
for format in list(datapoints.BoundingBoxFormat):
281+
sample = dict(
282+
boxes=datapoints.BoundingBox([[0, 0, 0, 0]], format=format, spatial_size=(224, 244)),
283+
labels=torch.tensor([3]),
284+
)
285+
assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4)
286+
272287
@parametrize(
273288
[
274289
(

test/test_prototype_transforms_functional.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,22 @@ def test_unkown_type(self, info):
508508
with pytest.raises(TypeError, match=re.escape(str(type(unkown_input)))):
509509
info.dispatcher(unkown_input, *other_args, **kwargs)
510510

511+
@make_info_args_kwargs_parametrization(
512+
[
513+
info
514+
for info in DISPATCHER_INFOS
515+
if datapoints.BoundingBox in info.kernels and info.dispatcher is not F.convert_format_bounding_box
516+
],
517+
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.BoundingBox),
518+
)
519+
def test_bounding_box_format_consistency(self, info, args_kwargs):
520+
(bounding_box, *other_args), kwargs = args_kwargs.load()
521+
format = bounding_box.format
522+
523+
output = info.dispatcher(bounding_box, *other_args, **kwargs)
524+
525+
assert output.format == format
526+
511527

512528
@pytest.mark.parametrize(
513529
("alias", "target"),

0 commit comments

Comments
 (0)