Skip to content

Commit 46e07ad

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] make convert_format_bounding_box a hybrid kernel dispatcher (#7228)
Reviewed By: vmoens Differential Revision: D44416273 fbshipit-source-id: de42739e883dc6aa60601156b2052fbbf2c52290
1 parent 9be983c commit 46e07ad

File tree

5 files changed

+94
-21
lines changed

5 files changed

+94
-21
lines changed

test/prototype_common_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,13 @@ class TensorLoader:
237237
def load(self, device):
238238
return self.fn(self.shape, self.dtype, device)
239239

240+
def unwrap(self):
241+
return TensorLoader(
242+
fn=lambda shape, dtype, device: self.fn(shape, dtype, device).as_subclass(torch.Tensor),
243+
shape=self.shape,
244+
dtype=self.dtype,
245+
)
246+
240247

241248
@dataclasses.dataclass
242249
class ImageLoader(TensorLoader):

test/prototype_transforms_kernel_infos.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
make_video_loader,
2727
make_video_loaders,
2828
mark_framework_limitation,
29-
TensorLoader,
3029
TestMark,
3130
)
3231
from torch.utils._pytree import tree_map
@@ -660,7 +659,8 @@ def sample_inputs_affine_video():
660659
def sample_inputs_convert_format_bounding_box():
661660
formats = list(datapoints.BoundingBoxFormat)
662661
for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
663-
yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format)
662+
yield ArgsKwargs(bounding_box_loader, new_format=new_format)
663+
yield ArgsKwargs(bounding_box_loader.unwrap(), old_format=bounding_box_loader.format, new_format=new_format)
664664

665665

666666
def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
@@ -671,8 +671,14 @@ def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
671671

672672
def reference_inputs_convert_format_bounding_box():
673673
for args_kwargs in sample_inputs_convert_format_bounding_box():
674-
if len(args_kwargs.args[0].shape) == 2:
675-
yield args_kwargs
674+
if len(args_kwargs.args[0].shape) != 2:
675+
continue
676+
677+
(loader, *other_args), kwargs = args_kwargs
678+
if isinstance(loader, BoundingBoxLoader):
679+
kwargs["old_format"] = loader.format
680+
loader = loader.unwrap()
681+
yield ArgsKwargs(loader, *other_args, **kwargs)
676682

677683

678684
KERNEL_INFOS.append(
@@ -682,6 +688,18 @@ def reference_inputs_convert_format_bounding_box():
682688
reference_fn=reference_convert_format_bounding_box,
683689
reference_inputs_fn=reference_inputs_convert_format_bounding_box,
684690
logs_usage=True,
691+
test_marks=[
692+
mark_framework_limitation(
693+
("TestKernels", "test_scripted_vs_eager"),
694+
reason=(
695+
"The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a "
696+
"`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor "
697+
"`spatial_size` was passed"
698+
),
699+
condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader)
700+
and arg_kwargs.kwargs.get("old_format") is None,
701+
)
702+
],
685703
),
686704
)
687705

@@ -2014,13 +2032,10 @@ def sample_inputs_clamp_bounding_box():
20142032
for bounding_box_loader in make_bounding_box_loaders():
20152033
yield ArgsKwargs(bounding_box_loader)
20162034

2017-
simple_tensor_loader = TensorLoader(
2018-
fn=lambda shape, dtype, device: bounding_box_loader.fn(shape, dtype, device).as_subclass(torch.Tensor),
2019-
shape=bounding_box_loader.shape,
2020-
dtype=bounding_box_loader.dtype,
2021-
)
20222035
yield ArgsKwargs(
2023-
simple_tensor_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
2036+
bounding_box_loader.unwrap(),
2037+
format=bounding_box_loader.format,
2038+
spatial_size=bounding_box_loader.spatial_size,
20242039
)
20252040

20262041

test/test_prototype_transforms_functional.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ class TestClampBoundingBox:
572572
def test_simple_tensor_insufficient_metadata(self, metadata):
573573
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
574574

575-
with pytest.raises(ValueError, match="simple tensor"):
575+
with pytest.raises(ValueError, match=re.escape("`format` and `spatial_size` has to be passed")):
576576
F.clamp_bounding_box(simple_tensor, **metadata)
577577

578578
@pytest.mark.parametrize(
@@ -586,10 +586,37 @@ def test_simple_tensor_insufficient_metadata(self, metadata):
586586
def test_datapoint_explicit_metadata(self, metadata):
587587
datapoint = next(make_bounding_boxes())
588588

589-
with pytest.raises(ValueError, match="bounding box datapoint"):
589+
with pytest.raises(ValueError, match=re.escape("`format` and `spatial_size` must not be passed")):
590590
F.clamp_bounding_box(datapoint, **metadata)
591591

592592

593+
class TestConvertFormatBoundingBox:
594+
@pytest.mark.parametrize(
595+
("inpt", "old_format"),
596+
[
597+
(next(make_bounding_boxes()), None),
598+
(next(make_bounding_boxes()).as_subclass(torch.Tensor), datapoints.BoundingBoxFormat.XYXY),
599+
],
600+
)
601+
def test_missing_new_format(self, inpt, old_format):
602+
with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
603+
F.convert_format_bounding_box(inpt, old_format)
604+
605+
def test_simple_tensor_insufficient_metadata(self):
606+
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
607+
608+
with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
609+
F.convert_format_bounding_box(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
610+
611+
def test_datapoint_explicit_metadata(self):
612+
datapoint = next(make_bounding_boxes())
613+
614+
with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
615+
F.convert_format_bounding_box(
616+
datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH
617+
)
618+
619+
593620
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
594621
# `prototype_transforms_kernel_infos.py`
595622

torchvision/prototype/transforms/_meta.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,7 @@ def __init__(self, format: Union[str, datapoints.BoundingBoxFormat]) -> None:
1919
self.format = format
2020

2121
def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox:
22-
# We need to unwrap here to avoid unnecessary `__torch_function__` calls,
23-
# since `convert_format_bounding_box` does not have a dispatcher function that would do that for us
24-
output = F.convert_format_bounding_box(
25-
inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=params["format"]
26-
)
27-
return datapoints.BoundingBox.wrap_like(inpt, output, format=params["format"])
22+
return F.convert_format_bounding_box(inpt, new_format=self.format) # type: ignore[return-value]
2823

2924

3025
class ConvertDtype(Transform):

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,9 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
186186
return xyxy
187187

188188

189-
def convert_format_bounding_box(
189+
def _convert_format_bounding_box(
190190
bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False
191191
) -> torch.Tensor:
192-
if not torch.jit.is_scripting():
193-
_log_api_usage_once(convert_format_bounding_box)
194192

195193
if new_format == old_format:
196194
return bounding_box
@@ -209,6 +207,37 @@ def convert_format_bounding_box(
209207
return bounding_box
210208

211209

210+
def convert_format_bounding_box(
211+
inpt: datapoints.InputTypeJIT,
212+
old_format: Optional[BoundingBoxFormat] = None,
213+
new_format: Optional[BoundingBoxFormat] = None,
214+
inplace: bool = False,
215+
) -> datapoints.InputTypeJIT:
216+
# This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor
217+
# inputs as well as extract it from `datapoints.BoundingBox` inputs. However, putting a default value on
218+
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
219+
# default error that would be thrown if `new_format` had no default value.
220+
if new_format is None:
221+
raise TypeError("convert_format_bounding_box() missing 1 required argument: 'new_format'")
222+
223+
if not torch.jit.is_scripting():
224+
_log_api_usage_once(convert_format_bounding_box)
225+
226+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
227+
if old_format is None:
228+
raise ValueError("For simple tensor inputs, `old_format` has to be passed.")
229+
return _convert_format_bounding_box(inpt, old_format=old_format, new_format=new_format, inplace=inplace)
230+
elif isinstance(inpt, datapoints.BoundingBox):
231+
if old_format is not None:
232+
raise ValueError("For bounding box datapoint inputs, `old_format` must not be passed.")
233+
output = _convert_format_bounding_box(inpt, old_format=inpt.format, new_format=new_format, inplace=inplace)
234+
return datapoints.BoundingBox.wrap_like(inpt, output)
235+
else:
236+
raise TypeError(
237+
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
238+
)
239+
240+
212241
def _clamp_bounding_box(
213242
bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int]
214243
) -> torch.Tensor:

0 commit comments

Comments
 (0)