Skip to content

Commit b88c888

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Singular Sanitize BoundingBox (#7316)
Summary: Co-authored-by: Nicolas Hug <[email protected]> Reviewed By: vmoens Differential Revision: D44416560 fbshipit-source-id: db0ed0bfab6998a04650c2c2ec96dd31f2857649
1 parent ff3fbcb commit b88c888

File tree

6 files changed

+22
-22
lines changed

6 files changed

+22
-22
lines changed

gallery/plot_transforms_v2_e2e.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,13 @@ def load_example_coco_detection_dataset(**kwargs):
105105
transforms.RandomHorizontalFlip(),
106106
transforms.ToImageTensor(),
107107
transforms.ConvertImageDtype(torch.float32),
108-
transforms.SanitizeBoundingBoxes(),
108+
transforms.SanitizeBoundingBox(),
109109
]
110110
)
111111

112112
########################################################################################################################
113113
# .. note::
114-
# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` transform is a no-op in this example, but it
114+
# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBox` transform is a no-op in this example, but it
115115
# should be placed at least once at the end of a detection pipeline to remove degenerate bounding boxes as well as
116116
# the corresponding labels and optionally masks. It is particularly critical to add it if
117117
# :class:`~torchvision.transforms.v2.RandomIoUCrop` was used.

test/test_transforms_v2.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def test_common(self, transform, adapter, container_type, image_or_video, device
275275
boxes=datapoints.BoundingBox([[0, 0, 0, 0]], format=format, spatial_size=(224, 244)),
276276
labels=torch.tensor([3]),
277277
)
278-
assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4)
278+
assert transforms.SanitizeBoundingBox()(sample)["boxes"].shape == (0, 4)
279279

280280
@parametrize(
281281
[
@@ -1876,7 +1876,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
18761876
transforms.ConvertImageDtype(torch.float),
18771877
]
18781878
if sanitize:
1879-
t += [transforms.SanitizeBoundingBoxes()]
1879+
t += [transforms.SanitizeBoundingBox()]
18801880
t = transforms.Compose(t)
18811881

18821882
num_boxes = 5
@@ -1917,7 +1917,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
19171917
# ssd and ssdlite contain RandomIoUCrop which may "remove" some bbox. It
19181918
# doesn't remove them strictly speaking, it just marks some boxes as
19191919
# degenerate and those boxes will be later removed by
1920-
# SanitizeBoundingBoxes(), which we add to the pipelines if the sanitize
1920+
# SanitizeBoundingBox(), which we add to the pipelines if the sanitize
19211921
# param is True.
19221922
# Note that the values below are probably specific to the random seed
19231923
# set above (which is fine).
@@ -1989,7 +1989,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
19891989
img = sample.pop("image")
19901990
sample = (img, sample)
19911991

1992-
out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample)
1992+
out = transforms.SanitizeBoundingBox(min_size=min_size, labels_getter=labels_getter)(sample)
19931993

19941994
if sample_type is tuple:
19951995
out_image = out[0]
@@ -2023,13 +2023,13 @@ def test_sanitize_bounding_boxes_default_heuristic(key, sample_type):
20232023
sample = {key: labels, "another_key": "whatever"}
20242024
if sample_type is tuple:
20252025
sample = (None, sample, "whatever_again")
2026-
assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(sample) is labels
2026+
assert transforms.SanitizeBoundingBox._find_labels_default_heuristic(sample) is labels
20272027

20282028
if key.lower() != "labels":
20292029
# If "labels" is in the dict (case-insensitive),
20302030
# it takes precedence over other keys which would otherwise be a match
20312031
d = {key: "something_else", "labels": labels}
2032-
assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(d) is labels
2032+
assert transforms.SanitizeBoundingBox._find_labels_default_heuristic(d) is labels
20332033

20342034

20352035
def test_sanitize_bounding_boxes_errors():
@@ -2041,25 +2041,25 @@ def test_sanitize_bounding_boxes_errors():
20412041
)
20422042

20432043
with pytest.raises(ValueError, match="min_size must be >= 1"):
2044-
transforms.SanitizeBoundingBoxes(min_size=0)
2044+
transforms.SanitizeBoundingBox(min_size=0)
20452045
with pytest.raises(ValueError, match="labels_getter should either be a str"):
2046-
transforms.SanitizeBoundingBoxes(labels_getter=12)
2046+
transforms.SanitizeBoundingBox(labels_getter=12)
20472047

20482048
with pytest.raises(ValueError, match="Could not infer where the labels are"):
20492049
bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])}
2050-
transforms.SanitizeBoundingBoxes()(bad_labels_key)
2050+
transforms.SanitizeBoundingBox()(bad_labels_key)
20512051

20522052
with pytest.raises(ValueError, match="If labels_getter is a str or 'default'"):
20532053
not_a_dict = (good_bbox, torch.arange(good_bbox.shape[0]))
2054-
transforms.SanitizeBoundingBoxes()(not_a_dict)
2054+
transforms.SanitizeBoundingBox()(not_a_dict)
20552055

20562056
with pytest.raises(ValueError, match="must be a tensor"):
20572057
not_a_tensor = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0]).tolist()}
2058-
transforms.SanitizeBoundingBoxes()(not_a_tensor)
2058+
transforms.SanitizeBoundingBox()(not_a_tensor)
20592059

20602060
with pytest.raises(ValueError, match="Number of boxes"):
20612061
different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)}
2062-
transforms.SanitizeBoundingBoxes()(different_sizes)
2062+
transforms.SanitizeBoundingBox()(different_sizes)
20632063

20642064
with pytest.raises(ValueError, match="boxes must be of shape"):
20652065
bad_bbox = datapoints.BoundingBox( # batch with 2 elements
@@ -2071,7 +2071,7 @@ def test_sanitize_bounding_boxes_errors():
20712071
spatial_size=(20, 20),
20722072
)
20732073
different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])}
2074-
transforms.SanitizeBoundingBoxes()(different_sizes)
2074+
transforms.SanitizeBoundingBox()(different_sizes)
20752075

20762076

20772077
@pytest.mark.parametrize(

test/test_transforms_v2_consistency.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1099,7 +1099,7 @@ def make_label(extra_dims, categories):
10991099
v2_transforms.Compose(
11001100
[
11011101
v2_transforms.RandomIoUCrop(),
1102-
v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]),
1102+
v2_transforms.SanitizeBoundingBox(labels_getter=lambda sample: sample[1]["labels"]),
11031103
]
11041104
),
11051105
{"with_mask": False},

torchvision/transforms/v2/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
TenCrop,
4141
)
4242
from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype
43-
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, SanitizeBoundingBoxes, ToDtype
43+
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, SanitizeBoundingBox, ToDtype
4444
from ._temporal import UniformTemporalSubsample
4545
from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage
4646

torchvision/transforms/v2/_geometry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,7 +1114,7 @@ class RandomIoUCrop(Transform):
11141114
11151115
.. warning::
11161116
In order to properly remove the bounding boxes below the IoU threshold, `RandomIoUCrop`
1117-
must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes`, either immediately
1117+
must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBox`, either immediately
11181118
after or later in the transforms pipeline.
11191119
11201120
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
@@ -1222,7 +1222,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
12221222

12231223
if isinstance(output, datapoints.BoundingBox):
12241224
# We "mark" the invalid boxes as degenreate, and they can be
1225-
# removed by a later call to SanitizeBoundingBoxes()
1225+
# removed by a later call to SanitizeBoundingBox()
12261226
output[~params["is_within_crop_area"]] = 0
12271227

12281228
return output

torchvision/transforms/v2/_misc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
246246
return inpt.to(dtype=dtype)
247247

248248

249-
class SanitizeBoundingBoxes(Transform):
249+
class SanitizeBoundingBox(Transform):
250250
# This removes boxes and their corresponding labels:
251251
# - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1)
252252
# - boxes with any coordinate outside the range of the image (negative, or > spatial_size)
@@ -269,7 +269,7 @@ def __init__(
269269
elif callable(labels_getter):
270270
self._labels_getter = labels_getter
271271
elif isinstance(labels_getter, str):
272-
self._labels_getter = lambda inputs: SanitizeBoundingBoxes._get_dict_or_second_tuple_entry(inputs)[
272+
self._labels_getter = lambda inputs: SanitizeBoundingBox._get_dict_or_second_tuple_entry(inputs)[
273273
labels_getter # type: ignore[index]
274274
]
275275
elif labels_getter is None:
@@ -300,7 +300,7 @@ def _get_dict_or_second_tuple_entry(inputs: Any) -> Mapping[str, Any]:
300300
def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
301301
# Tries to find a "labels" key, otherwise tries for the first key that contains "label" - case insensitive
302302
# Returns None if nothing is found
303-
inputs = SanitizeBoundingBoxes._get_dict_or_second_tuple_entry(inputs)
303+
inputs = SanitizeBoundingBox._get_dict_or_second_tuple_entry(inputs)
304304
candidate_key = None
305305
with suppress(StopIteration):
306306
candidate_key = next(key for key in inputs.keys() if key.lower() == "labels")

0 commit comments

Comments
 (0)