Skip to content

Commit 0c44513

Browse files
authored
[BC-breaking] Introduced InterpolationModes and deprecated arguments: resample and fillcolor (#2952)
* Deprecated arguments: resample and fillcolor Replaced by interpolation and fill * Updates according to the review * Added tests to check warnings and asserted BC * [WIP] Interpolation modes * Added InterpolationModes enum * Added supported for int values for interpolation for BC * Removed useless test code * Fix flake8
1 parent 240210c commit 0c44513

7 files changed

+416
-177
lines changed

test/test_functional_tensor.py

Lines changed: 77 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@
44
import math
55

66
import numpy as np
7-
from PIL.Image import NEAREST, BILINEAR, BICUBIC
87

98
import torch
109
import torchvision.transforms.functional_tensor as F_t
1110
import torchvision.transforms.functional_pil as F_pil
1211
import torchvision.transforms.functional as F
12+
from torchvision.transforms import InterpolationModes
1313

1414
from common_utils import TransformsTester
1515

1616

17+
NEAREST, BILINEAR, BICUBIC = InterpolationModes.NEAREST, InterpolationModes.BILINEAR, InterpolationModes.BICUBIC
18+
19+
1720
class Tester(TransformsTester):
1821

1922
def setUp(self):
@@ -365,7 +368,7 @@ def test_adjust_gamma(self):
365368
)
366369

367370
def test_resize(self):
368-
script_fn = torch.jit.script(F_t.resize)
371+
script_fn = torch.jit.script(F.resize)
369372
tensor, pil_img = self._create_data(26, 36, device=self.device)
370373
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
371374

@@ -382,14 +385,14 @@ def test_resize(self):
382385

383386
for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]:
384387
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
385-
resized_tensor = F_t.resize(tensor, size=size, interpolation=interpolation)
386-
resized_pil_img = F_pil.resize(pil_img, size=size, interpolation=interpolation)
388+
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation)
389+
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation)
387390

388391
self.assertEqual(
389392
resized_tensor.size()[1:], resized_pil_img.size[::-1], msg="{}, {}".format(size, interpolation)
390393
)
391394

392-
if interpolation != NEAREST:
395+
if interpolation not in [NEAREST, ]:
393396
# We can not check values if mode = NEAREST, as results are different
394397
# E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]]
395398
# E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
@@ -407,24 +410,32 @@ def test_resize(self):
407410
script_size = [size, ]
408411
else:
409412
script_size = size
413+
410414
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation)
411415
self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation))
412416

413417
self._test_fn_on_batch(
414418
batch_tensors, F.resize, size=script_size, interpolation=interpolation
415419
)
416420

421+
# assert changed type warning
422+
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
423+
res1 = F.resize(tensor, size=32, interpolation=2)
424+
res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
425+
self.assertTrue(res1.equal(res2))
426+
417427
def test_resized_crop(self):
418428
# test values of F.resized_crop in several cases:
419429
# 1) resize to the same size, crop to the same size => should be identity
420430
tensor, _ = self._create_data(26, 36, device=self.device)
421-
for i in [0, 2, 3]:
422-
out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=i)
431+
432+
for mode in [NEAREST, BILINEAR, BICUBIC]:
433+
out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode)
423434
self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
424435

425436
# 2) resize by half and crop a TL corner
426437
tensor, _ = self._create_data(26, 36, device=self.device)
427-
out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=0)
438+
out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST)
428439
expected_out_tensor = tensor[:, :20:2, :30:2]
429440
self.assertTrue(
430441
expected_out_tensor.equal(out_tensor),
@@ -433,17 +444,19 @@ def test_resized_crop(self):
433444

434445
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
435446
self._test_fn_on_batch(
436-
batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=0
447+
batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=NEAREST
437448
)
438449

439450
def _test_affine_identity_map(self, tensor, scripted_affine):
440451
# 1) identity map
441-
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
452+
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
442453

443454
self.assertTrue(
444455
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
445456
)
446-
out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
457+
out_tensor = scripted_affine(
458+
tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
459+
)
447460
self.assertTrue(
448461
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
449462
)
@@ -461,13 +474,13 @@ def _test_affine_square_rotations(self, tensor, pil_img, scripted_affine):
461474
]
462475
for a, true_tensor in test_configs:
463476
out_pil_img = F.affine(
464-
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
477+
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
465478
)
466479
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(self.device)
467480

468481
for fn in [F.affine, scripted_affine]:
469482
out_tensor = fn(
470-
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
483+
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
471484
)
472485
if true_tensor is not None:
473486
self.assertTrue(
@@ -496,13 +509,13 @@ def _test_affine_rect_rotations(self, tensor, pil_img, scripted_affine):
496509
for a in test_configs:
497510

498511
out_pil_img = F.affine(
499-
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
512+
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
500513
)
501514
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
502515

503516
for fn in [F.affine, scripted_affine]:
504517
out_tensor = fn(
505-
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
518+
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
506519
).cpu()
507520

508521
if out_tensor.dtype != torch.uint8:
@@ -526,10 +539,10 @@ def _test_affine_translations(self, tensor, pil_img, scripted_affine):
526539
]
527540
for t in test_configs:
528541

529-
out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
542+
out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
530543

531544
for fn in [F.affine, scripted_affine]:
532-
out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
545+
out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
533546

534547
if out_tensor.dtype != torch.uint8:
535548
out_tensor = out_tensor.to(torch.uint8)
@@ -550,13 +563,13 @@ def _test_affine_all_ops(self, tensor, pil_img, scripted_affine):
550563
(-45, [-10, -10], 1.2, [4.0, 5.0]),
551564
(-90, [0, 0], 1.0, [0.0, 0.0]),
552565
]
553-
for r in [0, ]:
566+
for r in [NEAREST, ]:
554567
for a, t, s, sh in test_configs:
555-
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r)
568+
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=r)
556569
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
557570

558571
for fn in [F.affine, scripted_affine]:
559-
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r).cpu()
572+
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, interpolation=r).cpu()
560573

561574
if out_tensor.dtype != torch.uint8:
562575
out_tensor = out_tensor.to(torch.uint8)
@@ -605,18 +618,36 @@ def test_affine(self):
605618
batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0]
606619
)
607620

621+
tensor, pil_img = data[0]
622+
# assert deprecation warning and non-BC
623+
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
624+
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=2)
625+
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
626+
self.assertTrue(res1.equal(res2))
627+
628+
# assert changed type warning
629+
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
630+
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2)
631+
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
632+
self.assertTrue(res1.equal(res2))
633+
634+
with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"):
635+
res1 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fillcolor=10)
636+
res2 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fill=10)
637+
self.assertEqual(res1, res2)
638+
608639
def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers):
609640
img_size = pil_img.size
610641
dt = tensor.dtype
611-
for r in [0, ]:
642+
for r in [NEAREST, ]:
612643
for a in range(-180, 180, 17):
613644
for e in [True, False]:
614645
for c in centers:
615646

616-
out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c)
647+
out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c)
617648
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
618649
for fn in [F.rotate, scripted_rotate]:
619-
out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c).cpu()
650+
out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c).cpu()
620651

621652
if out_tensor.dtype != torch.uint8:
622653
out_tensor = out_tensor.to(torch.uint8)
@@ -673,12 +704,24 @@ def test_rotate(self):
673704

674705
center = (20, 22)
675706
self._test_fn_on_batch(
676-
batch_tensors, F.rotate, angle=32, resample=0, expand=True, center=center
707+
batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center
677708
)
709+
tensor, pil_img = data[0]
710+
# assert deprecation warning and non-BC
711+
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
712+
res1 = F.rotate(tensor, 45, resample=2)
713+
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
714+
self.assertTrue(res1.equal(res2))
715+
716+
# assert changed type warning
717+
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
718+
res1 = F.rotate(tensor, 45, interpolation=2)
719+
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
720+
self.assertTrue(res1.equal(res2))
678721

679722
def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs):
680723
dt = tensor.dtype
681-
for r in [0, ]:
724+
for r in [NEAREST, ]:
682725
for spoints, epoints in test_configs:
683726
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r)
684727
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
@@ -739,9 +782,17 @@ def test_perspective(self):
739782

740783
for spoints, epoints in test_configs:
741784
self._test_fn_on_batch(
742-
batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0
785+
batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=NEAREST
743786
)
744787

788+
# assert changed type warning
789+
spoints = [[0, 0], [33, 0], [33, 25], [0, 25]]
790+
epoints = [[3, 2], [32, 3], [30, 24], [2, 25]]
791+
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
792+
res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=2)
793+
res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR)
794+
self.assertTrue(res1.equal(res2))
795+
745796
def test_gaussian_blur(self):
746797
small_image_tensor = torch.from_numpy(
747798
np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))

test/test_transforms.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,11 +1492,21 @@ def test_random_rotation(self):
14921492

14931493
t = transforms.RandomRotation((-10, 10))
14941494
angle = t.get_params(t.degrees)
1495-
self.assertTrue(angle > -10 and angle < 10)
1495+
self.assertTrue(-10 < angle < 10)
14961496

14971497
# Checking if RandomRotation can be printed as string
14981498
t.__repr__()
14991499

1500+
# assert deprecation warning and non-BC
1501+
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
1502+
t = transforms.RandomRotation((-10, 10), resample=2)
1503+
self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR)
1504+
1505+
# assert changed type warning
1506+
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
1507+
t = transforms.RandomRotation((-10, 10), interpolation=2)
1508+
self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR)
1509+
15001510
def test_random_affine(self):
15011511

15021512
with self.assertRaises(ValueError):
@@ -1537,8 +1547,22 @@ def test_random_affine(self):
15371547
# Checking if RandomAffine can be printed as string
15381548
t.__repr__()
15391549

1540-
t = transforms.RandomAffine(10, resample=Image.BILINEAR)
1541-
self.assertIn("Image.BILINEAR", t.__repr__())
1550+
t = transforms.RandomAffine(10, interpolation=transforms.InterpolationModes.BILINEAR)
1551+
self.assertIn("bilinear", t.__repr__())
1552+
1553+
# assert deprecation warning and non-BC
1554+
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
1555+
t = transforms.RandomAffine(10, resample=2)
1556+
self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR)
1557+
1558+
with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"):
1559+
t = transforms.RandomAffine(10, fillcolor=10)
1560+
self.assertEqual(t.fill, 10)
1561+
1562+
# assert changed type warning
1563+
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
1564+
t = transforms.RandomAffine(10, interpolation=2)
1565+
self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR)
15421566

15431567
def test_to_grayscale(self):
15441568
"""Unit tests for grayscale transform"""

test/test_transforms_tensor.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import torch
33
from torchvision import transforms as T
44
from torchvision.transforms import functional as F
5-
6-
from PIL.Image import NEAREST, BILINEAR, BICUBIC
5+
from torchvision.transforms import InterpolationModes
76

87
import numpy as np
98

@@ -12,6 +11,9 @@
1211
from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes
1312

1413

14+
NEAREST, BILINEAR, BICUBIC = InterpolationModes.NEAREST, InterpolationModes.BILINEAR, InterpolationModes.BICUBIC
15+
16+
1517
class Tester(TransformsTester):
1618

1719
def setUp(self):
@@ -349,7 +351,7 @@ def test_random_affine(self):
349351
for interpolation in [NEAREST, BILINEAR]:
350352
transform = T.RandomAffine(
351353
degrees=degrees, translate=translate,
352-
scale=scale, shear=shear, resample=interpolation
354+
scale=scale, shear=shear, interpolation=interpolation
353355
)
354356
s_transform = torch.jit.script(transform)
355357

@@ -368,7 +370,7 @@ def test_random_rotate(self):
368370
for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
369371
for interpolation in [NEAREST, BILINEAR]:
370372
transform = T.RandomRotation(
371-
degrees=degrees, resample=interpolation, expand=expand, center=center
373+
degrees=degrees, interpolation=interpolation, expand=expand, center=center
372374
)
373375
s_transform = torch.jit.script(transform)
374376

0 commit comments

Comments
 (0)