diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index a640ea403f5..108b5dfdcc0 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -477,6 +477,17 @@ class RandomHorizontalFlip(object): def __init__(self, p=0.5): self.p = p + @staticmethod + def get_params(p): + """Get parameters for ``crop`` for a random crop. + + Args: + p : probability of flipping + Returns: + tuple: bool + """ + return random.random() < p + def __call__(self, img): """ Args: @@ -485,7 +496,8 @@ def __call__(self, img): Returns: PIL Image: Randomly flipped image. """ - if random.random() < self.p: + to_flip = self.get_params(self.p) + if to_flip: return F.hflip(img) return img @@ -503,6 +515,17 @@ class RandomVerticalFlip(object): def __init__(self, p=0.5): self.p = p + @staticmethod + def get_params(p): + """Get parameters for ``crop`` for a random crop. + + Args: + p : probability of flipping + Returns: + tuple: bool + """ + return random.random() < p + def __call__(self, img): """ Args: @@ -511,7 +534,8 @@ def __call__(self, img): Returns: PIL Image: Randomly flipped image. """ - if random.random() < self.p: + to_flip = self.get_params(self.p) + if to_flip: return F.vflip(img) return img @@ -1068,6 +1092,17 @@ class RandomGrayscale(object): def __init__(self, p=0.1): self.p = p + @staticmethod + def get_params(p): + """Get parameters for ``crop`` for a random crop. + + Args: + p : probability of converting to grayscale + Returns: + tuple: bool + """ + return random.random() < p + def __call__(self, img): """ Args: @@ -1077,7 +1112,8 @@ def __call__(self, img): PIL Image: Randomly grayscaled image. """ num_output_channels = 1 if img.mode == 'L' else 3 - if random.random() < self.p: + to_convert = self.get_params(self.p) + if to_convert: return F.to_grayscale(img, num_output_channels=num_output_channels) return img