Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,7 +1534,7 @@ def test__get_params(self, mocker):
assert int(spatial_size[1] * r_min) <= width <= int(spatial_size[1] * r_max)

def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock()
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock()

transform = transforms.ScaleJitter(
Expand Down Expand Up @@ -1581,7 +1581,7 @@ def test__get_params(self, min_size, max_size, mocker):
assert shorter in min_size

def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock()
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock()

transform = transforms.RandomShortestSize(
Expand Down Expand Up @@ -1945,7 +1945,7 @@ def test__get_params(self):
assert min_size <= size < max_size

def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock()
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock()

transform = transforms.RandomResize(
Expand Down
46 changes: 36 additions & 10 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def __init__(
ArgsKwargs((32, 29)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((30, 27), interpolation=PIL.Image.NEAREST),
ArgsKwargs((35, 29), interpolation=PIL.Image.BILINEAR),
ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC),
NotScriptableArgsKwargs(31, max_size=32),
ArgsKwargs([31], max_size=32),
NotScriptableArgsKwargs(30, max_size=100),
Expand Down Expand Up @@ -305,6 +308,8 @@ def __init__(
ArgsKwargs(25, ratio=(0.5, 1.5)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((31, 28), interpolation=PIL.Image.NEAREST),
ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC),
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
],
Expand Down Expand Up @@ -352,6 +357,8 @@ def __init__(
ArgsKwargs(sigma=(2.5, 3.9)),
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs(interpolation=PIL.Image.NEAREST),
ArgsKwargs(interpolation=PIL.Image.BICUBIC),
ArgsKwargs(fill=1),
],
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
Expand Down Expand Up @@ -386,6 +393,7 @@ def __init__(
ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)),
ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)),
ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(degrees=30.0, interpolation=PIL.Image.NEAREST),
ArgsKwargs(degrees=30.0, fill=1),
ArgsKwargs(degrees=30.0, fill=(2, 3, 4)),
ArgsKwargs(degrees=30.0, center=(0, 0)),
Expand Down Expand Up @@ -420,6 +428,7 @@ def __init__(
ArgsKwargs(p=1),
ArgsKwargs(p=1, distortion_scale=0.3),
ArgsKwargs(p=1, distortion_scale=0.2, interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(p=1, distortion_scale=0.2, interpolation=PIL.Image.NEAREST),
ArgsKwargs(p=1, distortion_scale=0.1, fill=1),
ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)),
],
Expand All @@ -432,6 +441,7 @@ def __init__(
ArgsKwargs(degrees=30.0),
ArgsKwargs(degrees=(-20.0, 10.0)),
ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.BILINEAR),
ArgsKwargs(degrees=30.0, interpolation=PIL.Image.BILINEAR),
ArgsKwargs(degrees=30.0, expand=True),
ArgsKwargs(degrees=30.0, center=(0, 0)),
ArgsKwargs(degrees=30.0, fill=1),
Expand Down Expand Up @@ -851,7 +861,11 @@ class TestAATransforms:
)
@pytest.mark.parametrize(
"interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
[
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_randaug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
Expand Down Expand Up @@ -889,7 +903,11 @@ def test_randaug(self, inpt, interpolation, mocker):
)
@pytest.mark.parametrize(
"interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
[
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_trivial_aug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
Expand Down Expand Up @@ -937,7 +955,11 @@ def test_trivial_aug(self, inpt, interpolation, mocker):
)
@pytest.mark.parametrize(
"interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
[
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_augmix(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
Expand Down Expand Up @@ -986,7 +1008,11 @@ def test_augmix(self, inpt, interpolation, mocker):
)
@pytest.mark.parametrize(
"interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
[
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_aa(self, inpt, interpolation):
aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
Expand Down Expand Up @@ -1264,13 +1290,13 @@ def test_random_resize_eval(self, mocker):
(legacy_F.convert_image_dtype, {}),
(legacy_F.to_pil_image, {}),
(legacy_F.normalize, {}),
(legacy_F.resize, {}),
(legacy_F.resize, {"interpolation"}),
(legacy_F.pad, {"padding", "fill"}),
(legacy_F.crop, {}),
(legacy_F.center_crop, {}),
(legacy_F.resized_crop, {}),
(legacy_F.resized_crop, {"interpolation"}),
(legacy_F.hflip, {}),
(legacy_F.perspective, {"startpoints", "endpoints", "fill"}),
(legacy_F.perspective, {"startpoints", "endpoints", "fill", "interpolation"}),
(legacy_F.vflip, {}),
(legacy_F.five_crop, {}),
(legacy_F.ten_crop, {}),
Expand All @@ -1279,8 +1305,8 @@ def test_random_resize_eval(self, mocker):
(legacy_F.adjust_saturation, {}),
(legacy_F.adjust_hue, {}),
(legacy_F.adjust_gamma, {}),
(legacy_F.rotate, {"center", "fill"}),
(legacy_F.affine, {"angle", "translate", "center", "fill"}),
(legacy_F.rotate, {"center", "fill", "interpolation"}),
(legacy_F.affine, {"angle", "translate", "center", "fill", "interpolation"}),
(legacy_F.to_grayscale, {}),
(legacy_F.rgb_to_grayscale, {}),
(legacy_F.to_tensor, {}),
Expand All @@ -1292,7 +1318,7 @@ def test_random_resize_eval(self, mocker):
(legacy_F.adjust_sharpness, {}),
(legacy_F.autocontrast, {}),
(legacy_F.equalize, {}),
(legacy_F.elastic_transform, {"fill"}),
(legacy_F.elastic_transform, {"fill", "interpolation"}),
],
)
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):
Expand Down
12 changes: 6 additions & 6 deletions torchvision/prototype/datapoints/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def vertical_flip(self) -> BoundingBox:
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> BoundingBox:
Expand Down Expand Up @@ -107,7 +107,7 @@ def resized_crop(
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> BoundingBox:
output, spatial_size = self._F.resized_crop_bounding_box(
Expand All @@ -133,7 +133,7 @@ def pad(
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
Expand All @@ -154,7 +154,7 @@ def affine(
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> BoundingBox:
Expand All @@ -174,7 +174,7 @@ def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> BoundingBox:
Expand All @@ -191,7 +191,7 @@ def perspective(
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> BoundingBox:
output = self._F.elastic_bounding_box(
Expand Down
12 changes: 6 additions & 6 deletions torchvision/prototype/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def vertical_flip(self) -> Datapoint:
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Datapoint:
Expand All @@ -162,7 +162,7 @@ def resized_crop(
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> Datapoint:
return self
Expand All @@ -178,7 +178,7 @@ def pad(
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
Expand All @@ -191,7 +191,7 @@ def affine(
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Datapoint:
Expand All @@ -201,7 +201,7 @@ def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Datapoint:
Expand All @@ -210,7 +210,7 @@ def perspective(
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Datapoint:
return self
Expand Down
12 changes: 6 additions & 6 deletions torchvision/prototype/datapoints/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def vertical_flip(self) -> Image:
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Image:
Expand All @@ -86,7 +86,7 @@ def resized_crop(
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> Image:
output = self._F.resized_crop_image_tensor(
Expand All @@ -113,7 +113,7 @@ def pad(
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
Expand All @@ -129,7 +129,7 @@ def affine(
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Image:
Expand All @@ -149,7 +149,7 @@ def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Image:
Expand All @@ -166,7 +166,7 @@ def perspective(
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Image:
output = self._F.elastic_image_tensor(
Expand Down
12 changes: 6 additions & 6 deletions torchvision/prototype/datapoints/_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def vertical_flip(self) -> Mask:
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Mask:
Expand All @@ -75,7 +75,7 @@ def resized_crop(
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
antialias: Optional[Union[str, bool]] = "warn",
) -> Mask:
output = self._F.resized_crop_mask(self.as_subclass(torch.Tensor), top, left, height, width, size=size)
Expand All @@ -93,7 +93,7 @@ def pad(
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
Expand All @@ -107,7 +107,7 @@ def affine(
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Mask:
Expand All @@ -126,7 +126,7 @@ def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Mask:
Expand All @@ -138,7 +138,7 @@ def perspective(
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
) -> Mask:
output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill)
Expand Down
Loading