26
26
make_video_loader ,
27
27
make_video_loaders ,
28
28
mark_framework_limitation ,
29
- TensorLoader ,
30
29
TestMark ,
31
30
)
32
31
from torch .utils ._pytree import tree_map
@@ -660,7 +659,8 @@ def sample_inputs_affine_video():
660
659
def sample_inputs_convert_format_bounding_box ():
661
660
formats = list (datapoints .BoundingBoxFormat )
662
661
for bounding_box_loader , new_format in itertools .product (make_bounding_box_loaders (formats = formats ), formats ):
663
- yield ArgsKwargs (bounding_box_loader , old_format = bounding_box_loader .format , new_format = new_format )
662
+ yield ArgsKwargs (bounding_box_loader , new_format = new_format )
663
+ yield ArgsKwargs (bounding_box_loader .unwrap (), old_format = bounding_box_loader .format , new_format = new_format )
664
664
665
665
666
666
def reference_convert_format_bounding_box (bounding_box , old_format , new_format ):
@@ -671,8 +671,14 @@ def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
671
671
672
672
def reference_inputs_convert_format_bounding_box ():
673
673
for args_kwargs in sample_inputs_convert_format_bounding_box ():
674
- if len (args_kwargs .args [0 ].shape ) == 2 :
675
- yield args_kwargs
674
+ if len (args_kwargs .args [0 ].shape ) != 2 :
675
+ continue
676
+
677
+ (loader , * other_args ), kwargs = args_kwargs
678
+ if isinstance (loader , BoundingBoxLoader ):
679
+ kwargs ["old_format" ] = loader .format
680
+ loader = loader .unwrap ()
681
+ yield ArgsKwargs (loader , * other_args , ** kwargs )
676
682
677
683
678
684
KERNEL_INFOS .append (
@@ -682,6 +688,18 @@ def reference_inputs_convert_format_bounding_box():
682
688
reference_fn = reference_convert_format_bounding_box ,
683
689
reference_inputs_fn = reference_inputs_convert_format_bounding_box ,
684
690
logs_usage = True ,
691
+ test_marks = [
692
+ mark_framework_limitation (
693
+ ("TestKernels" , "test_scripted_vs_eager" ),
694
+ reason = (
695
+ "The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a "
696
+ "`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor "
697
+ "`spatial_size` was passed"
698
+ ),
699
+ condition = lambda arg_kwargs : isinstance (arg_kwargs .args [0 ], BoundingBoxLoader )
700
+ and arg_kwargs .kwargs .get ("old_format" ) is None ,
701
+ )
702
+ ],
685
703
),
686
704
)
687
705
@@ -2014,13 +2032,10 @@ def sample_inputs_clamp_bounding_box():
2014
2032
for bounding_box_loader in make_bounding_box_loaders ():
2015
2033
yield ArgsKwargs (bounding_box_loader )
2016
2034
2017
- simple_tensor_loader = TensorLoader (
2018
- fn = lambda shape , dtype , device : bounding_box_loader .fn (shape , dtype , device ).as_subclass (torch .Tensor ),
2019
- shape = bounding_box_loader .shape ,
2020
- dtype = bounding_box_loader .dtype ,
2021
- )
2022
2035
yield ArgsKwargs (
2023
- simple_tensor_loader , format = bounding_box_loader .format , spatial_size = bounding_box_loader .spatial_size
2036
+ bounding_box_loader .unwrap (),
2037
+ format = bounding_box_loader .format ,
2038
+ spatial_size = bounding_box_loader .spatial_size ,
2024
2039
)
2025
2040
2026
2041
0 commit comments