Skip to content

Singular Sanitize BoundingBox #7316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gallery/plot_transforms_v2_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ def load_example_coco_detection_dataset(**kwargs):
transforms.RandomHorizontalFlip(),
transforms.ToImageTensor(),
transforms.ConvertImageDtype(torch.float32),
transforms.SanitizeBoundingBoxes(),
transforms.SanitizeBoundingBox(),
]
)

########################################################################################################################
# .. note::
# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` transform is a no-op in this example, but it
# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBox` transform is a no-op in this example, but it
# should be placed at least once at the end of a detection pipeline to remove degenerate bounding boxes as well as
# the corresponding labels and optionally masks. It is particularly critical to add it if
# :class:`~torchvision.transforms.v2.RandomIoUCrop` was used.
Expand Down
26 changes: 13 additions & 13 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def test_common(self, transform, adapter, container_type, image_or_video, device
boxes=datapoints.BoundingBox([[0, 0, 0, 0]], format=format, spatial_size=(224, 244)),
labels=torch.tensor([3]),
)
assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4)
assert transforms.SanitizeBoundingBox()(sample)["boxes"].shape == (0, 4)

@parametrize(
[
Expand Down Expand Up @@ -1876,7 +1876,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
transforms.ConvertImageDtype(torch.float),
]
if sanitize:
t += [transforms.SanitizeBoundingBoxes()]
t += [transforms.SanitizeBoundingBox()]
t = transforms.Compose(t)

num_boxes = 5
Expand Down Expand Up @@ -1917,7 +1917,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
# ssd and ssdlite contain RandomIoUCrop which may "remove" some bbox. It
# doesn't remove them strictly speaking, it just marks some boxes as
# degenerate and those boxes will be later removed by
# SanitizeBoundingBoxes(), which we add to the pipelines if the sanitize
# SanitizeBoundingBox(), which we add to the pipelines if the sanitize
# param is True.
# Note that the values below are probably specific to the random seed
# set above (which is fine).
Expand Down Expand Up @@ -1989,7 +1989,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
img = sample.pop("image")
sample = (img, sample)

out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample)
out = transforms.SanitizeBoundingBox(min_size=min_size, labels_getter=labels_getter)(sample)

if sample_type is tuple:
out_image = out[0]
Expand Down Expand Up @@ -2023,13 +2023,13 @@ def test_sanitize_bounding_boxes_default_heuristic(key, sample_type):
sample = {key: labels, "another_key": "whatever"}
if sample_type is tuple:
sample = (None, sample, "whatever_again")
assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(sample) is labels
assert transforms.SanitizeBoundingBox._find_labels_default_heuristic(sample) is labels

if key.lower() != "labels":
# If "labels" is in the dict (case-insensitive),
# it takes precedence over other keys which would otherwise be a match
d = {key: "something_else", "labels": labels}
assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(d) is labels
assert transforms.SanitizeBoundingBox._find_labels_default_heuristic(d) is labels


def test_sanitize_bounding_boxes_errors():
Expand All @@ -2041,25 +2041,25 @@ def test_sanitize_bounding_boxes_errors():
)

with pytest.raises(ValueError, match="min_size must be >= 1"):
transforms.SanitizeBoundingBoxes(min_size=0)
transforms.SanitizeBoundingBox(min_size=0)
with pytest.raises(ValueError, match="labels_getter should either be a str"):
transforms.SanitizeBoundingBoxes(labels_getter=12)
transforms.SanitizeBoundingBox(labels_getter=12)

with pytest.raises(ValueError, match="Could not infer where the labels are"):
bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])}
transforms.SanitizeBoundingBoxes()(bad_labels_key)
transforms.SanitizeBoundingBox()(bad_labels_key)

with pytest.raises(ValueError, match="If labels_getter is a str or 'default'"):
not_a_dict = (good_bbox, torch.arange(good_bbox.shape[0]))
transforms.SanitizeBoundingBoxes()(not_a_dict)
transforms.SanitizeBoundingBox()(not_a_dict)

with pytest.raises(ValueError, match="must be a tensor"):
not_a_tensor = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0]).tolist()}
transforms.SanitizeBoundingBoxes()(not_a_tensor)
transforms.SanitizeBoundingBox()(not_a_tensor)

with pytest.raises(ValueError, match="Number of boxes"):
different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)}
transforms.SanitizeBoundingBoxes()(different_sizes)
transforms.SanitizeBoundingBox()(different_sizes)

with pytest.raises(ValueError, match="boxes must be of shape"):
bad_bbox = datapoints.BoundingBox( # batch with 2 elements
Expand All @@ -2071,7 +2071,7 @@ def test_sanitize_bounding_boxes_errors():
spatial_size=(20, 20),
)
different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])}
transforms.SanitizeBoundingBoxes()(different_sizes)
transforms.SanitizeBoundingBox()(different_sizes)


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,7 @@ def make_label(extra_dims, categories):
v2_transforms.Compose(
[
v2_transforms.RandomIoUCrop(),
v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]),
v2_transforms.SanitizeBoundingBox(labels_getter=lambda sample: sample[1]["labels"]),
]
),
{"with_mask": False},
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
TenCrop,
)
from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, SanitizeBoundingBoxes, ToDtype
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, SanitizeBoundingBox, ToDtype
from ._temporal import UniformTemporalSubsample
from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage

Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,7 @@ class RandomIoUCrop(Transform):

.. warning::
In order to properly remove the bounding boxes below the IoU threshold, `RandomIoUCrop`
must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes`, either immediately
must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBox`, either immediately
after or later in the transforms pipeline.

If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
Expand Down Expand Up @@ -1222,7 +1222,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:

if isinstance(output, datapoints.BoundingBox):
# We "mark" the invalid boxes as degenreate, and they can be
# removed by a later call to SanitizeBoundingBoxes()
# removed by a later call to SanitizeBoundingBox()
output[~params["is_within_crop_area"]] = 0

return output
Expand Down
6 changes: 3 additions & 3 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt.to(dtype=dtype)


class SanitizeBoundingBoxes(Transform):
class SanitizeBoundingBox(Transform):
# This removes boxes and their corresponding labels:
# - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1)
# - boxes with any coordinate outside the range of the image (negative, or > spatial_size)
Expand All @@ -269,7 +269,7 @@ def __init__(
elif callable(labels_getter):
self._labels_getter = labels_getter
elif isinstance(labels_getter, str):
self._labels_getter = lambda inputs: SanitizeBoundingBoxes._get_dict_or_second_tuple_entry(inputs)[
self._labels_getter = lambda inputs: SanitizeBoundingBox._get_dict_or_second_tuple_entry(inputs)[
labels_getter # type: ignore[index]
]
elif labels_getter is None:
Expand Down Expand Up @@ -300,7 +300,7 @@ def _get_dict_or_second_tuple_entry(inputs: Any) -> Mapping[str, Any]:
def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
# Tries to find a "labels" key, otherwise tries for the first key that contains "label" - case insensitive
# Returns None if nothing is found
inputs = SanitizeBoundingBoxes._get_dict_or_second_tuple_entry(inputs)
inputs = SanitizeBoundingBox._get_dict_or_second_tuple_entry(inputs)
candidate_key = None
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if key.lower() == "labels")
Expand Down