Skip to content

Commit 9c20193

Browse files
authored
Make RandomVerticalFlip torchscriptable (#2283)
* Make RandomVerticalFlip torchscriptable * Fix lint
1 parent 016784b commit 9c20193

File tree

4 files changed

+45
-16
lines changed

4 files changed

+45
-16
lines changed

test/test_transforms_tensor.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,27 @@ def compareTensorToPIL(self, tensor, pil_image):
1818
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
1919
self.assertTrue(tensor.equal(pil_tensor))
2020

21-
def test_random_horizontal_flip(self):
21+
def _test_flip(self, func, method):
2222
tensor, pil_img = self._create_data()
23-
flip_tensor = F.hflip(tensor)
24-
flip_pil_img = F.hflip(pil_img)
23+
flip_tensor = getattr(F, func)(tensor)
24+
flip_pil_img = getattr(F, func)(pil_img)
2525
self.compareTensorToPIL(flip_tensor, flip_pil_img)
2626

27-
scripted_fn = torch.jit.script(F.hflip)
27+
scripted_fn = torch.jit.script(getattr(F, func))
2828
flip_tensor_script = scripted_fn(tensor)
2929
self.assertTrue(flip_tensor.equal(flip_tensor_script))
3030

3131
# test for class interface
32-
f = T.RandomHorizontalFlip()
32+
f = getattr(T, method)()
3333
scripted_fn = torch.jit.script(f)
3434
scripted_fn(tensor)
3535

36+
def test_random_horizontal_flip(self):
37+
self._test_flip('hflip', 'RandomHorizontalFlip')
38+
39+
def test_random_vertical_flip(self):
40+
self._test_flip('vflip', 'RandomVerticalFlip')
41+
3642

3743
if __name__ == '__main__':
3844
unittest.main()

torchvision/transforms/functional.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -537,19 +537,22 @@ def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=N
537537
return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation, **opts)
538538

539539

540-
def vflip(img):
541-
"""Vertically flip the given PIL Image.
540+
def vflip(img: Tensor) -> Tensor:
541+
"""Vertically flip the given PIL Image or torch Tensor.
542542
543543
Args:
544-
img (PIL Image): Image to be flipped.
544+
img (PIL Image or Torch Tensor): Image to be flipped. If img
545+
is a Tensor, it is expected to be in [..., H, W] format,
546+
where ... means it can have an arbitrary number of trailing
547+
dimensions.
545548
546549
Returns:
547550
PIL Image: Vertically flipped image.
548551
"""
549-
if not _is_pil_image(img):
550-
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
552+
if not isinstance(img, torch.Tensor):
553+
return F_pil.vflip(img)
551554

552-
return img.transpose(Image.FLIP_TOP_BOTTOM)
555+
return F_t.vflip(img)
553556

554557

555558
def five_crop(img, size):

torchvision/transforms/functional_pil.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,19 @@ def hflip(img):
2828
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
2929

3030
return img.transpose(Image.FLIP_LEFT_RIGHT)
31+
32+
33+
@torch.jit.unused
34+
def vflip(img):
35+
"""Vertically flip the given PIL Image.
36+
37+
Args:
38+
img (PIL Image): Image to be flipped.
39+
40+
Returns:
41+
PIL Image: Vertically flipped image.
42+
"""
43+
if not _is_pil_image(img):
44+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
45+
46+
return img.transpose(Image.FLIP_TOP_BOTTOM)

torchvision/transforms/transforms.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -530,25 +530,29 @@ def __repr__(self):
530530
return self.__class__.__name__ + '(p={})'.format(self.p)
531531

532532

533-
class RandomVerticalFlip(object):
533+
class RandomVerticalFlip(torch.nn.Module):
534534
"""Vertically flip the given PIL Image randomly with a given probability.
535+
The image can be a PIL Image or a torch Tensor, in which case it is expected
536+
to have [..., H, W] shape, where ... means an arbitrary number of leading
537+
dimensions
535538
536539
Args:
537540
p (float): probability of the image being flipped. Default value is 0.5
538541
"""
539542

540543
def __init__(self, p=0.5):
544+
super().__init__()
541545
self.p = p
542546

543-
def __call__(self, img):
547+
def forward(self, img):
544548
"""
545549
Args:
546-
img (PIL Image): Image to be flipped.
550+
img (PIL Image or Tensor): Image to be flipped.
547551
548552
Returns:
549-
PIL Image: Randomly flipped image.
553+
PIL Image or Tensor: Randomly flipped image.
550554
"""
551-
if random.random() < self.p:
555+
if torch.rand(1) < self.p:
552556
return F.vflip(img)
553557
return img
554558

0 commit comments

Comments
 (0)