Skip to content

modified transforms.py to accept list of PIL images #611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
42 changes: 39 additions & 3 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,17 @@ class RandomHorizontalFlip(object):
def __init__(self, p=0.5):
self.p = p

@staticmethod
def get_params(p):
"""Get parameters for ``crop`` for a random crop.

Args:
p : probability of flipping
Returns:
tuple: bool
"""
return random.random() < p

def __call__(self, img):
"""
Args:
Expand All @@ -485,7 +496,8 @@ def __call__(self, img):
Returns:
PIL Image: Randomly flipped image.
"""
if random.random() < self.p:
to_flip = self.get_params(self.p)
if to_flip:
return F.hflip(img)
return img

Expand All @@ -503,6 +515,17 @@ class RandomVerticalFlip(object):
def __init__(self, p=0.5):
self.p = p

@staticmethod
def get_params(p):
"""Get parameters for ``crop`` for a random crop.

Args:
p : probability of flipping
Returns:
tuple: bool
"""
return random.random() < p

def __call__(self, img):
"""
Args:
Expand All @@ -511,7 +534,8 @@ def __call__(self, img):
Returns:
PIL Image: Randomly flipped image.
"""
if random.random() < self.p:
to_flip = self.get_params(self.p)
if to_flip:
return F.vflip(img)
return img

Expand Down Expand Up @@ -1068,6 +1092,17 @@ class RandomGrayscale(object):
def __init__(self, p=0.1):
self.p = p

@staticmethod
def get_params(p):
"""Get parameters for ``crop`` for a random crop.

Args:
p : probability of converting to grayscale
Returns:
tuple: bool
"""
return random.random() < p

def __call__(self, img):
"""
Args:
Expand All @@ -1077,7 +1112,8 @@ def __call__(self, img):
PIL Image: Randomly grayscaled image.
"""
num_output_channels = 1 if img.mode == 'L' else 3
if random.random() < self.p:
to_convert = self.get_params(self.p)
if to_convert:
return F.to_grayscale(img, num_output_channels=num_output_channels)
return img

Expand Down