diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index eb90508fa5c..aab01904026 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -461,9 +461,7 @@ def transform(bbox): ], dtype=bbox.dtype, ) - return F.convert_format_bounding_box( - out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False - ) + return F.convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format) if bounding_box.ndim < 2: bounding_box = [bounding_box] @@ -556,17 +554,12 @@ def sample_inputs_affine_video(): def sample_inputs_convert_format_bounding_box(): - formats = set(features.BoundingBoxFormat) - for bounding_box_loader in make_bounding_box_loaders(formats=formats): - old_format = bounding_box_loader.format - for params in combinations_grid(new_format=formats - {old_format}, copy=(True, False)): - yield ArgsKwargs(bounding_box_loader, old_format=old_format, **params) - + formats = list(features.BoundingBoxFormat) + for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats): + yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format) -def reference_convert_format_bounding_box(bounding_box, old_format, new_format, copy): - if not copy: - raise pytest.UsageError("Reference for `convert_format_bounding_box` only supports `copy=True`") +def reference_convert_format_bounding_box(bounding_box, old_format, new_format): return torchvision.ops.box_convert( bounding_box, in_fmt=old_format.kernel_name.lower(), out_fmt=new_format.kernel_name.lower() ) @@ -574,8 +567,7 @@ def reference_convert_format_bounding_box(bounding_box, old_format, new_format, def reference_inputs_convert_format_bounding_box(): for args_kwargs in sample_inputs_convert_color_space_image_tensor(): - (image_loader, *other_args), kwargs = args_kwargs - if len(image_loader.shape) == 2 and kwargs.setdefault("copy", True): + if len(args_kwargs.args[0].shape) == 2: yield args_kwargs @@ -600,11 +592,11 @@ def sample_inputs_convert_color_space_image_tensor(): for image_loader in make_image_loaders( sizes=["random"], color_spaces=[color_space], dtypes=[torch.float32], constant_alpha=True ): - yield ArgsKwargs(image_loader, old_color_space=color_space, new_color_space=color_space, copy=False) + yield ArgsKwargs(image_loader, old_color_space=color_space, new_color_space=color_space) @pil_reference_wrapper -def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_color_space, copy=True): +def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_color_space): color_space_pil = features.ColorSpace.from_pil_mode(image_pil.mode) if color_space_pil != old_color_space: raise pytest.UsageError( @@ -612,7 +604,7 @@ def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_c f"from {old_color_space} to {color_space_pil}" ) - return F.convert_color_space_image_pil(image_pil, color_space=new_color_space, copy=copy) + return F.convert_color_space_image_pil(image_pil, color_space=new_color_space) def reference_inputs_convert_color_space_image_tensor(): diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 3423006e2eb..dc867f8ffa4 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -478,9 +478,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): device=bbox.device, ) return ( - convert_format_bounding_box( - out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False - ), + convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format), (height, width), ) @@ -733,14 +731,16 @@ def _compute_expected_bbox(bbox, padding_): bbox_format = bbox.format bbox_dtype = bbox.dtype - bbox = convert_format_bounding_box(bbox, old_format=bbox_format, new_format=features.BoundingBoxFormat.XYXY) + bbox = ( + bbox.clone() + if bbox_format == features.BoundingBoxFormat.XYXY + else convert_format_bounding_box(bbox, bbox_format, features.BoundingBoxFormat.XYXY) + ) bbox[0::2] += pad_left bbox[1::2] += pad_up - bbox = convert_format_bounding_box( - bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False - ) + bbox = convert_format_bounding_box(bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format) if bbox.dtype != bbox_dtype: # Temporary cast to original dtype # e.g. float32 -> int @@ -840,9 +840,7 @@ def _compute_expected_bbox(bbox, pcoeffs_): dtype=bbox.dtype, device=bbox.device, ) - return convert_format_bounding_box( - out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False - ) + return convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format) spatial_size = (32, 38) @@ -903,7 +901,7 @@ def _compute_expected_bbox(bbox, output_size_): dtype=bbox.dtype, device=bbox.device, ) - return convert_format_bounding_box(out_bbox, features.BoundingBoxFormat.XYWH, format_, copy=False) + return convert_format_bounding_box(out_bbox, features.BoundingBoxFormat.XYWH, format_) for bboxes in make_bounding_boxes(extra_dims=((4,),)): bboxes = bboxes.to(device) diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 48f4b0950bb..d52989641a5 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -110,18 +110,6 @@ def spatial_size(self) -> Tuple[int, int]: def num_channels(self) -> int: return self.shape[-3] - def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Image: - if isinstance(color_space, str): - color_space = ColorSpace.from_str(color_space.upper()) - - return Image.wrap_like( - self, - self._F.convert_color_space_image_tensor( - self.as_subclass(torch.Tensor), old_color_space=self.color_space, new_color_space=color_space, copy=copy - ), - color_space=color_space, - ) - def horizontal_flip(self) -> Image: output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor)) return Image.wrap_like(self, output) diff --git a/torchvision/prototype/features/_video.py b/torchvision/prototype/features/_video.py index 6351ad5aa43..a4d30a49c7a 100644 --- a/torchvision/prototype/features/_video.py +++ b/torchvision/prototype/features/_video.py @@ -66,18 +66,6 @@ def num_channels(self) -> int: def num_frames(self) -> int: return self.shape[-4] - def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Video: - if isinstance(color_space, str): - color_space = ColorSpace.from_str(color_space.upper()) - - return Video.wrap_like( - self, - self._F.convert_color_space_video( - self.as_subclass(torch.Tensor), old_color_space=self.color_space, new_color_space=color_space, copy=copy - ), - color_space=color_space, - ) - def horizontal_flip(self) -> Video: output = self._F.horizontal_flip_video(self.as_subclass(torch.Tensor)) return Video.wrap_like(self, output) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 99b77eb4005..4a45f9f5788 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -265,7 +265,7 @@ def _copy_paste( # https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422 xyxy_boxes[:, 2:] += 1 boxes = F.convert_format_bounding_box( - xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False + xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format ) out_target["boxes"] = torch.cat([boxes, paste_boxes]) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 5c67bf0ec78..440e23ab631 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -655,9 +655,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: continue # check for any valid boxes with centers within the crop area - xyxy_bboxes = F.convert_format_bounding_box( - bboxes, old_format=bboxes.format, new_format=features.BoundingBoxFormat.XYXY, copy=True - ) + xyxy_bboxes = F.convert_format_bounding_box(bboxes, bboxes.format, features.BoundingBoxFormat.XYXY) cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2]) cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3]) is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom) @@ -801,22 +799,21 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: top = int(offset_height * r) left = int(offset_width * r) + bounding_boxes: Optional[torch.Tensor] try: bounding_boxes = query_bounding_box(flat_inputs) except ValueError: bounding_boxes = None if needs_crop and bounding_boxes is not None: - bounding_boxes = cast( - features.BoundingBox, F.crop(bounding_boxes, top=top, left=left, height=new_height, width=new_width) - ) - bounding_boxes = features.BoundingBox.wrap_like( - bounding_boxes, - F.clamp_bounding_box( - bounding_boxes, format=bounding_boxes.format, spatial_size=bounding_boxes.spatial_size - ), + format = bounding_boxes.format + bounding_boxes, spatial_size = F.crop_bounding_box( + bounding_boxes, format=format, top=top, left=left, height=new_height, width=new_width ) - height_and_width = bounding_boxes.to_format(features.BoundingBoxFormat.XYWH)[..., 2:] + bounding_boxes = F.clamp_bounding_box(bounding_boxes, format=format, spatial_size=spatial_size) + height_and_width = F.convert_format_bounding_box( + bounding_boxes, old_format=format, new_format=features.BoundingBoxFormat.XYWH + )[..., 2:] is_valid = torch.all(height_and_width > 0, dim=-1) else: is_valid = None diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 829efeab1a8..6e5a8139704 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -50,7 +50,6 @@ def __init__( self, color_space: Union[str, features.ColorSpace], old_color_space: Optional[Union[str, features.ColorSpace]] = None, - copy: bool = True, ) -> None: super().__init__() @@ -62,14 +61,10 @@ def __init__( old_color_space = features.ColorSpace.from_str(old_color_space) self.old_color_space = old_color_space - self.copy = copy - def _transform( self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any] ) -> Union[features.ImageType, features.VideoType]: - return F.convert_color_space( - inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy - ) + return F.convert_color_space(inpt, color_space=self.color_space, old_color_space=self.old_color_space) class ClampBoundingBoxes(Transform): diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 1c897700c57..1451b83cf26 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -36,14 +36,18 @@ def horizontal_flip_bounding_box( ) -> torch.Tensor: shape = bounding_box.shape - bounding_box = convert_format_bounding_box( - bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every + # BoundingBoxFormat instead of converting back and forth + bounding_box = ( + bounding_box.clone() + if format == features.BoundingBoxFormat.XYXY + else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY) ).reshape(-1, 4) bounding_box[:, [0, 2]] = spatial_size[1] - bounding_box[:, [2, 0]] return convert_format_bounding_box( - bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False + bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format ).reshape(shape) @@ -73,14 +77,18 @@ def vertical_flip_bounding_box( ) -> torch.Tensor: shape = bounding_box.shape - bounding_box = convert_format_bounding_box( - bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every + # BoundingBoxFormat instead of converting back and forth + bounding_box = ( + bounding_box.clone() + if format == features.BoundingBoxFormat.XYXY + else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY) ).reshape(-1, 4) bounding_box[:, [1, 3]] = spatial_size[0] - bounding_box[:, [3, 1]] return convert_format_bounding_box( - bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False + bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format ).reshape(shape) @@ -394,8 +402,9 @@ def affine_bounding_box( center: Optional[List[float]] = None, ) -> torch.Tensor: original_shape = bounding_box.shape - bounding_box = convert_format_bounding_box( - bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + + bounding_box = ( + convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY) ).reshape(-1, 4) out_bboxes, _ = _affine_bounding_box_xyxy(bounding_box, spatial_size, angle, translate, scale, shear, center) @@ -403,7 +412,7 @@ def affine_bounding_box( # out_bboxes should be of shape [N boxes, 4] return convert_format_bounding_box( - out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False + out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format ).reshape(original_shape) @@ -583,8 +592,8 @@ def rotate_bounding_box( center = None original_shape = bounding_box.shape - bounding_box = convert_format_bounding_box( - bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + bounding_box = ( + convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY) ).reshape(-1, 4) out_bboxes, spatial_size = _affine_bounding_box_xyxy( @@ -599,9 +608,9 @@ def rotate_bounding_box( ) return ( - convert_format_bounding_box( - out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False - ).reshape(original_shape), + convert_format_bounding_box(out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format).reshape( + original_shape + ), spatial_size, ) @@ -818,8 +827,12 @@ def crop_bounding_box( height: int, width: int, ) -> Tuple[torch.Tensor, Tuple[int, int]]: - bounding_box = convert_format_bounding_box( - bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every + # BoundingBoxFormat instead of converting back and forth + bounding_box = ( + bounding_box.clone() + if format == features.BoundingBoxFormat.XYXY + else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY) ) # Crop or implicit pad if left and/or top have negative values: @@ -827,9 +840,7 @@ def crop_bounding_box( bounding_box[..., 1::2] -= top return ( - convert_format_bounding_box( - bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False - ), + convert_format_bounding_box(bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format), (height, width), ) @@ -896,8 +907,8 @@ def perspective_bounding_box( raise ValueError("Argument perspective_coeffs should have 8 float values") original_shape = bounding_box.shape - bounding_box = convert_format_bounding_box( - bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + bounding_box = ( + convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY) ).reshape(-1, 4) dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32 @@ -967,7 +978,7 @@ def perspective_bounding_box( # out_bboxes should be of shape [N boxes, 4] return convert_format_bounding_box( - out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False + out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format ).reshape(original_shape) @@ -1061,8 +1072,8 @@ def elastic_bounding_box( displacement = displacement.to(bounding_box.device) original_shape = bounding_box.shape - bounding_box = convert_format_bounding_box( - bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + bounding_box = ( + convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY) ).reshape(-1, 4) # Question (vfdev-5): should we rely on good displacement shape and fetch image size from it @@ -1088,7 +1099,7 @@ def elastic_bounding_box( out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype) return convert_format_bounding_box( - out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False + out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format ).reshape(original_shape) diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 61a54f01cc9..57155656212 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -125,13 +125,10 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor: def convert_format_bounding_box( - bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, copy: bool = True + bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat ) -> torch.Tensor: if new_format == old_format: - if copy: - return bounding_box.clone() - else: - return bounding_box + return bounding_box if old_format == BoundingBoxFormat.XYWH: bounding_box = _xywh_to_xyxy(bounding_box) @@ -149,12 +146,16 @@ def convert_format_bounding_box( def clamp_bounding_box( bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int] ) -> torch.Tensor: - # TODO: (PERF) Possible speed up clamping if we have different implementations for each bbox format. - # Not sure if they yield equivalent results. - xyxy_boxes = convert_format_bounding_box(bounding_box, format, BoundingBoxFormat.XYXY) + # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every + # BoundingBoxFormat instead of converting back and forth + xyxy_boxes = ( + bounding_box.clone() + if format == BoundingBoxFormat.XYXY + else convert_format_bounding_box(bounding_box, format, BoundingBoxFormat.XYXY) + ) xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1]) xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0]) - return convert_format_bounding_box(xyxy_boxes, BoundingBoxFormat.XYXY, format, copy=False) + return convert_format_bounding_box(xyxy_boxes, BoundingBoxFormat.XYXY, format) def _split_alpha(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -192,13 +193,10 @@ def _rgb_to_gray(image: torch.Tensor) -> torch.Tensor: def convert_color_space_image_tensor( - image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace, copy: bool = True + image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace ) -> torch.Tensor: if new_color_space == old_color_space: - if copy: - return image.clone() - else: - return image + return image if old_color_space == ColorSpace.OTHER or new_color_space == ColorSpace.OTHER: raise RuntimeError(f"Conversion to or from {ColorSpace.OTHER} is not supported.") @@ -242,34 +240,29 @@ def convert_color_space_image_tensor( @torch.jit.unused -def convert_color_space_image_pil( - image: PIL.Image.Image, color_space: ColorSpace, copy: bool = True -) -> PIL.Image.Image: +def convert_color_space_image_pil(image: PIL.Image.Image, color_space: ColorSpace) -> PIL.Image.Image: old_mode = image.mode try: new_mode = _COLOR_SPACE_TO_PIL_MODE[color_space] except KeyError: raise ValueError(f"Conversion from {ColorSpace.from_pil_mode(old_mode)} to {color_space} is not supported.") - if not copy and image.mode == new_mode: + if image.mode == new_mode: return image return image.convert(new_mode) def convert_color_space_video( - video: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace, copy: bool = True + video: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace ) -> torch.Tensor: - return convert_color_space_image_tensor( - video, old_color_space=old_color_space, new_color_space=new_color_space, copy=copy - ) + return convert_color_space_image_tensor(video, old_color_space=old_color_space, new_color_space=new_color_space) def convert_color_space( inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], color_space: ColorSpace, old_color_space: Optional[ColorSpace] = None, - copy: bool = True, ) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]: if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) @@ -279,10 +272,16 @@ def convert_color_space( "In order to convert the color space of simple tensors, " "the `old_color_space=...` parameter needs to be passed." ) - return convert_color_space_image_tensor( - inpt, old_color_space=old_color_space, new_color_space=color_space, copy=copy + return convert_color_space_image_tensor(inpt, old_color_space=old_color_space, new_color_space=color_space) + elif isinstance(inpt, features.Image): + output = convert_color_space_image_tensor( + inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space + ) + return features.Image.wrap_like(inpt, output, color_space=color_space) + elif isinstance(inpt, features.Video): + output = convert_color_space_video( + inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space ) - elif isinstance(inpt, (features.Image, features.Video)): - return inpt.to_color_space(color_space, copy=copy) + return features.Video.wrap_like(inpt, output, color_space=color_space) else: - return convert_color_space_image_pil(inpt, color_space, copy=copy) + return convert_color_space_image_pil(inpt, color_space)