From c700e9af15179b2c3a21418832d205b9cc979582 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 13 Feb 2023 15:18:45 +0100 Subject: [PATCH 1/3] only use plain tensors in kernel tests --- test/prototype_common_utils.py | 7 - test/prototype_transforms_kernel_infos.py | 24 ++-- test/test_prototype_transforms_functional.py | 139 ++++++++----------- 3 files changed, 69 insertions(+), 101 deletions(-) diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index c53fecaef7e..f2ae8d2b9e5 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -237,13 +237,6 @@ class TensorLoader: def load(self, device): return self.fn(self.shape, self.dtype, device) - def unwrap(self): - return TensorLoader( - fn=lambda shape, dtype, device: self.fn(shape, dtype, device).as_subclass(torch.Tensor), - shape=self.shape, - dtype=self.dtype, - ) - @dataclasses.dataclass class ImageLoader(TensorLoader): diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index eddf76440c5..3111582706d 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -337,7 +337,6 @@ def sample_inputs_resize_video(): def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size=None): - old_height, old_width = spatial_size new_height, new_width = F._geometry._compute_resized_output_size(spatial_size, size=size, max_size=max_size) @@ -350,13 +349,15 @@ def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size= ) expected_bboxes = reference_affine_bounding_box_helper( - bounding_box, format=bounding_box.format, affine_matrix=affine_matrix + bounding_box, format=datapoints.BoundingBoxFormat.XYXY, affine_matrix=affine_matrix ) return expected_bboxes, (new_height, new_width) def reference_inputs_resize_bounding_box(): - for bounding_box_loader in make_bounding_box_loaders(extra_dims=((), (4,))): + for bounding_box_loader in make_bounding_box_loaders( + formats=[datapoints.BoundingBoxFormat.XYXY], extra_dims=((), (4,)) + ): for size in _get_resize_sizes(bounding_box_loader.spatial_size): yield ArgsKwargs(bounding_box_loader, size=size, spatial_size=bounding_box_loader.spatial_size) @@ -668,8 +669,7 @@ def sample_inputs_affine_video(): def sample_inputs_convert_format_bounding_box(): formats = list(datapoints.BoundingBoxFormat) for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats): - yield ArgsKwargs(bounding_box_loader, new_format=new_format) - yield ArgsKwargs(bounding_box_loader.unwrap(), old_format=bounding_box_loader.format, new_format=new_format) + yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format) def reference_convert_format_bounding_box(bounding_box, old_format, new_format): @@ -680,14 +680,8 @@ def reference_convert_format_bounding_box(bounding_box, old_format, new_format): def reference_inputs_convert_format_bounding_box(): for args_kwargs in sample_inputs_convert_format_bounding_box(): - if len(args_kwargs.args[0].shape) != 2: - continue - - (loader, *other_args), kwargs = args_kwargs - if isinstance(loader, BoundingBoxLoader): - kwargs["old_format"] = loader.format - loader = loader.unwrap() - yield ArgsKwargs(loader, *other_args, **kwargs) + if len(args_kwargs.args[0].shape) == 2: + yield args_kwargs KERNEL_INFOS.append( @@ -2049,10 +2043,8 @@ def sample_inputs_adjust_saturation_video(): def sample_inputs_clamp_bounding_box(): for bounding_box_loader in make_bounding_box_loaders(): - yield ArgsKwargs(bounding_box_loader) - yield ArgsKwargs( - bounding_box_loader.unwrap(), + bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size, ) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 539cbce7787..d371c2343e1 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -121,8 +121,8 @@ class TestKernels: def test_logging(self, spy_on, info, args_kwargs, device): spy = spy_on(torch._C._log_api_usage_once) - args, kwargs = args_kwargs.load(device) - info.kernel(*args, **kwargs) + (input, *other_args), kwargs = args_kwargs.load(device) + info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs) spy.assert_any_call(f"{info.kernel.__module__}.{info.id}") @@ -134,6 +134,7 @@ def test_scripted_vs_eager(self, test_id, info, args_kwargs, device): kernel_scripted = script(kernel_eager) (input, *other_args), kwargs = args_kwargs.load(device) + input = input.as_subclass(torch.Tensor) actual = kernel_scripted(input, *other_args, **kwargs) expected = kernel_eager(input, *other_args, **kwargs) @@ -155,14 +156,12 @@ def _unbatch(self, batch, *, data_dims): if batched_tensor.ndim == data_dims: return batch - unbatcheds = [] - for unbatched in ( - batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)] - ): - if isinstance(batch, datapoints._datapoint.Datapoint): - unbatched = type(batch).wrap_like(batch, unbatched) - unbatcheds.append(self._unbatch(unbatched, data_dims=data_dims)) - return unbatcheds + return [ + self._unbatch(unbatched, data_dims=data_dims) + for unbatched in ( + batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)] + ) + ] @sample_inputs @pytest.mark.parametrize("device", cpu_and_gpu()) @@ -195,6 +194,7 @@ def test_batched_vs_single(self, test_id, info, args_kwargs, device): elif not all(batched_input.shape[:-data_dims]): pytest.skip("Input has a degenerate batch shape.") + batched_input = batched_input.as_subclass(torch.Tensor) batched_output = info.kernel(batched_input, *other_args, **kwargs) actual = self._unbatch(batched_output, data_dims=data_dims) @@ -212,6 +212,7 @@ def test_batched_vs_single(self, test_id, info, args_kwargs, device): @pytest.mark.parametrize("device", cpu_and_gpu()) def test_no_inplace(self, info, args_kwargs, device): (input, *other_args), kwargs = args_kwargs.load(device) + input = input.as_subclass(torch.Tensor) if input.numel() == 0: pytest.skip("The input has a degenerate shape.") @@ -225,6 +226,7 @@ def test_no_inplace(self, info, args_kwargs, device): @needs_cuda def test_cuda_vs_cpu(self, test_id, info, args_kwargs): (input_cpu, *other_args), kwargs = args_kwargs.load("cpu") + input_cpu = input_cpu.as_subclass(torch.Tensor) input_cuda = input_cpu.to("cuda") output_cpu = info.kernel(input_cpu, *other_args, **kwargs) @@ -242,6 +244,7 @@ def test_cuda_vs_cpu(self, test_id, info, args_kwargs): @pytest.mark.parametrize("device", cpu_and_gpu()) def test_dtype_and_device_consistency(self, info, args_kwargs, device): (input, *other_args), kwargs = args_kwargs.load(device) + input = input.as_subclass(torch.Tensor) output = info.kernel(input, *other_args, **kwargs) # Most kernels just return a tensor, but some also return some additional metadata @@ -254,6 +257,7 @@ def test_dtype_and_device_consistency(self, info, args_kwargs, device): @reference_inputs def test_against_reference(self, test_id, info, args_kwargs): (input, *other_args), kwargs = args_kwargs.load("cpu") + input = input.as_subclass(torch.Tensor) actual = info.kernel(input, *other_args, **kwargs) expected = info.reference_fn(input, *other_args, **kwargs) @@ -271,6 +275,7 @@ def test_against_reference(self, test_id, info, args_kwargs): ) def test_float32_vs_uint8(self, test_id, info, args_kwargs): (input, *other_args), kwargs = args_kwargs.load("cpu") + input = input.as_subclass(torch.Tensor) if input.dtype != torch.uint8: pytest.skip(f"Input dtype is {input.dtype}.") @@ -647,21 +652,15 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_): @pytest.mark.parametrize("device", cpu_and_gpu()) def test_correctness_affine_bounding_box_on_fixed_input(device): # Check transformation against known expected output + format = datapoints.BoundingBoxFormat.XYXY spatial_size = (64, 64) - # xyxy format in_boxes = [ [20, 25, 35, 45], [50, 5, 70, 22], [spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10], [1, 1, 5, 5], ] - in_boxes = datapoints.BoundingBox( - in_boxes, - format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=spatial_size, - dtype=torch.float64, - device=device, - ) + in_boxes = torch.tensor(in_boxes, dtype=torch.float64, device=device) # Tested parameters angle = 63 scale = 0.89 @@ -686,11 +685,11 @@ def test_correctness_affine_bounding_box_on_fixed_input(device): output_boxes = F.affine_bounding_box( in_boxes, - in_boxes.format, - in_boxes.spatial_size, - angle, - (dx * spatial_size[1], dy * spatial_size[0]), - scale, + format=format, + spatial_size=spatial_size, + angle=angle, + translate=(dx * spatial_size[1], dy * spatial_size[0]), + scale=scale, shear=(0, 0), ) @@ -725,9 +724,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): affine_matrix = affine_matrix[:2, :] height, width = bbox.spatial_size - bbox_xyxy = convert_format_bounding_box( - bbox, old_format=bbox.format, new_format=datapoints.BoundingBoxFormat.XYXY - ) + bbox_xyxy = convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY) points = np.array( [ [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], @@ -766,10 +763,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): dtype=bbox.dtype, device=bbox.device, ) - return ( - convert_format_bounding_box(out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox.format), - (height, width), - ) + return convert_format_bounding_box(out_bbox, new_format=bbox.format), (height, width) spatial_size = (32, 38) @@ -778,8 +772,8 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): bboxes_spatial_size = bboxes.spatial_size output_bboxes, output_spatial_size = F.rotate_bounding_box( - bboxes, - bboxes_format, + bboxes.as_subclass(torch.Tensor), + format=bboxes_format, spatial_size=bboxes_spatial_size, angle=angle, expand=expand, @@ -810,6 +804,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): @pytest.mark.parametrize("expand", [False]) # expand=True does not match D2 def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): # Check transformation against known expected output + format = datapoints.BoundingBoxFormat.XYXY spatial_size = (64, 64) # xyxy format in_boxes = [ @@ -818,13 +813,7 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): [spatial_size[1] - 6, spatial_size[0] - 6, spatial_size[1] - 2, spatial_size[0] - 2], [spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10], ] - in_boxes = datapoints.BoundingBox( - in_boxes, - format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=spatial_size, - dtype=torch.float64, - device=device, - ) + in_boxes = torch.tensor(in_boxes, dtype=torch.float64, device=device) # Tested parameters angle = 45 center = None if expand else [12, 23] @@ -854,9 +843,9 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): output_boxes, _ = F.rotate_bounding_box( in_boxes, - in_boxes.format, - in_boxes.spatial_size, - angle, + format=format, + spatial_size=spatial_size, + angle=angle, expand=expand, center=center, ) @@ -906,16 +895,14 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, # out_box = denormalize_bbox(n_out_box, height, width) # expected_bboxes.append(out_box) - size = (64, 76) - # xyxy format + format = datapoints.BoundingBoxFormat.XYXY + spatial_size = (64, 76) in_boxes = [ [10.0, 15.0, 25.0, 35.0], [50.0, 5.0, 70.0, 22.0], [45.0, 46.0, 56.0, 62.0], ] - in_boxes = datapoints.BoundingBox( - in_boxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=size, device=device - ) + in_boxes = torch.tensor(in_boxes, device=device) if format != datapoints.BoundingBoxFormat.XYXY: in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) @@ -924,15 +911,15 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, format, top, left, - size[0], - size[1], + spatial_size[0], + spatial_size[1], ) if format != datapoints.BoundingBoxFormat.XYXY: output_boxes = convert_format_bounding_box(output_boxes, format, datapoints.BoundingBoxFormat.XYXY) torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) - torch.testing.assert_close(output_spatial_size, size) + torch.testing.assert_close(output_spatial_size, spatial_size) @pytest.mark.parametrize("device", cpu_and_gpu()) @@ -980,8 +967,8 @@ def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_): bbox[3] = (bbox[3] - top_) * size_[0] / height_ return bbox + format = datapoints.BoundingBoxFormat.XYXY spatial_size = (100, 100) - # xyxy format in_boxes = [ [10.0, 10.0, 20.0, 20.0], [5.0, 10.0, 15.0, 20.0], @@ -1024,22 +1011,22 @@ def test_correctness_pad_bounding_box(device, padding): def _compute_expected_bbox(bbox, padding_): pad_left, pad_up, _, _ = _parse_padding(padding_) - bbox_format = bbox.format - bbox_dtype = bbox.dtype + dtype = bbox.dtype + format = bbox.format bbox = ( bbox.clone() - if bbox_format == datapoints.BoundingBoxFormat.XYXY - else convert_format_bounding_box(bbox, bbox_format, datapoints.BoundingBoxFormat.XYXY) + if format == datapoints.BoundingBoxFormat.XYXY + else convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY) ) bbox[0::2] += pad_left bbox[1::2] += pad_up - bbox = convert_format_bounding_box(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox_format) - if bbox.dtype != bbox_dtype: + bbox = convert_format_bounding_box(bbox, new_format=format) + if bbox.dtype != dtype: # Temporary cast to original dtype # e.g. float32 -> int - bbox = bbox.to(bbox_dtype) + bbox = bbox.to(dtype) return bbox def _compute_expected_spatial_size(bbox, padding_): @@ -1108,9 +1095,7 @@ def _compute_expected_bbox(bbox, pcoeffs_): ] ) - bbox_xyxy = convert_format_bounding_box( - bbox, old_format=bbox.format, new_format=datapoints.BoundingBoxFormat.XYXY - ) + bbox_xyxy = convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY) points = np.array( [ [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], @@ -1122,22 +1107,22 @@ def _compute_expected_bbox(bbox, pcoeffs_): numer = np.matmul(points, m1.T) denom = np.matmul(points, m2.T) transformed_points = numer / denom - out_bbox = [ - np.min(transformed_points[:, 0]), - np.min(transformed_points[:, 1]), - np.max(transformed_points[:, 0]), - np.max(transformed_points[:, 1]), - ] + out_bbox = np.array( + [ + np.min(transformed_points[:, 0]), + np.min(transformed_points[:, 1]), + np.max(transformed_points[:, 0]), + np.max(transformed_points[:, 1]), + ] + ) out_bbox = datapoints.BoundingBox( - np.array(out_bbox), + out_bbox, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=bbox.spatial_size, dtype=bbox.dtype, device=bbox.device, ) - return convert_format_bounding_box( - out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox.format - ) + return convert_format_bounding_box(out_bbox, new_format=bbox.format) spatial_size = (32, 38) @@ -1146,14 +1131,12 @@ def _compute_expected_bbox(bbox, pcoeffs_): for bboxes in make_bounding_boxes(spatial_size=spatial_size, extra_dims=((4,),)): bboxes = bboxes.to(device) - bboxes_format = bboxes.format - bboxes_spatial_size = bboxes.spatial_size output_bboxes = F.perspective_bounding_box( - bboxes, - bboxes_format, - None, - None, + bboxes.as_subclass(torch.Tensor), + format=bboxes.format, + startpoints=None, + endpoints=None, coefficients=pcoeffs, ) @@ -1162,7 +1145,7 @@ def _compute_expected_bbox(bbox, pcoeffs_): expected_bboxes = [] for bbox in bboxes: - bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) + bbox = datapoints.BoundingBox(bbox, format=bboxes.format, spatial_size=bboxes.spatial_size) expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs)) if len(expected_bboxes) > 1: expected_bboxes = torch.stack(expected_bboxes) From 61383293b25bc430208cdd300f3c9e5faac996ea Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 13 Feb 2023 15:43:53 +0100 Subject: [PATCH 2/3] remove skips hybrids --- test/prototype_transforms_kernel_infos.py | 26 ----------------------- 1 file changed, 26 deletions(-) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 3111582706d..e65fb2d6a6a 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -12,7 +12,6 @@ from datasets_utils import combinations_grid from prototype_common_utils import ( ArgsKwargs, - BoundingBoxLoader, get_num_channels, ImageLoader, InfoBase, @@ -691,18 +690,6 @@ def reference_inputs_convert_format_bounding_box(): reference_fn=reference_convert_format_bounding_box, reference_inputs_fn=reference_inputs_convert_format_bounding_box, logs_usage=True, - test_marks=[ - mark_framework_limitation( - ("TestKernels", "test_scripted_vs_eager"), - reason=( - "The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a " - "`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor " - "`spatial_size` was passed" - ), - condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader) - and arg_kwargs.kwargs.get("old_format") is None, - ) - ], ), ) @@ -2055,19 +2042,6 @@ def sample_inputs_clamp_bounding_box(): F.clamp_bounding_box, sample_inputs_fn=sample_inputs_clamp_bounding_box, logs_usage=True, - test_marks=[ - mark_framework_limitation( - ("TestKernels", "test_scripted_vs_eager"), - reason=( - "The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a " - "`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor " - "`spatial_size` was passed" - ), - condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader) - and arg_kwargs.kwargs.get("format") is None - and arg_kwargs.kwargs.get("spatial_size") is None, - ) - ], ) ) From 99dd341a83e6d970d3e473586f3c871fde733d85 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 13 Feb 2023 15:44:16 +0100 Subject: [PATCH 3/3] add DispatcherInfos for hybrids --- test/prototype_transforms_dispatcher_infos.py | 34 +++++++++++++++---- test/test_prototype_transforms_functional.py | 1 - 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/test/prototype_transforms_dispatcher_infos.py b/test/prototype_transforms_dispatcher_infos.py index f6b8786570c..8fe5333aa51 100644 --- a/test/prototype_transforms_dispatcher_infos.py +++ b/test/prototype_transforms_dispatcher_infos.py @@ -64,13 +64,21 @@ def sample_inputs(self, *datapoint_types, filter_metadata=True): if not filter_metadata: yield from sample_inputs - else: - for args_kwargs in sample_inputs: - for attribute in datapoint_type.__annotations__.keys(): - if attribute in args_kwargs.kwargs: - del args_kwargs.kwargs[attribute] + return - yield args_kwargs + import itertools + + for args_kwargs in sample_inputs: + for name in itertools.chain( + datapoint_type.__annotations__.keys(), + # FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a + # per-dispatcher level. However, so far there is no option for that. + (f"old_{name}" for name in datapoint_type.__annotations__.keys()), + ): + if name in args_kwargs.kwargs: + del args_kwargs.kwargs[name] + + yield args_kwargs def xfail_jit(reason, *, condition=None): @@ -458,4 +466,18 @@ def fill_sequence_needs_broadcast(args_kwargs): skip_dispatch_datapoint, ], ), + DispatcherInfo( + F.clamp_bounding_box, + kernels={datapoints.BoundingBox: F.clamp_bounding_box}, + test_marks=[ + skip_dispatch_datapoint, + ], + ), + DispatcherInfo( + F.convert_format_bounding_box, + kernels={datapoints.BoundingBox: F.convert_format_bounding_box}, + test_marks=[ + skip_dispatch_datapoint, + ], + ), ] diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index d371c2343e1..1650d03de73 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -346,7 +346,6 @@ def test_scripted_smoke(self, info, args_kwargs, device): @pytest.mark.parametrize( "dispatcher", [ - F.clamp_bounding_box, F.get_dimensions, F.get_image_num_channels, F.get_image_size,