Skip to content

Commit 6239988

Browse files
committed
Make RandomHorizontalFlip torchscriptable
1 parent 98aa805 commit 6239988

File tree

5 files changed

+91
-13
lines changed

5 files changed

+91
-13
lines changed

test/test_transforms_tensor.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
from torchvision import transforms as T
3+
from torchvision.transforms import functional as F
4+
from PIL import Image
5+
6+
import numpy as np
7+
8+
import unittest
9+
10+
11+
class Tester(unittest.TestCase):
12+
def _create_data(self, height=3, width=3, channels=3):
13+
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8)
14+
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().numpy())
15+
return tensor, pil_img
16+
17+
def compareTensorToPIL(self, tensor, pil_image):
18+
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
19+
self.assertTrue(tensor.equal(pil_tensor))
20+
21+
def test_random_horizontal_flip(self):
22+
tensor, pil_img = self._create_data()
23+
flip_tensor = F.hflip(tensor)
24+
flip_pil_img = F.hflip(pil_img)
25+
self.compareTensorToPIL(flip_tensor, flip_pil_img)
26+
27+
scripted_fn = torch.jit.script(F.hflip)
28+
flip_tensor_script = scripted_fn(tensor)
29+
self.assertTrue(flip_tensor.equal(flip_tensor_script))
30+
31+
# test for class interface
32+
f = T.RandomHorizontalFlip()
33+
scripted_fn = torch.jit.script(f)
34+
scripted_fn(tensor)
35+
36+
37+
if __name__ == '__main__':
38+
unittest.main()

torchvision/transforms/functional.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from torch import Tensor
23
import math
34
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
45
try:
@@ -11,6 +12,9 @@
1112
from collections.abc import Sequence, Iterable
1213
import warnings
1314

15+
from . import functional_pil as F_pil
16+
from . import functional_tensor as F_t
17+
1418

1519
def _is_pil_image(img):
1620
if accimage is not None:
@@ -428,19 +432,22 @@ def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINE
428432
return img
429433

430434

431-
def hflip(img):
432-
"""Horizontally flip the given PIL Image.
435+
def hflip(img: Tensor) -> Tensor:
436+
"""Horizontally flip the given PIL Image or torch Tensor.
433437
434438
Args:
435-
img (PIL Image): Image to be flipped.
439+
img (PIL Image or Torch Tensor): Image to be flipped. If img
440+
is a Tensor, it is expected to be in [..., H, W] format,
441+
where ... means it can have an arbitrary number of trailing
442+
dimensions.
436443
437444
Returns:
438445
PIL Image: Horizontally flipped image.
439446
"""
440-
if not _is_pil_image(img):
441-
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
447+
if not isinstance(img, torch.Tensor):
448+
return F_pil.hflip(img)
442449

443-
return img.transpose(Image.FLIP_LEFT_RIGHT)
450+
return F_t.hflip(img)
444451

445452

446453
def _parse_fill(fill, img, min_pil_version):
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import torch
2+
try:
3+
import accimage
4+
except ImportError:
5+
accimage = None
6+
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
7+
8+
9+
@torch.jit.unused
10+
def _is_pil_image(img):
11+
if accimage is not None:
12+
return isinstance(img, (Image.Image, accimage.Image))
13+
else:
14+
return isinstance(img, Image.Image)
15+
16+
17+
@torch.jit.unused
18+
def hflip(img):
19+
"""Horizontally flip the given PIL Image.
20+
21+
Args:
22+
img (PIL Image): Image to be flipped.
23+
24+
Returns:
25+
PIL Image: Horizontally flipped image.
26+
"""
27+
if not _is_pil_image(img):
28+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
29+
30+
return img.transpose(Image.FLIP_LEFT_RIGHT)

torchvision/transforms/functional_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
import torchvision.transforms.functional as F
32
from torch import Tensor
43
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple
54

torchvision/transforms/transforms.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -500,25 +500,29 @@ def __repr__(self):
500500
return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
501501

502502

503-
class RandomHorizontalFlip(object):
504-
"""Horizontally flip the given PIL Image randomly with a given probability.
503+
class RandomHorizontalFlip(torch.nn.Module):
504+
"""Horizontally flip the given image randomly with a given probability.
505+
The image can be a PIL Image or a torch Tensor, in which case it is expected
506+
to have [..., H, W] shape, where ... means an arbitrary number of leading
507+
dimensions
505508
506509
Args:
507510
p (float): probability of the image being flipped. Default value is 0.5
508511
"""
509512

510513
def __init__(self, p=0.5):
514+
super().__init__()
511515
self.p = p
512516

513-
def __call__(self, img):
517+
def forward(self, img):
514518
"""
515519
Args:
516-
img (PIL Image): Image to be flipped.
520+
img (PIL Image or Tensor): Image to be flipped.
517521
518522
Returns:
519-
PIL Image: Randomly flipped image.
523+
PIL Image or Tensor: Randomly flipped image.
520524
"""
521-
if random.random() < self.p:
525+
if torch.rand(1) < self.p:
522526
return F.hflip(img)
523527
return img
524528

0 commit comments

Comments
 (0)