diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 167b839eef9..b881ebc502a 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1534,7 +1534,7 @@ def test__get_params(self, mocker): assert int(spatial_size[1] * r_min) <= width <= int(spatial_size[1] * r_max) def test__transform(self, mocker): - interpolation_sentinel = mocker.MagicMock() + interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode) antialias_sentinel = mocker.MagicMock() transform = transforms.ScaleJitter( @@ -1581,7 +1581,7 @@ def test__get_params(self, min_size, max_size, mocker): assert shorter in min_size def test__transform(self, mocker): - interpolation_sentinel = mocker.MagicMock() + interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode) antialias_sentinel = mocker.MagicMock() transform = transforms.RandomShortestSize( @@ -1945,7 +1945,7 @@ def test__get_params(self): assert min_size <= size < max_size def test__transform(self, mocker): - interpolation_sentinel = mocker.MagicMock() + interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode) antialias_sentinel = mocker.MagicMock() transform = transforms.RandomResize( diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index f0a7b44db3b..bb20f8a7b3a 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -88,6 +88,9 @@ def __init__( ArgsKwargs((32, 29)), ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST), ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC), + ArgsKwargs((30, 27), interpolation=PIL.Image.NEAREST), + ArgsKwargs((35, 29), interpolation=PIL.Image.BILINEAR), + ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC), NotScriptableArgsKwargs(31, max_size=32), ArgsKwargs([31], max_size=32), NotScriptableArgsKwargs(30, max_size=100), @@ -305,6 +308,8 @@ def __init__( ArgsKwargs(25, ratio=(0.5, 1.5)), ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST), ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC), + ArgsKwargs((31, 28), interpolation=PIL.Image.NEAREST), + ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC), ArgsKwargs((29, 32), antialias=False), ArgsKwargs((28, 31), antialias=True), ], @@ -352,6 +357,8 @@ def __init__( ArgsKwargs(sigma=(2.5, 3.9)), ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.NEAREST), ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.BICUBIC), + ArgsKwargs(interpolation=PIL.Image.NEAREST), + ArgsKwargs(interpolation=PIL.Image.BICUBIC), ArgsKwargs(fill=1), ], # ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image @@ -386,6 +393,7 @@ def __init__( ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)), ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)), ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.NEAREST), + ArgsKwargs(degrees=30.0, interpolation=PIL.Image.NEAREST), ArgsKwargs(degrees=30.0, fill=1), ArgsKwargs(degrees=30.0, fill=(2, 3, 4)), ArgsKwargs(degrees=30.0, center=(0, 0)), @@ -420,6 +428,7 @@ def __init__( ArgsKwargs(p=1), ArgsKwargs(p=1, distortion_scale=0.3), ArgsKwargs(p=1, distortion_scale=0.2, interpolation=prototype_transforms.InterpolationMode.NEAREST), + ArgsKwargs(p=1, distortion_scale=0.2, interpolation=PIL.Image.NEAREST), ArgsKwargs(p=1, distortion_scale=0.1, fill=1), ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)), ], @@ -432,6 +441,7 @@ def __init__( ArgsKwargs(degrees=30.0), ArgsKwargs(degrees=(-20.0, 10.0)), ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.BILINEAR), + ArgsKwargs(degrees=30.0, interpolation=PIL.Image.BILINEAR), ArgsKwargs(degrees=30.0, expand=True), ArgsKwargs(degrees=30.0, center=(0, 0)), ArgsKwargs(degrees=30.0, fill=1), @@ -851,7 +861,11 @@ class TestAATransforms: ) @pytest.mark.parametrize( "interpolation", - [prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR], + [ + prototype_transforms.InterpolationMode.NEAREST, + prototype_transforms.InterpolationMode.BILINEAR, + PIL.Image.NEAREST, + ], ) def test_randaug(self, inpt, interpolation, mocker): t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1) @@ -889,7 +903,11 @@ def test_randaug(self, inpt, interpolation, mocker): ) @pytest.mark.parametrize( "interpolation", - [prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR], + [ + prototype_transforms.InterpolationMode.NEAREST, + prototype_transforms.InterpolationMode.BILINEAR, + PIL.Image.NEAREST, + ], ) def test_trivial_aug(self, inpt, interpolation, mocker): t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation) @@ -937,7 +955,11 @@ def test_trivial_aug(self, inpt, interpolation, mocker): ) @pytest.mark.parametrize( "interpolation", - [prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR], + [ + prototype_transforms.InterpolationMode.NEAREST, + prototype_transforms.InterpolationMode.BILINEAR, + PIL.Image.NEAREST, + ], ) def test_augmix(self, inpt, interpolation, mocker): t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) @@ -986,7 +1008,11 @@ def test_augmix(self, inpt, interpolation, mocker): ) @pytest.mark.parametrize( "interpolation", - [prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR], + [ + prototype_transforms.InterpolationMode.NEAREST, + prototype_transforms.InterpolationMode.BILINEAR, + PIL.Image.NEAREST, + ], ) def test_aa(self, inpt, interpolation): aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet") @@ -1264,13 +1290,13 @@ def test_random_resize_eval(self, mocker): (legacy_F.convert_image_dtype, {}), (legacy_F.to_pil_image, {}), (legacy_F.normalize, {}), - (legacy_F.resize, {}), + (legacy_F.resize, {"interpolation"}), (legacy_F.pad, {"padding", "fill"}), (legacy_F.crop, {}), (legacy_F.center_crop, {}), - (legacy_F.resized_crop, {}), + (legacy_F.resized_crop, {"interpolation"}), (legacy_F.hflip, {}), - (legacy_F.perspective, {"startpoints", "endpoints", "fill"}), + (legacy_F.perspective, {"startpoints", "endpoints", "fill", "interpolation"}), (legacy_F.vflip, {}), (legacy_F.five_crop, {}), (legacy_F.ten_crop, {}), @@ -1279,8 +1305,8 @@ def test_random_resize_eval(self, mocker): (legacy_F.adjust_saturation, {}), (legacy_F.adjust_hue, {}), (legacy_F.adjust_gamma, {}), - (legacy_F.rotate, {"center", "fill"}), - (legacy_F.affine, {"angle", "translate", "center", "fill"}), + (legacy_F.rotate, {"center", "fill", "interpolation"}), + (legacy_F.affine, {"angle", "translate", "center", "fill", "interpolation"}), (legacy_F.to_grayscale, {}), (legacy_F.rgb_to_grayscale, {}), (legacy_F.to_tensor, {}), @@ -1292,7 +1318,7 @@ def test_random_resize_eval(self, mocker): (legacy_F.adjust_sharpness, {}), (legacy_F.autocontrast, {}), (legacy_F.equalize, {}), - (legacy_F.elastic_transform, {"fill"}), + (legacy_F.elastic_transform, {"fill", "interpolation"}), ], ) def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params): diff --git a/torchvision/prototype/datapoints/_bounding_box.py b/torchvision/prototype/datapoints/_bounding_box.py index e04a965d9fc..b904dd5e5aa 100644 --- a/torchvision/prototype/datapoints/_bounding_box.py +++ b/torchvision/prototype/datapoints/_bounding_box.py @@ -76,7 +76,7 @@ def vertical_flip(self) -> BoundingBox: def resize( # type: ignore[override] self, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[Union[str, bool]] = "warn", ) -> BoundingBox: @@ -107,7 +107,7 @@ def resized_crop( height: int, width: int, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ) -> BoundingBox: output, spatial_size = self._F.resized_crop_bounding_box( @@ -133,7 +133,7 @@ def pad( def rotate( self, angle: float, - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, fill: FillTypeJIT = None, @@ -154,7 +154,7 @@ def affine( translate: List[float], scale: float, shear: List[float], - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> BoundingBox: @@ -174,7 +174,7 @@ def perspective( self, startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> BoundingBox: @@ -191,7 +191,7 @@ def perspective( def elastic( self, displacement: torch.Tensor, - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: FillTypeJIT = None, ) -> BoundingBox: output = self._F.elastic_bounding_box( diff --git a/torchvision/prototype/datapoints/_datapoint.py b/torchvision/prototype/datapoints/_datapoint.py index 3738d2a8124..5f4a0d96ea2 100644 --- a/torchvision/prototype/datapoints/_datapoint.py +++ b/torchvision/prototype/datapoints/_datapoint.py @@ -143,7 +143,7 @@ def vertical_flip(self) -> Datapoint: def resize( # type: ignore[override] self, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[Union[str, bool]] = "warn", ) -> Datapoint: @@ -162,7 +162,7 @@ def resized_crop( height: int, width: int, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ) -> Datapoint: return self @@ -178,7 +178,7 @@ def pad( def rotate( self, angle: float, - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, fill: FillTypeJIT = None, @@ -191,7 +191,7 @@ def affine( translate: List[float], scale: float, shear: List[float], - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> Datapoint: @@ -201,7 +201,7 @@ def perspective( self, startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> Datapoint: @@ -210,7 +210,7 @@ def perspective( def elastic( self, displacement: torch.Tensor, - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: FillTypeJIT = None, ) -> Datapoint: return self diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index 8f3092fa1e7..4fc14323abe 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -62,7 +62,7 @@ def vertical_flip(self) -> Image: def resize( # type: ignore[override] self, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[Union[str, bool]] = "warn", ) -> Image: @@ -86,7 +86,7 @@ def resized_crop( height: int, width: int, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ) -> Image: output = self._F.resized_crop_image_tensor( @@ -113,7 +113,7 @@ def pad( def rotate( self, angle: float, - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, fill: FillTypeJIT = None, @@ -129,7 +129,7 @@ def affine( translate: List[float], scale: float, shear: List[float], - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> Image: @@ -149,7 +149,7 @@ def perspective( self, startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> Image: @@ -166,7 +166,7 @@ def perspective( def elastic( self, displacement: torch.Tensor, - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: FillTypeJIT = None, ) -> Image: output = self._F.elastic_image_tensor( diff --git a/torchvision/prototype/datapoints/_mask.py b/torchvision/prototype/datapoints/_mask.py index a1870fa4b20..41dce097c6c 100644 --- a/torchvision/prototype/datapoints/_mask.py +++ b/torchvision/prototype/datapoints/_mask.py @@ -53,7 +53,7 @@ def vertical_flip(self) -> Mask: def resize( # type: ignore[override] self, size: List[int], - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, max_size: Optional[int] = None, antialias: Optional[Union[str, bool]] = "warn", ) -> Mask: @@ -75,7 +75,7 @@ def resized_crop( height: int, width: int, size: List[int], - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, antialias: Optional[Union[str, bool]] = "warn", ) -> Mask: output = self._F.resized_crop_mask(self.as_subclass(torch.Tensor), top, left, height, width, size=size) @@ -93,7 +93,7 @@ def pad( def rotate( self, angle: float, - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, fill: FillTypeJIT = None, @@ -107,7 +107,7 @@ def affine( translate: List[float], scale: float, shear: List[float], - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> Mask: @@ -126,7 +126,7 @@ def perspective( self, startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> Mask: @@ -138,7 +138,7 @@ def perspective( def elastic( self, displacement: torch.Tensor, - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: FillTypeJIT = None, ) -> Mask: output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill) diff --git a/torchvision/prototype/datapoints/_video.py b/torchvision/prototype/datapoints/_video.py index 0e5ff7a17b8..f62edd68eaf 100644 --- a/torchvision/prototype/datapoints/_video.py +++ b/torchvision/prototype/datapoints/_video.py @@ -57,7 +57,7 @@ def vertical_flip(self) -> Video: def resize( # type: ignore[override] self, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[Union[str, bool]] = "warn", ) -> Video: @@ -85,7 +85,7 @@ def resized_crop( height: int, width: int, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ) -> Video: output = self._F.resized_crop_video( @@ -112,7 +112,7 @@ def pad( def rotate( self, angle: float, - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, fill: FillTypeJIT = None, @@ -128,7 +128,7 @@ def affine( translate: List[float], scale: float, shear: List[float], - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> Video: @@ -148,7 +148,7 @@ def perspective( self, startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> Video: @@ -165,7 +165,7 @@ def perspective( def elastic( self, displacement: torch.Tensor, - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: FillTypeJIT = None, ) -> Video: output = self._F.elastic_video( diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 65b672b7edc..3ceabba5e42 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -10,6 +10,7 @@ from torchvision.ops import masks_to_boxes from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform +from torchvision.prototype.transforms.functional._geometry import _check_interpolation from ._transform import _RandomApplyTransform from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size @@ -203,11 +204,11 @@ class SimpleCopyPaste(Transform): def __init__( self, blending: bool = True, - resize_interpolation: InterpolationMode = F.InterpolationMode.BILINEAR, + resize_interpolation: Union[int, InterpolationMode] = F.InterpolationMode.BILINEAR, antialias: Optional[bool] = None, ) -> None: super().__init__() - self.resize_interpolation = resize_interpolation + self.resize_interpolation = _check_interpolation(resize_interpolation) self.blending = blending self.antialias = antialias diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 89bead236b2..67afecf5df1 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -8,6 +8,7 @@ from torchvision import transforms as _transforms from torchvision.prototype import datapoints from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform +from torchvision.prototype.transforms.functional._geometry import _check_interpolation from torchvision.prototype.transforms.functional._meta import get_spatial_size from torchvision.transforms import functional_tensor as _FT @@ -19,11 +20,11 @@ class _AutoAugmentBase(Transform): def __init__( self, *, - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, ) -> None: super().__init__() - self.interpolation = interpolation + self.interpolation = _check_interpolation(interpolation) self.fill = _setup_fill_arg(fill) def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]: @@ -79,7 +80,7 @@ def _apply_image_or_video_transform( image: Union[datapoints.ImageType, datapoints.VideoType], transform_id: str, magnitude: float, - interpolation: InterpolationMode, + interpolation: Union[InterpolationMode, int], fill: Dict[Type, datapoints.FillTypeJIT], ) -> Union[datapoints.ImageType, datapoints.VideoType]: fill_ = fill[type(image)] @@ -193,7 +194,7 @@ class AutoAugment(_AutoAugmentBase): def __init__( self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) @@ -350,7 +351,7 @@ def __init__( num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31, - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) @@ -403,7 +404,7 @@ class TrivialAugmentWide(_AutoAugmentBase): def __init__( self, num_magnitude_bins: int = 31, - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, ): super().__init__(interpolation=interpolation, fill=fill) @@ -461,7 +462,7 @@ def __init__( chain_depth: int = -1, alpha: float = 1.0, all_ops: bool = True, - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index c4708cc57bd..ffabb91471c 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -10,6 +10,7 @@ from torchvision.ops.boxes import box_iou from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform +from torchvision.prototype.transforms.functional._geometry import _check_interpolation from torchvision.transforms.functional import _get_perspective_coeffs from ._transform import _RandomApplyTransform @@ -45,7 +46,7 @@ class Resize(Transform): def __init__( self, size: Union[int, Sequence[int]], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[Union[str, bool]] = "warn", ) -> None: @@ -61,7 +62,7 @@ def __init__( ) self.size = size - self.interpolation = interpolation + self.interpolation = _check_interpolation(interpolation) self.max_size = max_size self.antialias = antialias @@ -94,7 +95,7 @@ def __init__( size: Union[int, Sequence[int]], scale: Tuple[float, float] = (0.08, 1.0), ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0), - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ) -> None: super().__init__() @@ -111,7 +112,7 @@ def __init__( self.scale = scale self.ratio = ratio - self.interpolation = interpolation + self.interpolation = _check_interpolation(interpolation) self.antialias = antialias self._log_ratio = torch.log(torch.tensor(self.ratio)) @@ -317,14 +318,14 @@ class RandomRotation(Transform): def __init__( self, degrees: Union[numbers.Number, Sequence], - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, expand: bool = False, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, center: Optional[List[float]] = None, ) -> None: super().__init__() self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) - self.interpolation = interpolation + self.interpolation = _check_interpolation(interpolation) self.expand = expand self.fill = _setup_fill_arg(fill) @@ -359,7 +360,7 @@ def __init__( translate: Optional[Sequence[float]] = None, scale: Optional[Sequence[float]] = None, shear: Optional[Union[int, float, Sequence[float]]] = None, - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, center: Optional[List[float]] = None, ) -> None: @@ -383,7 +384,7 @@ def __init__( else: self.shear = shear - self.interpolation = interpolation + self.interpolation = _check_interpolation(interpolation) self.fill = _setup_fill_arg(fill) if center is not None: @@ -546,7 +547,7 @@ def __init__( self, distortion_scale: float = 0.5, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, p: float = 0.5, ) -> None: super().__init__(p=p) @@ -555,7 +556,7 @@ def __init__( raise ValueError("Argument distortion_scale value should be between 0 and 1") self.distortion_scale = distortion_scale - self.interpolation = interpolation + self.interpolation = _check_interpolation(interpolation) self.fill = _setup_fill_arg(fill) def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: @@ -608,13 +609,13 @@ def __init__( alpha: Union[float, Sequence[float]] = 50.0, sigma: Union[float, Sequence[float]] = 5.0, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, ) -> None: super().__init__() self.alpha = _setup_float_or_seq(alpha, "alpha", 2) self.sigma = _setup_float_or_seq(sigma, "sigma", 2) - self.interpolation = interpolation + self.interpolation = _check_interpolation(interpolation) self.fill = _setup_fill_arg(fill) def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: @@ -760,13 +761,13 @@ def __init__( self, target_size: Tuple[int, int], scale_range: Tuple[float, float] = (0.1, 2.0), - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ): super().__init__() self.target_size = target_size self.scale_range = scale_range - self.interpolation = interpolation + self.interpolation = _check_interpolation(interpolation) self.antialias = antialias def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: @@ -788,13 +789,13 @@ def __init__( self, min_size: Union[List[int], Tuple[int], int], max_size: Optional[int] = None, - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ): super().__init__() self.min_size = [min_size] if isinstance(min_size, int) else list(min_size) self.max_size = max_size - self.interpolation = interpolation + self.interpolation = _check_interpolation(interpolation) self.antialias = antialias def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: @@ -935,13 +936,13 @@ def __init__( self, min_size: int, max_size: int, - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ) -> None: super().__init__() self.min_size = min_size self.max_size = max_size - self.interpolation = interpolation + self.interpolation = _check_interpolation(interpolation) self.antialias = antialias def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: diff --git a/torchvision/prototype/transforms/_presets.py b/torchvision/prototype/transforms/_presets.py index 86300b0494b..7f18e885c39 100644 --- a/torchvision/prototype/transforms/_presets.py +++ b/torchvision/prototype/transforms/_presets.py @@ -9,6 +9,8 @@ import torch from torch import Tensor +from torchvision.prototype.transforms.functional._geometry import _check_interpolation + from . import functional as F, InterpolationMode __all__ = ["StereoMatching"] @@ -22,7 +24,7 @@ def __init__( resize_size: Optional[Tuple[int, ...]], mean: Tuple[float, ...] = (0.5, 0.5, 0.5), std: Tuple[float, ...] = (0.5, 0.5, 0.5), - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, ) -> None: super().__init__() @@ -36,7 +38,7 @@ def __init__( self.mean = list(mean) self.std = list(std) - self.interpolation = interpolation + self.interpolation = _check_interpolation(interpolation) self.use_gray_scale = use_gray_scale def forward(self, left_image: Tensor, right_image: Tensor) -> Tuple[Tensor, Tensor]: diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 7fa0736ccb6..814697f03a3 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -13,6 +13,7 @@ _check_antialias, _compute_resized_output_size as __compute_resized_output_size, _get_perspective_coeffs, + _interpolation_modes_from_int, InterpolationMode, pil_modes_mapping, pil_to_tensor, @@ -27,6 +28,17 @@ from ._utils import is_simple_tensor +def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + elif not isinstance(interpolation, InterpolationMode): + raise ValueError( + f"Argument interpolation should be an `InterpolationMode` or a corresponding Pillow integer constant, " + f"but got {interpolation}." + ) + return interpolation + + def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: return image.flip(-1) @@ -142,10 +154,11 @@ def _compute_resized_output_size( def resize_image_tensor( image: torch.Tensor, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[Union[str, bool]] = "warn", ) -> torch.Tensor: + interpolation = _check_interpolation(interpolation) antialias = _check_antialias(img=image, antialias=antialias, interpolation=interpolation) assert not isinstance(antialias, str) antialias = False if antialias is None else antialias @@ -189,9 +202,10 @@ def resize_image_tensor( def resize_image_pil( image: PIL.Image.Image, size: Union[Sequence[int], int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, max_size: Optional[int] = None, ) -> PIL.Image.Image: + interpolation = _check_interpolation(interpolation) size = _compute_resized_output_size(image.size[::-1], size=size, max_size=max_size) # type: ignore[arg-type] return _FP.resize(image, size, interpolation=pil_modes_mapping[interpolation]) @@ -228,7 +242,7 @@ def resize_bounding_box( def resize_video( video: torch.Tensor, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[Union[str, bool]] = "warn", ) -> torch.Tensor: @@ -238,7 +252,7 @@ def resize_video( def resize( inpt: datapoints.InputTypeJIT, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[Union[str, bool]] = "warn", ) -> datapoints.InputTypeJIT: @@ -513,10 +527,12 @@ def affine_image_tensor( translate: List[float], scale: float, shear: List[float], - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: datapoints.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: + interpolation = _check_interpolation(interpolation) + if image.numel() == 0: return image @@ -563,10 +579,11 @@ def affine_image_pil( translate: List[float], scale: float, shear: List[float], - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: datapoints.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> PIL.Image.Image: + interpolation = _check_interpolation(interpolation) angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5) @@ -731,7 +748,7 @@ def affine_video( translate: List[float], scale: float, shear: List[float], - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: datapoints.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: @@ -753,7 +770,7 @@ def affine( translate: List[float], scale: float, shear: List[float], - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: datapoints.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> datapoints.InputTypeJIT: @@ -797,11 +814,13 @@ def affine( def rotate_image_tensor( image: torch.Tensor, angle: float, - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, fill: datapoints.FillTypeJIT = None, ) -> torch.Tensor: + interpolation = _check_interpolation(interpolation) + shape = image.shape num_channels, height, width = shape[-3:] @@ -840,11 +859,13 @@ def rotate_image_tensor( def rotate_image_pil( image: PIL.Image.Image, angle: float, - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, fill: datapoints.FillTypeJIT = None, ) -> PIL.Image.Image: + interpolation = _check_interpolation(interpolation) + if center is not None and expand: warnings.warn("The provided center argument has no effect on the result if expand is True") center = None @@ -910,7 +931,7 @@ def rotate_mask( def rotate_video( video: torch.Tensor, angle: float, - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, fill: datapoints.FillTypeJIT = None, @@ -921,7 +942,7 @@ def rotate_video( def rotate( inpt: datapoints.InputTypeJIT, angle: float, - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, fill: datapoints.FillTypeJIT = None, @@ -1281,11 +1302,13 @@ def perspective_image_tensor( image: torch.Tensor, startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: datapoints.FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> torch.Tensor: perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) + interpolation = _check_interpolation(interpolation) + if image.numel() == 0: return image @@ -1326,11 +1349,12 @@ def perspective_image_pil( image: PIL.Image.Image, startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], - interpolation: InterpolationMode = InterpolationMode.BICUBIC, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BICUBIC, fill: datapoints.FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> PIL.Image.Image: perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) + interpolation = _check_interpolation(interpolation) return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill) @@ -1455,7 +1479,7 @@ def perspective_video( video: torch.Tensor, startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: datapoints.FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> torch.Tensor: @@ -1468,7 +1492,7 @@ def perspective( inpt: datapoints.InputTypeJIT, startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: datapoints.FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> datapoints.InputTypeJIT: @@ -1496,9 +1520,11 @@ def perspective( def elastic_image_tensor( image: torch.Tensor, displacement: torch.Tensor, - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: datapoints.FillTypeJIT = None, ) -> torch.Tensor: + interpolation = _check_interpolation(interpolation) + if image.numel() == 0: return image @@ -1537,7 +1563,7 @@ def elastic_image_tensor( def elastic_image_pil( image: PIL.Image.Image, displacement: torch.Tensor, - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: datapoints.FillTypeJIT = None, ) -> PIL.Image.Image: t_img = pil_to_tensor(image) @@ -1630,7 +1656,7 @@ def elastic_mask( def elastic_video( video: torch.Tensor, displacement: torch.Tensor, - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: datapoints.FillTypeJIT = None, ) -> torch.Tensor: return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill) @@ -1639,7 +1665,7 @@ def elastic_video( def elastic( inpt: datapoints.InputTypeJIT, displacement: torch.Tensor, - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: datapoints.FillTypeJIT = None, ) -> datapoints.InputTypeJIT: if not torch.jit.is_scripting(): @@ -1778,7 +1804,7 @@ def resized_crop_image_tensor( height: int, width: int, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ) -> torch.Tensor: image = crop_image_tensor(image, top, left, height, width) @@ -1793,7 +1819,7 @@ def resized_crop_image_pil( height: int, width: int, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, ) -> PIL.Image.Image: image = crop_image_pil(image, top, left, height, width) return resize_image_pil(image, size, interpolation=interpolation) @@ -1831,7 +1857,7 @@ def resized_crop_video( height: int, width: int, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ) -> torch.Tensor: return resized_crop_image_tensor( @@ -1846,7 +1872,7 @@ def resized_crop( height: int, width: int, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ) -> datapoints.InputTypeJIT: if not torch.jit.is_scripting():