Skip to content

Commit c009b0f

Browse files
committed
Adapted almost all functional tensor tests on CPU/CUDA
- fixed bug with transforms using generated grid - remains *_crop, blocked by pytorch#2568 - TODO: test_adjustments
1 parent a75fdd4 commit c009b0f

File tree

3 files changed

+88
-56
lines changed

3 files changed

+88
-56
lines changed

test/test_functional_tensor.py

Lines changed: 82 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,31 @@
1515
import torchvision.transforms.functional as F
1616

1717

18+
def run_on_cpu_and_cuda(fn):
19+
20+
devices = ["cpu", ]
21+
if torch.cuda.is_available():
22+
devices.append("cuda")
23+
24+
def wrapper(self, *args, **kwargs):
25+
for device in devices:
26+
fn(self, device, *args, **kwargs)
27+
28+
return wrapper
29+
30+
1831
class Tester(unittest.TestCase):
1932

20-
def _create_data(self, height=3, width=3, channels=3):
21-
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8)
22-
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().numpy())
33+
def _create_data(self, height=3, width=3, channels=3, device="cpu"):
34+
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8, device=device)
35+
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
2336
return tensor, pil_img
2437

2538
def compareTensorToPIL(self, tensor, pil_image, msg=None):
2639
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
2740
if msg is None:
2841
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
29-
self.assertTrue(tensor.equal(pil_tensor), msg)
42+
self.assertTrue(tensor.cpu().equal(pil_tensor), msg)
3043

3144
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None):
3245
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))).to(tensor)
@@ -36,9 +49,10 @@ def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None):
3649
msg="{}: mae={}, tol={}: \n{}\nvs\n{}".format(msg, mae, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])
3750
)
3851

39-
def test_vflip(self):
52+
@run_on_cpu_and_cuda
53+
def test_vflip(self, device):
4054
script_vflip = torch.jit.script(F_t.vflip)
41-
img_tensor = torch.randn(3, 16, 16)
55+
img_tensor = torch.randn(3, 16, 16).to(device=device)
4256
img_tensor_clone = img_tensor.clone()
4357
vflipped_img = F_t.vflip(img_tensor)
4458
vflipped_img_again = F_t.vflip(vflipped_img)
@@ -49,9 +63,10 @@ def test_vflip(self):
4963
vflipped_img_script = script_vflip(img_tensor)
5064
self.assertTrue(torch.equal(vflipped_img, vflipped_img_script))
5165

52-
def test_hflip(self):
66+
@run_on_cpu_and_cuda
67+
def test_hflip(self, device):
5368
script_hflip = torch.jit.script(F_t.hflip)
54-
img_tensor = torch.randn(3, 16, 16)
69+
img_tensor = torch.randn(3, 16, 16).to(device)
5570
img_tensor_clone = img_tensor.clone()
5671
hflipped_img = F_t.hflip(img_tensor)
5772
hflipped_img_again = F_t.hflip(hflipped_img)
@@ -62,9 +77,10 @@ def test_hflip(self):
6277
hflipped_img_script = script_hflip(img_tensor)
6378
self.assertTrue(torch.equal(hflipped_img, hflipped_img_script))
6479

65-
def test_crop(self):
80+
@run_on_cpu_and_cuda
81+
def test_crop(self, device):
6682
script_crop = torch.jit.script(F_t.crop)
67-
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
83+
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8, device=device)
6884
img_tensor_clone = img_tensor.clone()
6985
top = random.randint(0, 15)
7086
left = random.randint(0, 15)
@@ -73,7 +89,7 @@ def test_crop(self):
7389
img_cropped = F_t.crop(img_tensor, top, left, height, width)
7490
img_PIL = transforms.ToPILImage()(img_tensor)
7591
img_PIL_cropped = F.crop(img_PIL, top, left, height, width)
76-
img_cropped_GT = transforms.ToTensor()(img_PIL_cropped)
92+
img_cropped_GT = transforms.ToTensor()(img_PIL_cropped).to(device)
7793
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
7894
self.assertTrue(torch.equal(img_cropped, (img_cropped_GT * 255).to(torch.uint8)),
7995
"functional_tensor crop not working")
@@ -203,17 +219,15 @@ def test_center_crop(self):
203219
img_tensor_clone = img_tensor.clone()
204220
cropped_tensor = F_t.center_crop(img_tensor, [10, 10])
205221
cropped_pil_image = F.center_crop(transforms.ToPILImage()(img_tensor), [10, 10])
206-
cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(torch.uint8)
222+
cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(dtype=torch.uint8)
207223
self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor))
208224
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
209225
# scriptable function test
210226
cropped_script = script_center_crop(img_tensor, [10, 10])
211227
self.assertTrue(torch.equal(cropped_script, cropped_tensor))
212228

213229
def test_five_crop(self):
214-
script_five_crop = torch.jit.script(F_t.five_crop)
215230
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
216-
img_tensor_clone = img_tensor.clone()
217231
cropped_tensor = F_t.five_crop(img_tensor, [10, 10])
218232
cropped_pil_image = F.five_crop(transforms.ToPILImage()(img_tensor), [10, 10])
219233
self.assertTrue(torch.equal(cropped_tensor[0],
@@ -226,11 +240,6 @@ def test_five_crop(self):
226240
(transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8)))
227241
self.assertTrue(torch.equal(cropped_tensor[4],
228242
(transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
229-
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
230-
# scriptable function test
231-
cropped_script = script_five_crop(img_tensor, [10, 10])
232-
for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
233-
self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
234243

235244
def test_ten_crop(self):
236245
script_ten_crop = torch.jit.script(F_t.ten_crop)
@@ -264,9 +273,10 @@ def test_ten_crop(self):
264273
for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
265274
self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
266275

267-
def test_pad(self):
276+
@run_on_cpu_and_cuda
277+
def test_pad(self, device):
268278
script_fn = torch.jit.script(F_t.pad)
269-
tensor, pil_img = self._create_data(7, 8)
279+
tensor, pil_img = self._create_data(7, 8, device=device)
270280

271281
for dt in [None, torch.float32, torch.float64]:
272282
if dt is not None:
@@ -302,9 +312,10 @@ def test_pad(self):
302312
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
303313
F_t.pad(tensor, (-2, -3), padding_mode="symmetric")
304314

305-
def test_adjust_gamma(self):
306-
script_fn = torch.jit.script(F_t.adjust_gamma)
307-
tensor, pil_img = self._create_data(26, 36)
315+
@run_on_cpu_and_cuda
316+
def test_adjust_gamma(self, device):
317+
script_fn = torch.jit.script(F.adjust_gamma)
318+
tensor, pil_img = self._create_data(26, 36, device=device)
308319

309320
for dt in [torch.float64, torch.float32, None]:
310321

@@ -315,8 +326,8 @@ def test_adjust_gamma(self):
315326
gains = [0.7, 1.0, 1.3]
316327
for gamma, gain in zip(gammas, gains):
317328

318-
adjusted_tensor = F_t.adjust_gamma(tensor, gamma, gain)
319-
adjusted_pil = F_pil.adjust_gamma(pil_img, gamma, gain)
329+
adjusted_tensor = F.adjust_gamma(tensor, gamma, gain)
330+
adjusted_pil = F.adjust_gamma(pil_img, gamma, gain)
320331
scripted_result = script_fn(tensor, gamma, gain)
321332
self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype)
322333
self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1])
@@ -327,11 +338,12 @@ def test_adjust_gamma(self):
327338

328339
self.compareTensorToPIL(rbg_tensor, adjusted_pil)
329340

330-
self.assertTrue(adjusted_tensor.equal(scripted_result))
341+
self.assertTrue(adjusted_tensor.allclose(scripted_result))
331342

332-
def test_resize(self):
343+
@run_on_cpu_and_cuda
344+
def test_resize(self, device):
333345
script_fn = torch.jit.script(F_t.resize)
334-
tensor, pil_img = self._create_data(26, 36)
346+
tensor, pil_img = self._create_data(26, 36, device=device)
335347

336348
for dt in [None, torch.float32, torch.float64]:
337349
if dt is not None:
@@ -367,28 +379,30 @@ def test_resize(self):
367379
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation)
368380
self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation))
369381

370-
def test_resized_crop(self):
382+
@run_on_cpu_and_cuda
383+
def test_resized_crop(self, device):
371384
# test values of F.resized_crop in several cases:
372385
# 1) resize to the same size, crop to the same size => should be identity
373-
tensor, _ = self._create_data(26, 36)
386+
tensor, _ = self._create_data(26, 36, device=device)
374387
for i in [0, 2, 3]:
375388
out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=i)
376389
self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
377390

378391
# 2) resize by half and crop a TL corner
379-
tensor, _ = self._create_data(26, 36)
392+
tensor, _ = self._create_data(26, 36, device=device)
380393
out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=0)
381394
expected_out_tensor = tensor[:, :20:2, :30:2]
382395
self.assertTrue(
383396
expected_out_tensor.equal(out_tensor),
384397
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10])
385398
)
386399

387-
def test_affine(self):
400+
@run_on_cpu_and_cuda
401+
def test_affine(self, device):
388402
# Tests on square and rectangular images
389403
scripted_affine = torch.jit.script(F.affine)
390404

391-
for tensor, pil_img in [self._create_data(26, 26), self._create_data(32, 26)]:
405+
for tensor, pil_img in [self._create_data(26, 26, device=device), self._create_data(32, 26, device=device)]:
392406

393407
# 1) identity map
394408
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
@@ -412,8 +426,16 @@ def test_affine(self):
412426
(180, torch.rot90(tensor, k=2, dims=(-1, -2))),
413427
]
414428
for a, true_tensor in test_configs:
429+
430+
out_pil_img = F.affine(
431+
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
432+
)
433+
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(device)
434+
415435
for fn in [F.affine, scripted_affine]:
416-
out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
436+
out_tensor = fn(
437+
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
438+
)
417439
if true_tensor is not None:
418440
self.assertTrue(
419441
true_tensor.equal(out_tensor),
@@ -422,11 +444,6 @@ def test_affine(self):
422444
else:
423445
true_tensor = out_tensor
424446

425-
out_pil_img = F.affine(
426-
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
427-
)
428-
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
429-
430447
num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0
431448
ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2]
432449
# Tolerance : less than 6% of different pixels
@@ -442,12 +459,16 @@ def test_affine(self):
442459
90, 45, 15, -30, -60, -120
443460
]
444461
for a in test_configs:
462+
463+
out_pil_img = F.affine(
464+
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
465+
)
466+
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
467+
445468
for fn in [F.affine, scripted_affine]:
446-
out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
447-
out_pil_img = F.affine(
448-
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
449-
)
450-
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
469+
out_tensor = fn(
470+
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
471+
).cpu()
451472

452473
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
453474
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
@@ -465,9 +486,12 @@ def test_affine(self):
465486
[10, 12], (-12, -13)
466487
]
467488
for t in test_configs:
489+
490+
out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
491+
468492
for fn in [F.affine, scripted_affine]:
469493
out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
470-
out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
494+
471495
self.compareTensorToPIL(out_tensor, out_pil_img)
472496

473497
# 3) Test rotation + translation + scale + share
@@ -489,23 +513,25 @@ def test_affine(self):
489513
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
490514

491515
for fn in [F.affine, scripted_affine]:
492-
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r)
516+
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r).cpu()
493517
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
494518
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
495-
# Tolerance : less than 5% of different pixels
519+
# Tolerance : less than 5% (cpu), 6% (cuda) of different pixels
520+
tol = 0.06 if device == "cuda" else 0.05
496521
self.assertLess(
497522
ratio_diff_pixels,
498-
0.05,
523+
tol,
499524
msg="{}: {}\n{} vs \n{}".format(
500525
(r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
501526
)
502527
)
503528

504-
def test_rotate(self):
529+
@run_on_cpu_and_cuda
530+
def test_rotate(self, device):
505531
# Tests on square image
506532
scripted_rotate = torch.jit.script(F.rotate)
507533

508-
for tensor, pil_img in [self._create_data(26, 26), self._create_data(32, 26)]:
534+
for tensor, pil_img in [self._create_data(26, 26, device=device), self._create_data(32, 26, device=device)]:
509535

510536
img_size = pil_img.size
511537
centers = [
@@ -522,7 +548,7 @@ def test_rotate(self):
522548
out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c)
523549
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
524550
for fn in [F.rotate, scripted_rotate]:
525-
out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c)
551+
out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c).cpu()
526552

527553
self.assertEqual(
528554
out_tensor.shape,
@@ -545,11 +571,12 @@ def test_rotate(self):
545571
)
546572
)
547573

548-
def test_perspective(self):
574+
@run_on_cpu_and_cuda
575+
def test_perspective(self, device):
549576

550577
from torchvision.transforms import RandomPerspective
551578

552-
for tensor, pil_img in [self._create_data(26, 34), self._create_data(26, 26)]:
579+
for tensor, pil_img in [self._create_data(26, 34, device=device), self._create_data(26, 26, device=device)]:
553580

554581
scripted_tranform = torch.jit.script(F.perspective)
555582

@@ -569,7 +596,7 @@ def test_perspective(self):
569596
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
570597

571598
for fn in [F.perspective, scripted_tranform]:
572-
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r)
599+
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r).cpu()
573600

574601
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
575602
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]

torchvision/transforms/functional.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,9 +223,11 @@ def to_pil_image(pic, mode=None):
223223
pic = np.expand_dims(pic, 2)
224224

225225
npimg = pic
226-
if isinstance(pic, torch.FloatTensor) and mode != 'F':
226+
if pic.is_floating_point() and mode != 'F':
227227
pic = pic.mul(255).byte()
228228
if isinstance(pic, torch.Tensor):
229+
if pic.device != torch.device("cpu"):
230+
pic = pic.cpu()
229231
npimg = np.transpose(pic.numpy(), (1, 2, 0))
230232

231233
if not isinstance(npimg, np.ndarray):

torchvision/transforms/functional_tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,9 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor:
659659
need_cast = True
660660
img = img.to(torch.float32)
661661

662+
if grid.device.type != img.device.type:
663+
grid = grid.to(img)
664+
662665
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
663666

664667
if need_squeeze:

0 commit comments

Comments
 (0)