Skip to content

Commit 9be983c

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] make clamp_bounding_box a kernel / dispatcher hybrid (#7227)
Reviewed By: vmoens Differential Revision: D44416267 fbshipit-source-id: b3688a4c9d767b5e91c71dac77f44dedde65261b
1 parent 8897252 commit 9be983c

File tree

5 files changed

+89
-20
lines changed

5 files changed

+89
-20
lines changed

test/prototype_common_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -640,14 +640,14 @@ def __init__(
640640
self.condition = condition or (lambda args_kwargs: True)
641641

642642

643-
def mark_framework_limitation(test_id, reason):
643+
def mark_framework_limitation(test_id, reason, condition=None):
644644
# The purpose of this function is to have a single entry point for skip marks that are only there, because the test
645645
# framework cannot handle the kernel in general or a specific parameter combination.
646646
# As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is
647647
# still justified.
648648
# We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus,
649649
# we are wasting CI resources for no reason for most of the time
650-
return TestMark(test_id, pytest.mark.skip(reason=reason))
650+
return TestMark(test_id, pytest.mark.skip(reason=reason), condition=condition)
651651

652652

653653
class InfoBase:

test/prototype_transforms_kernel_infos.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from datasets_utils import combinations_grid
1313
from prototype_common_utils import (
1414
ArgsKwargs,
15+
BoundingBoxLoader,
1516
get_num_channels,
1617
ImageLoader,
1718
InfoBase,
@@ -25,6 +26,7 @@
2526
make_video_loader,
2627
make_video_loaders,
2728
mark_framework_limitation,
29+
TensorLoader,
2830
TestMark,
2931
)
3032
from torch.utils._pytree import tree_map
@@ -2010,8 +2012,15 @@ def sample_inputs_adjust_saturation_video():
20102012

20112013
def sample_inputs_clamp_bounding_box():
20122014
for bounding_box_loader in make_bounding_box_loaders():
2015+
yield ArgsKwargs(bounding_box_loader)
2016+
2017+
simple_tensor_loader = TensorLoader(
2018+
fn=lambda shape, dtype, device: bounding_box_loader.fn(shape, dtype, device).as_subclass(torch.Tensor),
2019+
shape=bounding_box_loader.shape,
2020+
dtype=bounding_box_loader.dtype,
2021+
)
20132022
yield ArgsKwargs(
2014-
bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
2023+
simple_tensor_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
20152024
)
20162025

20172026

@@ -2020,6 +2029,19 @@ def sample_inputs_clamp_bounding_box():
20202029
F.clamp_bounding_box,
20212030
sample_inputs_fn=sample_inputs_clamp_bounding_box,
20222031
logs_usage=True,
2032+
test_marks=[
2033+
mark_framework_limitation(
2034+
("TestKernels", "test_scripted_vs_eager"),
2035+
reason=(
2036+
"The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a "
2037+
"`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor "
2038+
"`spatial_size` was passed"
2039+
),
2040+
condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader)
2041+
and arg_kwargs.kwargs.get("format") is None
2042+
and arg_kwargs.kwargs.get("spatial_size") is None,
2043+
)
2044+
],
20232045
)
20242046
)
20252047

test/test_prototype_transforms_functional.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,14 @@ def _unbatch(self, batch, *, data_dims):
155155
if batched_tensor.ndim == data_dims:
156156
return batch
157157

158-
return [
159-
self._unbatch(unbatched, data_dims=data_dims)
160-
for unbatched in (
161-
batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
162-
)
163-
]
158+
unbatcheds = []
159+
for unbatched in (
160+
batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
161+
):
162+
if isinstance(batch, datapoints._datapoint.Datapoint):
163+
unbatched = type(batch).wrap_like(batch, unbatched)
164+
unbatcheds.append(self._unbatch(unbatched, data_dims=data_dims))
165+
return unbatcheds
164166

165167
@sample_inputs
166168
@pytest.mark.parametrize("device", cpu_and_gpu())
@@ -558,6 +560,36 @@ def assert_samples_from_standard_normal(t):
558560
assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std))
559561

560562

563+
class TestClampBoundingBox:
564+
@pytest.mark.parametrize(
565+
"metadata",
566+
[
567+
dict(),
568+
dict(format=datapoints.BoundingBoxFormat.XYXY),
569+
dict(spatial_size=(1, 1)),
570+
],
571+
)
572+
def test_simple_tensor_insufficient_metadata(self, metadata):
573+
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
574+
575+
with pytest.raises(ValueError, match="simple tensor"):
576+
F.clamp_bounding_box(simple_tensor, **metadata)
577+
578+
@pytest.mark.parametrize(
579+
"metadata",
580+
[
581+
dict(format=datapoints.BoundingBoxFormat.XYXY),
582+
dict(spatial_size=(1, 1)),
583+
dict(format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(1, 1)),
584+
],
585+
)
586+
def test_datapoint_explicit_metadata(self, metadata):
587+
datapoint = next(make_bounding_boxes())
588+
589+
with pytest.raises(ValueError, match="bounding box datapoint"):
590+
F.clamp_bounding_box(datapoint, **metadata)
591+
592+
561593
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
562594
# `prototype_transforms_kernel_infos.py`
563595

torchvision/prototype/transforms/_meta.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,4 @@ class ClampBoundingBoxes(Transform):
5151
_transformed_types = (datapoints.BoundingBox,)
5252

5353
def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox:
54-
# We need to unwrap here to avoid unnecessary `__torch_function__` calls,
55-
# since `clamp_bounding_box` does not have a dispatcher function that would do that for us
56-
output = F.clamp_bounding_box(
57-
inpt.as_subclass(torch.Tensor), format=inpt.format, spatial_size=inpt.spatial_size
58-
)
59-
return datapoints.BoundingBox.wrap_like(inpt, output)
54+
return F.clamp_bounding_box(inpt) # type: ignore[return-value]

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Tuple, Union
1+
from typing import List, Optional, Tuple, Union
22

33
import PIL.Image
44
import torch
@@ -209,12 +209,9 @@ def convert_format_bounding_box(
209209
return bounding_box
210210

211211

212-
def clamp_bounding_box(
212+
def _clamp_bounding_box(
213213
bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int]
214214
) -> torch.Tensor:
215-
if not torch.jit.is_scripting():
216-
_log_api_usage_once(clamp_bounding_box)
217-
218215
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
219216
# BoundingBoxFormat instead of converting back and forth
220217
xyxy_boxes = convert_format_bounding_box(
@@ -225,6 +222,29 @@ def clamp_bounding_box(
225222
return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True)
226223

227224

225+
def clamp_bounding_box(
226+
inpt: datapoints.InputTypeJIT,
227+
format: Optional[BoundingBoxFormat] = None,
228+
spatial_size: Optional[Tuple[int, int]] = None,
229+
) -> datapoints.InputTypeJIT:
230+
if not torch.jit.is_scripting():
231+
_log_api_usage_once(clamp_bounding_box)
232+
233+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
234+
if format is None or spatial_size is None:
235+
raise ValueError("For simple tensor inputs, `format` and `spatial_size` has to be passed.")
236+
return _clamp_bounding_box(inpt, format=format, spatial_size=spatial_size)
237+
elif isinstance(inpt, datapoints.BoundingBox):
238+
if format is not None or spatial_size is not None:
239+
raise ValueError("For bounding box datapoint inputs, `format` and `spatial_size` must not be passed.")
240+
output = _clamp_bounding_box(inpt, format=inpt.format, spatial_size=inpt.spatial_size)
241+
return datapoints.BoundingBox.wrap_like(inpt, output)
242+
else:
243+
raise TypeError(
244+
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
245+
)
246+
247+
228248
def _num_value_bits(dtype: torch.dtype) -> int:
229249
if dtype == torch.uint8:
230250
return 8

0 commit comments

Comments
 (0)