diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index eeaec1ac9fc..a9f76fcff07 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -18,21 +18,27 @@ def compareTensorToPIL(self, tensor, pil_image): pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))) self.assertTrue(tensor.equal(pil_tensor)) - def test_random_horizontal_flip(self): + def _test_flip(self, func, method): tensor, pil_img = self._create_data() - flip_tensor = F.hflip(tensor) - flip_pil_img = F.hflip(pil_img) + flip_tensor = getattr(F, func)(tensor) + flip_pil_img = getattr(F, func)(pil_img) self.compareTensorToPIL(flip_tensor, flip_pil_img) - scripted_fn = torch.jit.script(F.hflip) + scripted_fn = torch.jit.script(getattr(F, func)) flip_tensor_script = scripted_fn(tensor) self.assertTrue(flip_tensor.equal(flip_tensor_script)) # test for class interface - f = T.RandomHorizontalFlip() + f = getattr(T, method)() scripted_fn = torch.jit.script(f) scripted_fn(tensor) + def test_random_horizontal_flip(self): + self._test_flip('hflip', 'RandomHorizontalFlip') + + def test_random_vertical_flip(self): + self._test_flip('vflip', 'RandomVerticalFlip') + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index fe2ac048fec..22adba6ccc5 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -537,19 +537,22 @@ def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=N return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation, **opts) -def vflip(img): - """Vertically flip the given PIL Image. +def vflip(img: Tensor) -> Tensor: + """Vertically flip the given PIL Image or torch Tensor. Args: - img (PIL Image): Image to be flipped. + img (PIL Image or Torch Tensor): Image to be flipped. If img + is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of trailing + dimensions. Returns: PIL Image: Vertically flipped image. """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not isinstance(img, torch.Tensor): + return F_pil.vflip(img) - return img.transpose(Image.FLIP_TOP_BOTTOM) + return F_t.vflip(img) def five_crop(img, size): diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 00200212e4d..e387924ad36 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -28,3 +28,19 @@ def hflip(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return img.transpose(Image.FLIP_LEFT_RIGHT) + + +@torch.jit.unused +def vflip(img): + """Vertically flip the given PIL Image. + + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Vertically flipped image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.transpose(Image.FLIP_TOP_BOTTOM) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index bad2f9ab3f8..5c202f384ee 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -530,25 +530,29 @@ def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.p) -class RandomVerticalFlip(object): +class RandomVerticalFlip(torch.nn.Module): """Vertically flip the given PIL Image randomly with a given probability. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions Args: p (float): probability of the image being flipped. Default value is 0.5 """ def __init__(self, p=0.5): + super().__init__() self.p = p - def __call__(self, img): + def forward(self, img): """ Args: - img (PIL Image): Image to be flipped. + img (PIL Image or Tensor): Image to be flipped. Returns: - PIL Image: Randomly flipped image. + PIL Image or Tensor: Randomly flipped image. """ - if random.random() < self.p: + if torch.rand(1) < self.p: return F.vflip(img) return img