Skip to content

Commit ed48bb1

Browse files
NicolasHugpmeier
andauthored
Extend default heuristic of SanitizeBoundingBoxes to support tuples (#7304)
Co-authored-by: Philip Meier <[email protected]>
1 parent a46d97c commit ed48bb1

File tree

2 files changed

+63
-22
lines changed

2 files changed

+63
-22
lines changed

test/test_transforms_v2.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1935,7 +1935,14 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
19351935
@pytest.mark.parametrize(
19361936
"labels_getter", ("default", "labels", lambda inputs: inputs["labels"], None, lambda inputs: None)
19371937
)
1938-
def test_sanitize_bounding_boxes(min_size, labels_getter):
1938+
@pytest.mark.parametrize("sample_type", (tuple, dict))
1939+
def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
1940+
1941+
if sample_type is tuple and not isinstance(labels_getter, str):
1942+
# The "lambda inputs: inputs["labels"]" labels_getter used in this test
1943+
# doesn't work if the input is a tuple.
1944+
return
1945+
19391946
H, W = 256, 128
19401947

19411948
boxes_and_validity = [
@@ -1970,35 +1977,56 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
19701977
)
19711978

19721979
masks = datapoints.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W)))
1973-
1980+
whatever = torch.rand(10)
1981+
input_img = torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8)
19741982
sample = {
1975-
"image": torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8),
1983+
"image": input_img,
19761984
"labels": labels,
19771985
"boxes": boxes,
1978-
"whatever": torch.rand(10),
1986+
"whatever": whatever,
19791987
"None": None,
19801988
"masks": masks,
19811989
}
19821990

1991+
if sample_type is tuple:
1992+
img = sample.pop("image")
1993+
sample = (img, sample)
1994+
19831995
out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample)
19841996

1985-
assert out["image"] is sample["image"]
1986-
assert out["whatever"] is sample["whatever"]
1997+
if sample_type is tuple:
1998+
out_image = out[0]
1999+
out_labels = out[1]["labels"]
2000+
out_boxes = out[1]["boxes"]
2001+
out_masks = out[1]["masks"]
2002+
out_whatever = out[1]["whatever"]
2003+
else:
2004+
out_image = out["image"]
2005+
out_labels = out["labels"]
2006+
out_boxes = out["boxes"]
2007+
out_masks = out["masks"]
2008+
out_whatever = out["whatever"]
2009+
2010+
assert out_image is input_img
2011+
assert out_whatever is whatever
19872012

19882013
if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None):
1989-
assert out["labels"] is sample["labels"]
2014+
assert out_labels is labels
19902015
else:
1991-
assert isinstance(out["labels"], torch.Tensor)
1992-
assert out["boxes"].shape[0] == out["labels"].shape[0] == out["masks"].shape[0]
2016+
assert isinstance(out_labels, torch.Tensor)
2017+
assert out_boxes.shape[0] == out_labels.shape[0] == out_masks.shape[0]
19932018
# This works because we conveniently set labels to arange(num_boxes)
1994-
assert out["labels"].tolist() == valid_indices
2019+
assert out_labels.tolist() == valid_indices
19952020

19962021

19972022
@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))
1998-
def test_sanitize_bounding_boxes_default_heuristic(key):
2023+
@pytest.mark.parametrize("sample_type", (tuple, dict))
2024+
def test_sanitize_bounding_boxes_default_heuristic(key, sample_type):
19992025
labels = torch.arange(10)
2000-
d = {key: labels}
2001-
assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(d) is labels
2026+
sample = {key: labels, "another_key": "whatever"}
2027+
if sample_type is tuple:
2028+
sample = (None, sample, "whatever_again")
2029+
assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(sample) is labels
20022030

20032031
if key.lower() != "labels":
20042032
# If "labels" is in the dict (case-insensitive),

torchvision/transforms/v2/_misc.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import collections
22
import warnings
33
from contextlib import suppress
4-
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Type, Union
4+
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Sequence, Type, Union
55

66
import PIL.Image
77

@@ -269,7 +269,9 @@ def __init__(
269269
elif callable(labels_getter):
270270
self._labels_getter = labels_getter
271271
elif isinstance(labels_getter, str):
272-
self._labels_getter = lambda inputs: inputs[labels_getter]
272+
self._labels_getter = lambda inputs: SanitizeBoundingBoxes._get_dict_or_second_tuple_entry(inputs)[
273+
labels_getter # type: ignore[index]
274+
]
273275
elif labels_getter is None:
274276
self._labels_getter = None
275277
else:
@@ -278,10 +280,27 @@ def __init__(
278280
f"Got {labels_getter} of type {type(labels_getter)}."
279281
)
280282

283+
@staticmethod
284+
def _get_dict_or_second_tuple_entry(inputs: Any) -> Mapping[str, Any]:
285+
# datasets outputs may be plain dicts like {"img": ..., "labels": ..., "bbox": ...}
286+
# or tuples like (img, {"labels":..., "bbox": ...})
287+
# This hacky helper accounts for both structures.
288+
if isinstance(inputs, tuple):
289+
inputs = inputs[1]
290+
291+
if not isinstance(inputs, collections.abc.Mapping):
292+
raise ValueError(
293+
f"If labels_getter is a str or 'default', "
294+
f"then the input to forward() must be a dict or a tuple whose second element is a dict."
295+
f" Got {type(inputs)} instead."
296+
)
297+
return inputs
298+
281299
@staticmethod
282300
def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
283-
# Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive
301+
# Tries to find a "labels" key, otherwise tries for the first key that contains "label" - case insensitive
284302
# Returns None if nothing is found
303+
inputs = SanitizeBoundingBoxes._get_dict_or_second_tuple_entry(inputs)
285304
candidate_key = None
286305
with suppress(StopIteration):
287306
candidate_key = next(key for key in inputs.keys() if key.lower() == "labels")
@@ -298,12 +317,6 @@ def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Ten
298317
def forward(self, *inputs: Any) -> Any:
299318
inputs = inputs if len(inputs) > 1 else inputs[0]
300319

301-
if isinstance(self.labels_getter, str) and not isinstance(inputs, collections.abc.Mapping):
302-
raise ValueError(
303-
f"If labels_getter is a str or 'default' (got {self.labels_getter}), "
304-
f"then the input to forward() must be a dict. Got {type(inputs)} instead."
305-
)
306-
307320
if self._labels_getter is None:
308321
labels = None
309322
else:

0 commit comments

Comments
 (0)