diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index b9c89b2b76a..784394d2955 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -799,7 +799,9 @@ def test_assertions(self): with pytest.raises(ValueError, match="Kernel size value should be an odd and positive number"): transforms.GaussianBlur(4) - with pytest.raises(TypeError, match="sigma should be a single float or a list/tuple with length 2"): + with pytest.raises( + TypeError, match="sigma should be a single int or float or a list/tuple with length 2 floats." + ): transforms.GaussianBlur(3, sigma=[1, 2, 3]) with pytest.raises(ValueError, match="If sigma is a single number, it must be positive"): @@ -833,7 +835,7 @@ def test__transform(self, kernel_size, sigma, mocker): if isinstance(sigma, (tuple, list)): assert transform.sigma == sigma else: - assert transform.sigma == (sigma, sigma) + assert transform.sigma == [sigma, sigma] fn = mocker.patch("torchvision.prototype.transforms.functional.gaussian_blur") inpt = mocker.MagicMock(spec=features.Image) diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 0b61439d10c..33d89011ed7 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -84,6 +84,8 @@ def resize( # type: ignore[override] antialias: bool = False, ) -> BoundingBox: output = self._F.resize_bounding_box(self, size, image_size=self.image_size, max_size=max_size) + if isinstance(size, int): + size = [size] image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1]) return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype) @@ -95,6 +97,8 @@ def center_crop(self, output_size: List[int]) -> BoundingBox: output = self._F.center_crop_bounding_box( self, format=self.format, output_size=output_size, image_size=self.image_size ) + if isinstance(output_size, int): + output_size = [output_size] image_size = (output_size[0], output_size[0]) if len(output_size) == 1 else (output_size[0], output_size[1]) return BoundingBox.new_like(self, output, image_size=image_size) @@ -160,7 +164,7 @@ def rotate( def affine( self, - angle: float, + angle: Union[int, float], translate: List[float], scale: float, shear: List[float], diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 3d4357b9a99..b3f2172895d 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -169,7 +169,7 @@ def rotate( def affine( self, - angle: float, + angle: Union[int, float], translate: List[float], scale: float, shear: List[float], diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 0b832ae0270..21126c7f254 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -198,7 +198,7 @@ def rotate( def affine( self, - angle: float, + angle: Union[int, float], translate: List[float], scale: float, shear: List[float], diff --git a/torchvision/prototype/features/_mask.py b/torchvision/prototype/features/_mask.py index a0c3395dbe7..9dd614752a6 100644 --- a/torchvision/prototype/features/_mask.py +++ b/torchvision/prototype/features/_mask.py @@ -70,7 +70,7 @@ def rotate( def affine( self, - angle: float, + angle: Union[int, float], translate: List[float], scale: float, shear: List[float], diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index babcb83af04..df1f09fc192 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -19,6 +19,7 @@ _check_sequence_input, _setup_angle, _setup_fill_arg, + _setup_float_or_seq, _setup_size, has_all, has_any, @@ -67,9 +68,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class CenterCrop(Transform): - def __init__(self, size: List[int]): + def __init__(self, size: Union[int, Sequence[int]]): super().__init__() - self.size = size + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.center_crop(inpt, output_size=self.size) @@ -320,7 +321,7 @@ def __init__( degrees: Union[numbers.Number, Sequence], translate: Optional[Sequence[float]] = None, scale: Optional[Sequence[float]] = None, - shear: Optional[Union[float, Sequence[float]]] = None, + shear: Optional[Union[int, float, Sequence[float]]] = None, interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, center: Optional[List[float]] = None, @@ -545,23 +546,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: ) -def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]: - if not isinstance(arg, (float, Sequence)): - raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}") - if isinstance(arg, Sequence) and len(arg) != req_size: - raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}") - if isinstance(arg, Sequence): - for element in arg: - if not isinstance(element, float): - raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}") - - if isinstance(arg, float): - arg = [float(arg), float(arg)] - if isinstance(arg, (list, tuple)) and len(arg) == 1: - arg = [arg[0], arg[0]] - return arg - - class ElasticTransform(Transform): def __init__( self, diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index db93378312f..976e9f8b5ff 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -8,7 +8,7 @@ from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform -from ._utils import _setup_size, has_any, query_bounding_box +from ._utils import _setup_float_or_seq, _setup_size, has_any, query_bounding_box class Identity(Transform): @@ -112,7 +112,7 @@ def forward(self, *inpts: Any) -> Any: class GaussianBlur(Transform): def __init__( - self, kernel_size: Union[int, Sequence[int]], sigma: Union[float, Sequence[float]] = (0.1, 2.0) + self, kernel_size: Union[int, Sequence[int]], sigma: Union[int, float, Sequence[float]] = (0.1, 2.0) ) -> None: super().__init__() self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") @@ -120,17 +120,17 @@ def __init__( if ks <= 0 or ks % 2 == 0: raise ValueError("Kernel size value should be an odd and positive number.") - if isinstance(sigma, float): + if isinstance(sigma, (int, float)): if sigma <= 0: raise ValueError("If sigma is a single number, it must be positive.") - sigma = (sigma, sigma) + sigma = float(sigma) elif isinstance(sigma, Sequence) and len(sigma) == 2: if not 0.0 < sigma[0] <= sigma[1]: raise ValueError("sigma values should be positive and of the form (min, max).") else: - raise TypeError("sigma should be a single float or a list/tuple with length 2 floats.") + raise TypeError("sigma should be a single int or float or a list/tuple with length 2 floats.") - self.sigma = sigma + self.sigma = _setup_float_or_seq(sigma, "sigma", 2) def _get_params(self, sample: Any) -> Dict[str, Any]: sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item() diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 7107b14b3e0..219e6e50586 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -16,6 +16,23 @@ from typing_extensions import Literal +def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]: + if not isinstance(arg, (float, Sequence)): + raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}") + if isinstance(arg, Sequence) and len(arg) != req_size: + raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}") + if isinstance(arg, Sequence): + for element in arg: + if not isinstance(element, float): + raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}") + + if isinstance(arg, float): + arg = [float(arg), float(arg)] + if isinstance(arg, (list, tuple)) and len(arg) == 1: + arg = [arg[0], arg[0]] + return arg + + def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: if isinstance(fill, dict): for key, value in fill.items(): diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 87b65868bf9..1df5df2a3ef 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -97,6 +97,8 @@ def resize_image_tensor( max_size: Optional[int] = None, antialias: bool = False, ) -> torch.Tensor: + if isinstance(size, int): + size = [size] num_channels, old_height, old_width = get_dimensions_image_tensor(image) new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size) extra_dims = image.shape[:-3] @@ -145,6 +147,8 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N def resize_bounding_box( bounding_box: torch.Tensor, size: List[int], image_size: Tuple[int, int], max_size: Optional[int] = None ) -> torch.Tensor: + if isinstance(size, int): + size = [size] old_height, old_width = image_size new_height, new_width = _compute_resized_output_size(image_size, size=size, max_size=max_size) ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device) @@ -171,7 +175,7 @@ def resize( def _affine_parse_args( - angle: float, + angle: Union[int, float], translate: List[float], scale: float, shear: List[float], @@ -214,15 +218,18 @@ def _affine_parse_args( if len(shear) != 2: raise ValueError(f"Shear should be a sequence containing two values. Got {shear}") - if center is not None and not isinstance(center, (list, tuple)): - raise TypeError("Argument center should be a sequence") + if center is not None: + if not isinstance(center, (list, tuple)): + raise TypeError("Argument center should be a sequence") + else: + center = [float(c) for c in center] return angle, translate, shear, center def affine_image_tensor( img: torch.Tensor, - angle: float, + angle: Union[int, float], translate: List[float], scale: float, shear: List[float], @@ -254,7 +261,7 @@ def affine_image_tensor( @torch.jit.unused def affine_image_pil( img: PIL.Image.Image, - angle: float, + angle: Union[int, float], translate: List[float], scale: float, shear: List[float], @@ -278,34 +285,26 @@ def affine_image_pil( def _affine_bounding_box_xyxy( bounding_box: torch.Tensor, image_size: Tuple[int, int], - angle: float, - translate: Optional[List[float]] = None, - scale: Optional[float] = None, - shear: Optional[List[float]] = None, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], center: Optional[List[float]] = None, expand: bool = False, ) -> torch.Tensor: - dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32 - device = bounding_box.device - - if translate is None: - translate = [0.0, 0.0] - - if scale is None: - scale = 1.0 - - if shear is None: - shear = [0.0, 0.0] + angle, translate, shear, center = _affine_parse_args( + angle, translate, scale, shear, InterpolationMode.NEAREST, center + ) if center is None: height, width = image_size - center_f = [width * 0.5, height * 0.5] - else: - center_f = [float(c) for c in center] + center = [width * 0.5, height * 0.5] + + dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32 + device = bounding_box.device - translate_f = [float(t) for t in translate] affine_matrix = torch.tensor( - _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear, inverted=False), + _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False), dtype=dtype, device=device, ).view(2, 3) @@ -351,7 +350,7 @@ def affine_bounding_box( bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int], - angle: float, + angle: Union[int, float], translate: List[float], scale: float, shear: List[float], @@ -373,7 +372,7 @@ def affine_bounding_box( def affine_mask( mask: torch.Tensor, - angle: float, + angle: Union[int, float], translate: List[float], scale: float, shear: List[float], @@ -419,7 +418,7 @@ def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT: def affine( inpt: features.InputTypeJIT, - angle: float, + angle: Union[int, float], translate: List[float], scale: float, shear: List[float], @@ -427,6 +426,7 @@ def affine( fill: features.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> features.InputTypeJIT: + # TODO: consider deprecating integers from angle and shear on the future if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return affine_image_tensor( inpt, @@ -528,7 +528,16 @@ def rotate_bounding_box( bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY ).view(-1, 4) - out_bboxes = _affine_bounding_box_xyxy(bounding_box, image_size, angle=-angle, center=center, expand=expand) + out_bboxes = _affine_bounding_box_xyxy( + bounding_box, + image_size, + angle=-angle, + translate=[0.0, 0.0], + scale=1.0, + shear=[0.0, 0.0], + center=center, + expand=expand, + ) return convert_format_bounding_box( out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 952dc0d9e0d..052c79b6c45 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -23,6 +23,7 @@ def normalize( def gaussian_blur_image_tensor( img: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None ) -> torch.Tensor: + # TODO: consider deprecating integers from sigma on the future if isinstance(kernel_size, int): kernel_size = [kernel_size, kernel_size] if len(kernel_size) != 2: