Skip to content

Commit 9e71fda

Browse files
authored
Plural to singular name change. (#3055)
1 parent 0c44513 commit 9e71fda

File tree

5 files changed

+103
-103
lines changed

5 files changed

+103
-103
lines changed

test/test_functional_tensor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
import torchvision.transforms.functional_tensor as F_t
1010
import torchvision.transforms.functional_pil as F_pil
1111
import torchvision.transforms.functional as F
12-
from torchvision.transforms import InterpolationModes
12+
from torchvision.transforms import InterpolationMode
1313

1414
from common_utils import TransformsTester
1515

1616

17-
NEAREST, BILINEAR, BICUBIC = InterpolationModes.NEAREST, InterpolationModes.BILINEAR, InterpolationModes.BICUBIC
17+
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
1818

1919

2020
class Tester(TransformsTester):
@@ -419,7 +419,7 @@ def test_resize(self):
419419
)
420420

421421
# assert changed type warning
422-
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
422+
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
423423
res1 = F.resize(tensor, size=32, interpolation=2)
424424
res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
425425
self.assertTrue(res1.equal(res2))
@@ -626,7 +626,7 @@ def test_affine(self):
626626
self.assertTrue(res1.equal(res2))
627627

628628
# assert changed type warning
629-
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
629+
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
630630
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2)
631631
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
632632
self.assertTrue(res1.equal(res2))
@@ -714,7 +714,7 @@ def test_rotate(self):
714714
self.assertTrue(res1.equal(res2))
715715

716716
# assert changed type warning
717-
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
717+
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
718718
res1 = F.rotate(tensor, 45, interpolation=2)
719719
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
720720
self.assertTrue(res1.equal(res2))
@@ -788,7 +788,7 @@ def test_perspective(self):
788788
# assert changed type warning
789789
spoints = [[0, 0], [33, 0], [33, 25], [0, 25]]
790790
epoints = [[3, 2], [32, 3], [30, 24], [2, 25]]
791-
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
791+
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
792792
res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=2)
793793
res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR)
794794
self.assertTrue(res1.equal(res2))

test/test_transforms.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,12 +1500,12 @@ def test_random_rotation(self):
15001500
# assert deprecation warning and non-BC
15011501
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
15021502
t = transforms.RandomRotation((-10, 10), resample=2)
1503-
self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR)
1503+
self.assertEqual(t.interpolation, transforms.InterpolationMode.BILINEAR)
15041504

15051505
# assert changed type warning
1506-
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
1506+
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
15071507
t = transforms.RandomRotation((-10, 10), interpolation=2)
1508-
self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR)
1508+
self.assertEqual(t.interpolation, transforms.InterpolationMode.BILINEAR)
15091509

15101510
def test_random_affine(self):
15111511

@@ -1547,22 +1547,22 @@ def test_random_affine(self):
15471547
# Checking if RandomAffine can be printed as string
15481548
t.__repr__()
15491549

1550-
t = transforms.RandomAffine(10, interpolation=transforms.InterpolationModes.BILINEAR)
1550+
t = transforms.RandomAffine(10, interpolation=transforms.InterpolationMode.BILINEAR)
15511551
self.assertIn("bilinear", t.__repr__())
15521552

15531553
# assert deprecation warning and non-BC
15541554
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
15551555
t = transforms.RandomAffine(10, resample=2)
1556-
self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR)
1556+
self.assertEqual(t.interpolation, transforms.InterpolationMode.BILINEAR)
15571557

15581558
with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"):
15591559
t = transforms.RandomAffine(10, fillcolor=10)
15601560
self.assertEqual(t.fill, 10)
15611561

15621562
# assert changed type warning
1563-
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
1563+
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
15641564
t = transforms.RandomAffine(10, interpolation=2)
1565-
self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR)
1565+
self.assertEqual(t.interpolation, transforms.InterpolationMode.BILINEAR)
15661566

15671567
def test_to_grayscale(self):
15681568
"""Unit tests for grayscale transform"""

test/test_transforms_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
from torchvision import transforms as T
44
from torchvision.transforms import functional as F
5-
from torchvision.transforms import InterpolationModes
5+
from torchvision.transforms import InterpolationMode
66

77
import numpy as np
88

@@ -11,7 +11,7 @@
1111
from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes
1212

1313

14-
NEAREST, BILINEAR, BICUBIC = InterpolationModes.NEAREST, InterpolationModes.BILINEAR, InterpolationModes.BICUBIC
14+
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
1515

1616

1717
class Tester(TransformsTester):

torchvision/transforms/functional.py

Lines changed: 53 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from . import functional_tensor as F_t
2121

2222

23-
class InterpolationModes(Enum):
23+
class InterpolationMode(Enum):
2424
"""Interpolation modes
2525
"""
2626
NEAREST = "nearest"
@@ -33,26 +33,26 @@ class InterpolationModes(Enum):
3333

3434

3535
# TODO: Once torchscript supports Enums with staticmethod
36-
# this can be put into InterpolationModes as staticmethod
37-
def _interpolation_modes_from_int(i: int) -> InterpolationModes:
36+
# this can be put into InterpolationMode as staticmethod
37+
def _interpolation_modes_from_int(i: int) -> InterpolationMode:
3838
inverse_modes_mapping = {
39-
0: InterpolationModes.NEAREST,
40-
2: InterpolationModes.BILINEAR,
41-
3: InterpolationModes.BICUBIC,
42-
4: InterpolationModes.BOX,
43-
5: InterpolationModes.HAMMING,
44-
1: InterpolationModes.LANCZOS,
39+
0: InterpolationMode.NEAREST,
40+
2: InterpolationMode.BILINEAR,
41+
3: InterpolationMode.BICUBIC,
42+
4: InterpolationMode.BOX,
43+
5: InterpolationMode.HAMMING,
44+
1: InterpolationMode.LANCZOS,
4545
}
4646
return inverse_modes_mapping[i]
4747

4848

4949
pil_modes_mapping = {
50-
InterpolationModes.NEAREST: 0,
51-
InterpolationModes.BILINEAR: 2,
52-
InterpolationModes.BICUBIC: 3,
53-
InterpolationModes.BOX: 4,
54-
InterpolationModes.HAMMING: 5,
55-
InterpolationModes.LANCZOS: 1,
50+
InterpolationMode.NEAREST: 0,
51+
InterpolationMode.BILINEAR: 2,
52+
InterpolationMode.BICUBIC: 3,
53+
InterpolationMode.BOX: 4,
54+
InterpolationMode.HAMMING: 5,
55+
InterpolationMode.LANCZOS: 1,
5656
}
5757

5858
_is_pil_image = F_pil._is_pil_image
@@ -329,7 +329,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
329329
return tensor
330330

331331

332-
def resize(img: Tensor, size: List[int], interpolation: InterpolationModes = InterpolationModes.BILINEAR) -> Tensor:
332+
def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR) -> Tensor:
333333
r"""Resize the input image to the given size.
334334
The image can be a PIL Image or a torch Tensor, in which case it is expected
335335
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
@@ -343,10 +343,10 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationModes = Int
343343
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
344344
In torchscript mode size as single int is not supported, use a tuple or
345345
list of length 1: ``[size, ]``.
346-
interpolation (InterpolationModes): Desired interpolation enum defined by
347-
:class:`torchvision.transforms.InterpolationModes`.
348-
Default is ``InterpolationModes.BILINEAR``. If input is Tensor, only ``InterpolationModes.NEAREST``,
349-
``InterpolationModes.BILINEAR`` and ``InterpolationModes.BICUBIC`` are supported.
346+
interpolation (InterpolationMode): Desired interpolation enum defined by
347+
:class:`torchvision.transforms.InterpolationMode`.
348+
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
349+
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
350350
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
351351
352352
Returns:
@@ -355,13 +355,13 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationModes = Int
355355
# Backward compatibility with integer value
356356
if isinstance(interpolation, int):
357357
warnings.warn(
358-
"Argument interpolation should be of type InterpolationModes instead of int. "
359-
"Please, use InterpolationModes enum."
358+
"Argument interpolation should be of type InterpolationMode instead of int. "
359+
"Please, use InterpolationMode enum."
360360
)
361361
interpolation = _interpolation_modes_from_int(interpolation)
362362

363-
if not isinstance(interpolation, InterpolationModes):
364-
raise TypeError("Argument interpolation should be a InterpolationModes")
363+
if not isinstance(interpolation, InterpolationMode):
364+
raise TypeError("Argument interpolation should be a InterpolationMode")
365365

366366
if not isinstance(img, torch.Tensor):
367367
pil_interpolation = pil_modes_mapping[interpolation]
@@ -475,7 +475,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
475475

476476
def resized_crop(
477477
img: Tensor, top: int, left: int, height: int, width: int, size: List[int],
478-
interpolation: InterpolationModes = InterpolationModes.BILINEAR
478+
interpolation: InterpolationMode = InterpolationMode.BILINEAR
479479
) -> Tensor:
480480
"""Crop the given image and resize it to desired size.
481481
The image can be a PIL Image or a Tensor, in which case it is expected
@@ -490,10 +490,10 @@ def resized_crop(
490490
height (int): Height of the crop box.
491491
width (int): Width of the crop box.
492492
size (sequence or int): Desired output size. Same semantics as ``resize``.
493-
interpolation (InterpolationModes): Desired interpolation enum defined by
494-
:class:`torchvision.transforms.InterpolationModes`.
495-
Default is ``InterpolationModes.BILINEAR``. If input is Tensor, only ``InterpolationModes.NEAREST``,
496-
``InterpolationModes.BILINEAR`` and ``InterpolationModes.BICUBIC`` are supported.
493+
interpolation (InterpolationMode): Desired interpolation enum defined by
494+
:class:`torchvision.transforms.InterpolationMode`.
495+
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
496+
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
497497
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
498498
499499
Returns:
@@ -556,7 +556,7 @@ def perspective(
556556
img: Tensor,
557557
startpoints: List[List[int]],
558558
endpoints: List[List[int]],
559-
interpolation: InterpolationModes = InterpolationModes.BILINEAR,
559+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
560560
fill: Optional[int] = None
561561
) -> Tensor:
562562
"""Perform perspective transform of the given image.
@@ -569,9 +569,9 @@ def perspective(
569569
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
570570
endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
571571
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
572-
interpolation (InterpolationModes): Desired interpolation enum defined by
573-
:class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``.
574-
If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported.
572+
interpolation (InterpolationMode): Desired interpolation enum defined by
573+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
574+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
575575
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
576576
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
577577
image. If int or float, the value is used for all bands respectively.
@@ -587,13 +587,13 @@ def perspective(
587587
# Backward compatibility with integer value
588588
if isinstance(interpolation, int):
589589
warnings.warn(
590-
"Argument interpolation should be of type InterpolationModes instead of int. "
591-
"Please, use InterpolationModes enum."
590+
"Argument interpolation should be of type InterpolationMode instead of int. "
591+
"Please, use InterpolationMode enum."
592592
)
593593
interpolation = _interpolation_modes_from_int(interpolation)
594594

595-
if not isinstance(interpolation, InterpolationModes):
596-
raise TypeError("Argument interpolation should be a InterpolationModes")
595+
if not isinstance(interpolation, InterpolationMode):
596+
raise TypeError("Argument interpolation should be a InterpolationMode")
597597

598598
if not isinstance(img, torch.Tensor):
599599
pil_interpolation = pil_modes_mapping[interpolation]
@@ -869,7 +869,7 @@ def _get_inverse_affine_matrix(
869869

870870

871871
def rotate(
872-
img: Tensor, angle: float, interpolation: InterpolationModes = InterpolationModes.NEAREST,
872+
img: Tensor, angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST,
873873
expand: bool = False, center: Optional[List[int]] = None,
874874
fill: Optional[int] = None, resample: Optional[int] = None
875875
) -> Tensor:
@@ -880,9 +880,9 @@ def rotate(
880880
Args:
881881
img (PIL Image or Tensor): image to be rotated.
882882
angle (float or int): rotation angle value in degrees, counter-clockwise.
883-
interpolation (InterpolationModes): Desired interpolation enum defined by
884-
:class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.NEAREST``.
885-
If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported.
883+
interpolation (InterpolationMode): Desired interpolation enum defined by
884+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
885+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
886886
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
887887
expand (bool, optional): Optional expansion flag.
888888
If true, expands the output image to make it large enough to hold the entire rotated image.
@@ -913,8 +913,8 @@ def rotate(
913913
# Backward compatibility with integer value
914914
if isinstance(interpolation, int):
915915
warnings.warn(
916-
"Argument interpolation should be of type InterpolationModes instead of int. "
917-
"Please, use InterpolationModes enum."
916+
"Argument interpolation should be of type InterpolationMode instead of int. "
917+
"Please, use InterpolationMode enum."
918918
)
919919
interpolation = _interpolation_modes_from_int(interpolation)
920920

@@ -924,8 +924,8 @@ def rotate(
924924
if center is not None and not isinstance(center, (list, tuple)):
925925
raise TypeError("Argument center should be a sequence")
926926

927-
if not isinstance(interpolation, InterpolationModes):
928-
raise TypeError("Argument interpolation should be a InterpolationModes")
927+
if not isinstance(interpolation, InterpolationMode):
928+
raise TypeError("Argument interpolation should be a InterpolationMode")
929929

930930
if not isinstance(img, torch.Tensor):
931931
pil_interpolation = pil_modes_mapping[interpolation]
@@ -945,7 +945,7 @@ def rotate(
945945

946946
def affine(
947947
img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float],
948-
interpolation: InterpolationModes = InterpolationModes.NEAREST, fill: Optional[int] = None,
948+
interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[int] = None,
949949
resample: Optional[int] = None, fillcolor: Optional[int] = None
950950
) -> Tensor:
951951
"""Apply affine transformation on the image keeping image center invariant.
@@ -960,9 +960,9 @@ def affine(
960960
shear (float or tuple or list): shear angle value in degrees between -180 to 180, clockwise direction.
961961
If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while
962962
the second value corresponds to a shear parallel to the y axis.
963-
interpolation (InterpolationModes): Desired interpolation enum defined by
964-
:class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.NEAREST``.
965-
If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported.
963+
interpolation (InterpolationMode): Desired interpolation enum defined by
964+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
965+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
966966
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
967967
fill (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0).
968968
This option is not supported for Tensor input. Fill value for the area outside the transform in the output
@@ -984,8 +984,8 @@ def affine(
984984
# Backward compatibility with integer value
985985
if isinstance(interpolation, int):
986986
warnings.warn(
987-
"Argument interpolation should be of type InterpolationModes instead of int. "
988-
"Please, use InterpolationModes enum."
987+
"Argument interpolation should be of type InterpolationMode instead of int. "
988+
"Please, use InterpolationMode enum."
989989
)
990990
interpolation = _interpolation_modes_from_int(interpolation)
991991

@@ -1010,8 +1010,8 @@ def affine(
10101010
if not isinstance(shear, (numbers.Number, (list, tuple))):
10111011
raise TypeError("Shear should be either a single value or a sequence of two values")
10121012

1013-
if not isinstance(interpolation, InterpolationModes):
1014-
raise TypeError("Argument interpolation should be a InterpolationModes")
1013+
if not isinstance(interpolation, InterpolationMode):
1014+
raise TypeError("Argument interpolation should be a InterpolationMode")
10151015

10161016
if isinstance(angle, int):
10171017
angle = float(angle)

0 commit comments

Comments
 (0)