diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 6f10945feaf..3876beea5c4 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -266,7 +266,7 @@ def affine_bounding_box(): @register_kernel_info_from_sample_inputs_fn def affine_segmentation_mask(): - for image, angle, translate, scale, shear in itertools.product( + for mask, angle, translate, scale, shear in itertools.product( make_segmentation_masks(extra_dims=((), (4,))), [-87, 15, 90], # angle [5, -5], # translate @@ -274,7 +274,7 @@ def affine_segmentation_mask(): [0, 12], # shear ): yield SampleInput( - image, + mask, angle=angle, translate=(translate, translate), scale=scale, @@ -285,8 +285,12 @@ def affine_segmentation_mask(): @register_kernel_info_from_sample_inputs_fn def rotate_bounding_box(): for bounding_box, angle, expand, center in itertools.product( - make_bounding_boxes(), [-87, 15, 90], [True, False], [None, [12, 23]] # angle # expand # center + make_bounding_boxes(), [-87, 15, 90], [True, False], [None, [12, 23]] ): + if center is not None and expand: + # Skip warning: The provided center argument is ignored if expand is True + continue + yield SampleInput( bounding_box, format=bounding_box.format, @@ -297,6 +301,26 @@ def rotate_bounding_box(): ) +@register_kernel_info_from_sample_inputs_fn +def rotate_segmentation_mask(): + for mask, angle, expand, center in itertools.product( + make_segmentation_masks(extra_dims=((), (4,))), + [-87, 15, 90], # angle + [True, False], # expand + [None, [12, 23]], # center + ): + if center is not None and expand: + # Skip warning: The provided center argument is ignored if expand is True + continue + + yield SampleInput( + mask, + angle=angle, + expand=expand, + center=center, + ) + + @pytest.mark.parametrize( "kernel", [ @@ -411,8 +435,9 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): center=center, ) - if center is None: - center = [s // 2 for s in bboxes_image_size[::-1]] + center_ = center + if center_ is None: + center_ = [s * 0.5 for s in bboxes_image_size[::-1]] if bboxes.ndim < 2: bboxes = [bboxes] @@ -421,7 +446,7 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): for bbox in bboxes: bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) expected_bboxes.append( - _compute_expected_bbox(bbox, angle, (translate, translate), scale, (shear, shear), center) + _compute_expected_bbox(bbox, angle, (translate, translate), scale, (shear, shear), center_) ) if len(expected_bboxes) > 1: expected_bboxes = torch.stack(expected_bboxes) @@ -510,8 +535,10 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_): shear=(shear, shear), center=center, ) - if center is None: - center = [s // 2 for s in mask.shape[-2:][::-1]] + + center_ = center + if center_ is None: + center_ = [s * 0.5 for s in mask.shape[-2:][::-1]] if mask.ndim < 4: masks = [mask] @@ -520,7 +547,7 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_): expected_masks = [] for mask in masks: - expected_mask = _compute_expected_mask(mask, angle, (translate, translate), scale, (shear, shear), center) + expected_mask = _compute_expected_mask(mask, angle, (translate, translate), scale, (shear, shear), center_) expected_masks.append(expected_mask) if len(expected_masks) > 1: expected_masks = torch.stack(expected_masks) @@ -550,8 +577,7 @@ def test_correctness_affine_segmentation_mask_on_fixed_input(device): @pytest.mark.parametrize("angle", range(-90, 90, 56)) -@pytest.mark.parametrize("expand", [True, False]) -@pytest.mark.parametrize("center", [None, (12, 14)]) +@pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))]) def test_correctness_rotate_bounding_box(angle, expand, center): def _compute_expected_bbox(bbox, angle_, expand_, center_): affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_) @@ -620,8 +646,9 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): center=center, ) - if center is None: - center = [s // 2 for s in bboxes_image_size[::-1]] + center_ = center + if center_ is None: + center_ = [s * 0.5 for s in bboxes_image_size[::-1]] if bboxes.ndim < 2: bboxes = [bboxes] @@ -629,7 +656,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): expected_bboxes = [] for bbox in bboxes: bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) - expected_bboxes.append(_compute_expected_bbox(bbox, -angle, expand, center)) + expected_bboxes.append(_compute_expected_bbox(bbox, -angle, expand, center_)) if len(expected_bboxes) > 1: expected_bboxes = torch.stack(expected_bboxes) else: @@ -638,7 +665,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): @pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2, analysis in progress +@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2 def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): # Check transformation against known expected output image_size = (64, 64) @@ -689,3 +716,91 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): ) torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) + + +@pytest.mark.parametrize("angle", range(-90, 90, 37)) +@pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))]) +def test_correctness_rotate_segmentation_mask(angle, expand, center): + def _compute_expected_mask(mask, angle_, expand_, center_): + assert mask.ndim == 3 and mask.shape[0] == 1 + image_size = mask.shape[-2:] + affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_) + inv_affine_matrix = np.linalg.inv(affine_matrix) + + if expand_: + # Pillow implementation on how to perform expand: + # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054-L2069 + height, width = image_size + points = np.array( + [ + [0.0, 0.0, 1.0], + [0.0, 1.0 * height, 1.0], + [1.0 * width, 1.0 * height, 1.0], + [1.0 * width, 0.0, 1.0], + ] + ) + new_points = points @ inv_affine_matrix.T + min_vals = np.min(new_points, axis=0)[:2] + max_vals = np.max(new_points, axis=0)[:2] + cmax = np.ceil(np.trunc(max_vals * 1e4) * 1e-4) + cmin = np.floor(np.trunc((min_vals + 1e-8) * 1e4) * 1e-4) + new_width, new_height = (cmax - cmin).astype("int32").tolist() + tr = np.array([-(new_width - width) / 2.0, -(new_height - height) / 2.0, 1.0]) @ inv_affine_matrix.T + + inv_affine_matrix[:2, 2] = tr[:2] + image_size = [new_height, new_width] + + inv_affine_matrix = inv_affine_matrix[:2, :] + expected_mask = torch.zeros(1, *image_size, dtype=mask.dtype) + + for out_y in range(expected_mask.shape[1]): + for out_x in range(expected_mask.shape[2]): + output_pt = np.array([out_x + 0.5, out_y + 0.5, 1.0]) + input_pt = np.floor(np.dot(inv_affine_matrix, output_pt)).astype(np.int32) + in_x, in_y = input_pt[:2] + if 0 <= in_x < mask.shape[2] and 0 <= in_y < mask.shape[1]: + expected_mask[0, out_y, out_x] = mask[0, in_y, in_x] + return expected_mask.to(mask.device) + + for mask in make_segmentation_masks(extra_dims=((), (4,))): + output_mask = F.rotate_segmentation_mask( + mask, + angle=angle, + expand=expand, + center=center, + ) + + center_ = center + if center_ is None: + center_ = [s * 0.5 for s in mask.shape[-2:][::-1]] + + if mask.ndim < 4: + masks = [mask] + else: + masks = [m for m in mask] + + expected_masks = [] + for mask in masks: + expected_mask = _compute_expected_mask(mask, -angle, expand, center_) + expected_masks.append(expected_mask) + if len(expected_masks) > 1: + expected_masks = torch.stack(expected_masks) + else: + expected_masks = expected_masks[0] + torch.testing.assert_close(output_mask, expected_masks) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +def test_correctness_rotate_segmentation_mask_on_fixed_input(device): + # Check transformation against known expected output and CPU/CUDA devices + + # Create a fixed input segmentation mask with 2 square masks + # in top-left, bottom-left corners + mask = torch.zeros(1, 32, 32, dtype=torch.long, device=device) + mask[0, 2:10, 2:10] = 1 + mask[0, 32 - 9 : 32 - 3, 3:9] = 2 + + # Rotate 90 degrees + expected_mask = torch.rot90(mask, k=1, dims=(-2, -1)) + out_mask = F.rotate_segmentation_mask(mask, 90, expand=False) + torch.testing.assert_close(out_mask, expected_mask) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 51bf73a18f7..e8f25342a18 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -56,6 +56,7 @@ rotate_bounding_box, rotate_image_tensor, rotate_image_pil, + rotate_segmentation_mask, pad_image_tensor, pad_image_pil, pad_bounding_box, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 71882f06270..7629766c0e2 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -324,7 +324,7 @@ def rotate_image_tensor( center_f = [0.0, 0.0] if center is not None: if expand: - warnings.warn("The provided center argument is ignored if expand is True") + warnings.warn("The provided center argument has no effect on the result if expand is True") else: _, height, width = get_dimensions_image_tensor(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. @@ -345,7 +345,7 @@ def rotate_image_pil( center: Optional[List[float]] = None, ) -> PIL.Image.Image: if center is not None and expand: - warnings.warn("The provided center argument is ignored if expand is True") + warnings.warn("The provided center argument has no effect on the result if expand is True") center = None return _FP.rotate( @@ -361,6 +361,10 @@ def rotate_bounding_box( expand: bool = False, center: Optional[List[float]] = None, ) -> torch.Tensor: + if center is not None and expand: + warnings.warn("The provided center argument has no effect on the result if expand is True") + center = None + original_shape = bounding_box.shape bounding_box = convert_bounding_box_format( bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY @@ -373,6 +377,21 @@ def rotate_bounding_box( ).view(original_shape) +def rotate_segmentation_mask( + img: torch.Tensor, + angle: float, + expand: bool = False, + center: Optional[List[float]] = None, +) -> torch.Tensor: + return rotate_image_tensor( + img, + angle=angle, + expand=expand, + interpolation=InterpolationMode.NEAREST, + center=center, + ) + + pad_image_tensor = _FT.pad pad_image_pil = _FP.pad