@@ -155,12 +155,14 @@ def _unbatch(self, batch, *, data_dims):
155
155
if batched_tensor .ndim == data_dims :
156
156
return batch
157
157
158
- return [
159
- self ._unbatch (unbatched , data_dims = data_dims )
160
- for unbatched in (
161
- batched_tensor .unbind (0 ) if not metadata else [(t , * metadata ) for t in batched_tensor .unbind (0 )]
162
- )
163
- ]
158
+ unbatcheds = []
159
+ for unbatched in (
160
+ batched_tensor .unbind (0 ) if not metadata else [(t , * metadata ) for t in batched_tensor .unbind (0 )]
161
+ ):
162
+ if isinstance (batch , datapoints ._datapoint .Datapoint ):
163
+ unbatched = type (batch ).wrap_like (batch , unbatched )
164
+ unbatcheds .append (self ._unbatch (unbatched , data_dims = data_dims ))
165
+ return unbatcheds
164
166
165
167
@sample_inputs
166
168
@pytest .mark .parametrize ("device" , cpu_and_gpu ())
@@ -558,6 +560,36 @@ def assert_samples_from_standard_normal(t):
558
560
assert_samples_from_standard_normal (F .normalize_image_tensor (image , mean , std ))
559
561
560
562
563
+ class TestClampBoundingBox :
564
+ @pytest .mark .parametrize (
565
+ "metadata" ,
566
+ [
567
+ dict (),
568
+ dict (format = datapoints .BoundingBoxFormat .XYXY ),
569
+ dict (spatial_size = (1 , 1 )),
570
+ ],
571
+ )
572
+ def test_simple_tensor_insufficient_metadata (self , metadata ):
573
+ simple_tensor = next (make_bounding_boxes ()).as_subclass (torch .Tensor )
574
+
575
+ with pytest .raises (ValueError , match = "simple tensor" ):
576
+ F .clamp_bounding_box (simple_tensor , ** metadata )
577
+
578
+ @pytest .mark .parametrize (
579
+ "metadata" ,
580
+ [
581
+ dict (format = datapoints .BoundingBoxFormat .XYXY ),
582
+ dict (spatial_size = (1 , 1 )),
583
+ dict (format = datapoints .BoundingBoxFormat .XYXY , spatial_size = (1 , 1 )),
584
+ ],
585
+ )
586
+ def test_datapoint_explicit_metadata (self , metadata ):
587
+ datapoint = next (make_bounding_boxes ())
588
+
589
+ with pytest .raises (ValueError , match = "bounding box datapoint" ):
590
+ F .clamp_bounding_box (datapoint , ** metadata )
591
+
592
+
561
593
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
562
594
# `prototype_transforms_kernel_infos.py`
563
595
0 commit comments