-
Notifications
You must be signed in to change notification settings - Fork 7.1k
modified transforms.py to accept list of PIL images #611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
if input is a list of PIL images it returns a list of transformed imgs else it retains its old behaviour. if params need to be computed for every img then the params are computed based on the first img of the list. this change was made to ensure that a set of img have the same random transforms applied to them, for example in the image segmentation.
modified code to accept list of pil images
Hi, Thanks for the PR! I'm not sure that this is the way we would want those problems to be handled. It indeed adds a strong assumption that everything in the list if a PIL Image. For this reason, I think that we really want to keep the class interface as simple as possible, and let the user leverage the functional interface to perform their fine-grained transformations themselves. What do you think? |
You do have a point. Then is it possible for us to have a uniform interface like get_params for transforms like random_grayscale etc which would return a random bool that we can use to chain methods inside the dataset. |
Oh, yes, definitely! Could you send a PR fixing this in |
Revert "modified code to accept list of pil images"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for updating the PR. I have a few comments
torchvision/transforms/transforms.py
Outdated
@@ -503,6 +515,16 @@ class RandomVerticalFlip(object): | |||
def __init__(self, p=0.5): | |||
self.p = p | |||
|
|||
def get_params(p): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/transforms/transforms.py
Outdated
@@ -485,7 +496,8 @@ def __call__(self, img): | |||
Returns: | |||
PIL Image: Randomly flipped image. | |||
""" | |||
if random.random() < self.p: | |||
to_flip = self.get_params(self.p) | |||
if to_flip : |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/transforms/transforms.py
Outdated
@@ -511,7 +533,8 @@ def __call__(self, img): | |||
Returns: | |||
PIL Image: Randomly flipped image. | |||
""" | |||
if random.random() < self.p: | |||
to_flip = self.get_params(self.p) | |||
if to_flip : |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/transforms/transforms.py
Outdated
@@ -1037,6 +1060,7 @@ class Grayscale(object): | |||
def __init__(self, num_output_channels=1): | |||
self.num_output_channels = num_output_channels | |||
|
|||
|
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/transforms/transforms.py
Outdated
@@ -1068,6 +1092,16 @@ class RandomGrayscale(object): | |||
def __init__(self, p=0.1): | |||
self.p = p | |||
|
|||
def get_params(p): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/transforms/transforms.py
Outdated
@@ -1077,7 +1111,8 @@ def __call__(self, img): | |||
PIL Image: Randomly grayscaled image. | |||
""" | |||
num_output_channels = 1 if img.mode == 'L' else 3 | |||
if random.random() < self.p: | |||
to_convert = self.get_params(self.p) | |||
if to_convert : |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Actually, now I remember why we didn't have |
I was planning to take a list of transforms as input and then apply them sequentially using compose. I can just pass two lists one for the functional equivalents and one for the actual transforms, now if the actual transforms have a standard interface through which we can get parameters it will make it simpler to abstract away the kind of transforms that we would be applying out of the dataset code |
The problem is that class MyTransform(object):
def __init__(self, **kwargs):
# store params here
def __call__(self, img, masks, etc):
if random.rand() > self.prob1:
# do my magic
return img, masks, etc |
if input is a list of images then output is a list of images else old behaviour is retained
if params need to be computed for every img, then they are computed based on the fist img of the list
this was done to apply the same random transform on each img for example - image segmentation