Skip to content

Commit 11a39aa

Browse files
authored
Make RandomHorizontalFlip torchscriptable (#2282)
* Make RandomHorizontalFlip torchscriptable * Make _is_tensor_a_torch_image more generic * Make RandomVerticalFlip torchscriptable (#2283) * Make RandomVerticalFlip torchscriptable * Fix lint
1 parent de52437 commit 11a39aa

File tree

5 files changed

+132
-25
lines changed

5 files changed

+132
-25
lines changed

test/test_transforms_tensor.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
from torchvision import transforms as T
3+
from torchvision.transforms import functional as F
4+
from PIL import Image
5+
6+
import numpy as np
7+
8+
import unittest
9+
10+
11+
class Tester(unittest.TestCase):
12+
def _create_data(self, height=3, width=3, channels=3):
13+
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8)
14+
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().numpy())
15+
return tensor, pil_img
16+
17+
def compareTensorToPIL(self, tensor, pil_image):
18+
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
19+
self.assertTrue(tensor.equal(pil_tensor))
20+
21+
def _test_flip(self, func, method):
22+
tensor, pil_img = self._create_data()
23+
flip_tensor = getattr(F, func)(tensor)
24+
flip_pil_img = getattr(F, func)(pil_img)
25+
self.compareTensorToPIL(flip_tensor, flip_pil_img)
26+
27+
scripted_fn = torch.jit.script(getattr(F, func))
28+
flip_tensor_script = scripted_fn(tensor)
29+
self.assertTrue(flip_tensor.equal(flip_tensor_script))
30+
31+
# test for class interface
32+
f = getattr(T, method)()
33+
scripted_fn = torch.jit.script(f)
34+
scripted_fn(tensor)
35+
36+
def test_random_horizontal_flip(self):
37+
self._test_flip('hflip', 'RandomHorizontalFlip')
38+
39+
def test_random_vertical_flip(self):
40+
self._test_flip('vflip', 'RandomVerticalFlip')
41+
42+
43+
if __name__ == '__main__':
44+
unittest.main()

torchvision/transforms/functional.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from torch import Tensor
23
import math
34
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
45
try:
@@ -11,6 +12,9 @@
1112
from collections.abc import Sequence, Iterable
1213
import warnings
1314

15+
from . import functional_pil as F_pil
16+
from . import functional_tensor as F_t
17+
1418

1519
def _is_pil_image(img):
1620
if accimage is not None:
@@ -434,19 +438,22 @@ def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINE
434438
return img
435439

436440

437-
def hflip(img):
438-
"""Horizontally flip the given PIL Image.
441+
def hflip(img: Tensor) -> Tensor:
442+
"""Horizontally flip the given PIL Image or torch Tensor.
439443
440444
Args:
441-
img (PIL Image): Image to be flipped.
445+
img (PIL Image or Torch Tensor): Image to be flipped. If img
446+
is a Tensor, it is expected to be in [..., H, W] format,
447+
where ... means it can have an arbitrary number of trailing
448+
dimensions.
442449
443450
Returns:
444451
PIL Image: Horizontally flipped image.
445452
"""
446-
if not _is_pil_image(img):
447-
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
453+
if not isinstance(img, torch.Tensor):
454+
return F_pil.hflip(img)
448455

449-
return img.transpose(Image.FLIP_LEFT_RIGHT)
456+
return F_t.hflip(img)
450457

451458

452459
def _parse_fill(fill, img, min_pil_version):
@@ -536,19 +543,22 @@ def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=N
536543
return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation, **opts)
537544

538545

539-
def vflip(img):
540-
"""Vertically flip the given PIL Image.
546+
def vflip(img: Tensor) -> Tensor:
547+
"""Vertically flip the given PIL Image or torch Tensor.
541548
542549
Args:
543-
img (PIL Image): Image to be flipped.
550+
img (PIL Image or Torch Tensor): Image to be flipped. If img
551+
is a Tensor, it is expected to be in [..., H, W] format,
552+
where ... means it can have an arbitrary number of trailing
553+
dimensions.
544554
545555
Returns:
546556
PIL Image: Vertically flipped image.
547557
"""
548-
if not _is_pil_image(img):
549-
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
558+
if not isinstance(img, torch.Tensor):
559+
return F_pil.vflip(img)
550560

551-
return img.transpose(Image.FLIP_TOP_BOTTOM)
561+
return F_t.vflip(img)
552562

553563

554564
def five_crop(img, size):
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
try:
3+
import accimage
4+
except ImportError:
5+
accimage = None
6+
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
7+
8+
9+
@torch.jit.unused
10+
def _is_pil_image(img):
11+
if accimage is not None:
12+
return isinstance(img, (Image.Image, accimage.Image))
13+
else:
14+
return isinstance(img, Image.Image)
15+
16+
17+
@torch.jit.unused
18+
def hflip(img):
19+
"""Horizontally flip the given PIL Image.
20+
21+
Args:
22+
img (PIL Image): Image to be flipped.
23+
24+
Returns:
25+
PIL Image: Horizontally flipped image.
26+
"""
27+
if not _is_pil_image(img):
28+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
29+
30+
return img.transpose(Image.FLIP_LEFT_RIGHT)
31+
32+
33+
@torch.jit.unused
34+
def vflip(img):
35+
"""Vertically flip the given PIL Image.
36+
37+
Args:
38+
img (PIL Image): Image to be flipped.
39+
40+
Returns:
41+
PIL Image: Vertically flipped image.
42+
"""
43+
if not _is_pil_image(img):
44+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
45+
46+
return img.transpose(Image.FLIP_TOP_BOTTOM)

torchvision/transforms/functional_tensor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import torch
2-
import torchvision.transforms.functional as F
32
from torch import Tensor
43
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple
54

65

76
def _is_tensor_a_torch_image(input):
8-
return len(input.shape) == 3
7+
return input.ndim >= 2
98

109

1110
def vflip(img):

torchvision/transforms/transforms.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -500,51 +500,59 @@ def __repr__(self):
500500
return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
501501

502502

503-
class RandomHorizontalFlip(object):
504-
"""Horizontally flip the given PIL Image randomly with a given probability.
503+
class RandomHorizontalFlip(torch.nn.Module):
504+
"""Horizontally flip the given image randomly with a given probability.
505+
The image can be a PIL Image or a torch Tensor, in which case it is expected
506+
to have [..., H, W] shape, where ... means an arbitrary number of leading
507+
dimensions
505508
506509
Args:
507510
p (float): probability of the image being flipped. Default value is 0.5
508511
"""
509512

510513
def __init__(self, p=0.5):
514+
super().__init__()
511515
self.p = p
512516

513-
def __call__(self, img):
517+
def forward(self, img):
514518
"""
515519
Args:
516-
img (PIL Image): Image to be flipped.
520+
img (PIL Image or Tensor): Image to be flipped.
517521
518522
Returns:
519-
PIL Image: Randomly flipped image.
523+
PIL Image or Tensor: Randomly flipped image.
520524
"""
521-
if random.random() < self.p:
525+
if torch.rand(1) < self.p:
522526
return F.hflip(img)
523527
return img
524528

525529
def __repr__(self):
526530
return self.__class__.__name__ + '(p={})'.format(self.p)
527531

528532

529-
class RandomVerticalFlip(object):
533+
class RandomVerticalFlip(torch.nn.Module):
530534
"""Vertically flip the given PIL Image randomly with a given probability.
535+
The image can be a PIL Image or a torch Tensor, in which case it is expected
536+
to have [..., H, W] shape, where ... means an arbitrary number of leading
537+
dimensions
531538
532539
Args:
533540
p (float): probability of the image being flipped. Default value is 0.5
534541
"""
535542

536543
def __init__(self, p=0.5):
544+
super().__init__()
537545
self.p = p
538546

539-
def __call__(self, img):
547+
def forward(self, img):
540548
"""
541549
Args:
542-
img (PIL Image): Image to be flipped.
550+
img (PIL Image or Tensor): Image to be flipped.
543551
544552
Returns:
545-
PIL Image: Randomly flipped image.
553+
PIL Image or Tensor: Randomly flipped image.
546554
"""
547-
if random.random() < self.p:
555+
if torch.rand(1) < self.p:
548556
return F.vflip(img)
549557
return img
550558

0 commit comments

Comments
 (0)