Skip to content

Proposal for extending transforms #230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
fmassa opened this issue Aug 24, 2017 · 30 comments
Closed

Proposal for extending transforms #230

fmassa opened this issue Aug 24, 2017 · 30 comments

Comments

@fmassa
Copy link
Member

fmassa commented Aug 24, 2017

Up to now, we compose different transforms via Compose transform, which works like nn.Sequential but for transforms. Each transform is applied independently to input and target.
While simple and efficient, there has been an increasing demand on extending the available transforms to accept both input and target, see #9, #115, #221 for some examples. This would allow performing the same random transformations in both the input and the target, for tasks such as semantic segmentation.

There are a few possible approaches and I'll summarize the ones that have been mentioned in here:

  1. provide a set of random transforms working on pairs (or triplets, etc) of images, as proposed for example in [add] A PairRandomCrop for both input and target. #221. There are a couple of downsides of this approach: (a) it doesn't scale when we want to combine images coming from different domains (in object detection, we have image and bounding boxes for example, and we need to reimplement a new transform for each pair, introducing some redundancy in the code), (b) hard-coded to work with 2 inputs (but could be extended to n-inputs without difficulties)
  2. factor out the randomness in the transforms, so that we can use the same functionality we currently have, just needing to add a generate_seed call in __getitem__, as was proposed in Separate random generation from transforms #115 . The drawback from this approach is that in some cases, we need the information of input to be able to perform the transformation in target (imagine flipping a bounding box horizontally, it requires the image width, which is not available right away).
  3. provide a generic base class that samples the random transformation parameters, and passes those parameters to the input arguments, as proposed by @chsasank in the slack channel. While this is an improvement over both [add] A PairRandomCrop for both input and target. #221 and Separate random generation from transforms #115, it still suffers from the same limitation of Separate random generation from transforms #115 as we still only pass one input argument at a time.

Also, all those options actually handle each input independently. Furthermore, for both 1. and 3., the order of the operations is fixed in the dataset class (first input transforms, then target transforms, then joint transforms, for example). There might be cases where it would be convenient to not be restricted to such orderings, and alternate between single transform and joint transform.

One possibility to address those issues would be to provide a set of functions like nn.Split, nn.Concat, nn.Select, etc, so that we can always pass all inputs to the transforms, and let the transforms be implemented by the user explicitly. This mimics the legacy nn behavior.
The downside of this approach is that it gets very complicated to write some transformations, and this doesn't buy us much.

Instead, one simpler approach (which has already been advocated in the past by @colesbury) is to let the user directly subclass the dataset and implement their complex transforms in there. This approach (4.) would look something like

class VOCDatasetSegmentation(VOCDataset):
    def __init__(self, flip=False, **kwargs):
        super(VOCDatasetSegmentation, self).__init__(**kwargs)
        self.flip = flip

    def __getitem__(self, idx):
        image, target = super(VOCDatasetSegmentation, self).__getitem__(idx)
        do_flip = np.random.random() > 0.5
        if self.flip and do_flip:
            # flip image and bbox here
        return image, target

A downside is that we can't easily re-use those transforms in a different Dataset class (COCO for example).

The question now is if there is an intermediate design that we could leverage that keeps the simplicity of 4., without having to subclass the dataset and reimplement every time the same transforms?

What about yet another possibility (5.) would be to let the user write the code of their as follows

class Dataset(object):
    def __init__(self, transforms=None):
        self.transforms = transforms
    def __getitem__(self, idx):
        # get image1, image2, bounding_box
        # the transforms takes all inputs into account
        if self.transforms:
            image1, image2, bounding_box = self.transforms(image1, image2, bounding_box)
        return image1, image2, bounding_box

from torchvision.transforms import random_horizontal_flip

class MyJointRandomFlipTransform(object):
    def __call__(self, image1, image2, bounding_box):
        # provide a functional interface for the current transforms
        # so that they can be easily reused, and have the parameters
        # of the transformation if needed
        image1, params = random_horizontal_flip(image1, return_params=True)
        # reuses the same transformations, if wanted
        image2 = random_horizontal_flip(image2, params=params)
        # no transformation in torchvision for bounding_box, have to do it
        # ourselves
        if params.flip:
            bounding_box[:, 1] = image1.size(2) - bounding_box[:, 1]
            bounding_box[:, 3] = image1.size(2) - bounding_box[:, 3]
        return image1, image2, bounding_box

In this way, we have the flexibility subclassing the dataset, while being more modular and easy to implement.
There would be some differences in the way we write our datasets currently, but we could have a fallback implementation for backward compatibility, in the lines of

class StandardTransform(object):
    def __init__(self, transform, target_transform):
        self.transform = transform
        self.target_transform = target_transform

   def __call__(self, input, target):
        if self.transform:
            input = self.transform(input)
        if self.target_transform:
            target = self.target_transform(target)
        return input, target

and we would replace in the current datasets transform and target_transform by a single transforms, while keeping the old behavior

class Dataset(object):
    def __init__(self, path, transforms=None, transform=None, target_transform=None):
        # assert that only transforms or (transform, target_transform) can be set at a time
        if transforms is None:
           transforms = StandardTransform(transform, target_transform)
        self.transforms = transforms

    # getitem only uses transforms from now on

What do you think? Do you see drawbacks on using such an approach?

@chsasank
Copy link
Contributor

chsasank commented Aug 25, 2017

Wow, really nice writeup summarising possible approaches.

I think this is a cool idea! Simple to understand.

One drawback I can think of is how params are handled. Transforms usually have two kinda of params: some fixed and specified when transform is inited(like output size in RandomCrop) and some random (like top left corner of the crop). Which of them are returned by random_horizontal_flip(image1, return_params=True)? (I assume random_horizontal_flip is a instance of RandomHorizontalFlip). This can make things confusing. But I guess, we can easily work around this.

Another drawback with this is that we cannot have default transforms support bounding boxes. Although we can hack around it, it will be with 'gotchas', which is not cool.

Here's one solution: How about we separate param generation and transform, like in your proposal but the implementation is done in class as methods. We can then allow user to subclass the transform to do whatever the user wants. Let me illustrate with RandomCrop example:

class RandomCrop(object):
    def __init__(self, size):
        self.size = size

    def get_params(self, img):
        w, h = img.size
        th, tw = self.size
        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        return x1, y1, x1 + th, y1 + tw

    @staticmethod
    def transform(img, x1, y1, x2, y2):
        # no self here, all required params will have to specified in params
        return img.crop((x1, y1, x2, y2))


    def __call__(self, img):

        params = self.get_params(img)
        img = self.transform(img, *params)
        return img

Assumptions is that there will always be an image to transform. Now Object detection subclass can be like.

class MyRandomCrop(RandomCrop):
    def __call__(self, img, bbox):

        x1, y1, x2, w2 = self.get_params(img)

        img = self.transform(img, x1, y1, x2, w2)
        # convention: bbox = (x, y, w, h)
        bbox[0] = bbox[0] - x1
        bbox[1] = bbox[1] - y1

        return img, bbox

Disadvantage to this solution is increased complexity. But then we can have a abstract class to explain about get_params().

@fmassa
Copy link
Member Author

fmassa commented Sep 1, 2017

@chsasank I like you proposal, makes it simpler to write new transforms without having to mess up with the functional interface of the transforms that I proposed, and the return_params seemed quite ugly as well :)

Before we move forward, there is another point that I'd like to raise.
With the basic fallback implementation that I mentioned in the original post (StandardTransform), we can have the sequential and individual operations that are currently supported by the datasets.

For more complex cases where we want to apply joint transforms but also individual transforms (for example in segmentation, we perform colour augmentation in the input image, and random crops / flips in the input and target), I think the way to go would be to encourage the user to write their own Transform class, which leverages other Transform classes, and implement the __call__ that chains the operations together, like in pytorch, instead of trying to use the Compose transform to write complicated transformations. Something like

class MySegmentationTransform(object):

    def __init__(self, input_color_transform,
                 joint_random_crop_transform, joint_random_flip_transform):
        self.input_color_transform = input_color_transform
        self.joint_random_crop_transform = joint_random_crop_transform
        self.joint_random_flip_transform = joint_random_flip_transform

    def __call__(self, input, target):
        input, target = self.joint_random_crop_transform(input, target)
        input, target = self.joint_random_flip_transform(input, target)
        input = convert_pil_to_tensor(input)
        target = convert_pil_to_tensor(target)
        input = self.input_color_transform(input) / 255
        return input, target

What do you think?

cc @alykhantejani for feedback as well.

@chsasank
Copy link
Contributor

chsasank commented Sep 1, 2017

This is cool, but I liked Compose :).
We can kind of modify Compose to work with joint transforms like this:
(Observe how my design uses dicts for returns)

class Compose():
    def __call__(image, **kwargs):
        if len(kwargs) == 0:
            for t in self.transforms:
                image = t(image)

            return image
        else:
            kwargs['image'] = image
            for t in self.transforms:
                kwargs = t(**kwargs)

            return kwargs

Don't you think we should support segmentation and bounding boxes in this repo? In this design, it's not such a big deal. We will have methods transform_bboxes and transform_segmap for each transform.

Then call will be like:

 def __call__(self, image, **kwargs):
    params = self.get_params(image)
    image = self.transform(image, *params)

    if len(kwargs) > 0:
        for k, v in kwargs:
            if k.startswith('bboxes'):
                kwargks[k] = self.transform_bboxes(image, v, *params)
            elif k.startswith('segmap'):
                kwargs[k] = self.transform_segmap(image, v, *params)

        kwargs['image'] = image
        return kwargs
    else:
        return img

User can then use transforms like this:

transform = Compose([RandomCrop, HorizontalFlip])
sample = {'image': image, 'segmap': segmap, 'bboxes': bboxes}
transformed_sample = transform(**sample)

@fmassa
Copy link
Member Author

fmassa commented Sep 1, 2017

I'm not super happy with the Compose that you propose. It seems a bit obscure and error-prone to me, but I'll let @alykhantejani give his view on it.
I don't think we should be providing anything for the moment that restricts/over-complexifies the transforms, so I'd restrain from adding specific support for bounding boxes.

Also, segmentation maps are images, right? So why should we handle it differently? The only reason I can see it is if you want to use the Compose you proposed, but for some transforms you don't want to apply them to the segmentation map (like color augmentation). This shouldn't be a problem if you just let the user write their Transforms themselves, right?

@chsasank
Copy link
Contributor

chsasank commented Sep 1, 2017

Sure, it does look a bit complex. Let's leave it out. (Segmentation needs nearest neighbour interpolation too).

@alykhantejani
Copy link
Contributor

Really good summary + ideas on this thread @fmassa and @chsasank

I like the idea of letting the user create transform graphs in any way they want, without overwhelming them with things like JoinTransform, Split and Sequential, so in this respect, I like the idea of the MySegmentationTransform object that @fmassa proposed.

I also like the idea of separating out the params for a transform from the implementation, but was not a fan of the return_params kwarg as, as @chsasank mentioned, it’s a bit opaque as to what is actually returned. (Users will have to continuously refer to docs or source code)

I’m not a huge fan of using dicts with string args for the inputs as again I think there is potential for errors and confusion on the users part.

I would like to propose something that is a hybrid of the proposals above, which is that we create functional forms of the transforms where each function explicitly requires all of the parameters necessary for the transform. i.e. the random_crop function will look like def random_crop(img, x, y, w, h). The users will then pass a single callable to the Dataset (which we can make backwards compatible in the way proposed by @fmassa) which handles the generating of their random params and the chaining of all of their transforms. So we would have something a bit like this:

import torchvision.transforms as F

def my_segmentation_trasform(input, target):
    x, y, w, h = get_random_crop_paras(input, target)
    input = F.random_crop(input, x, y, w, h)
    target = F.random_crop(target, x, y, w, h)
    if random.random() > 0.5:
        input = F.hflip(input)
        target = F.hflip(target)
    input = F.color_transform(input)
    return F.image_to_tensor(input), F.image_to_tensor(target)

This is similar to @chsasank’s idea of using static methods for the transform (which are functions anyway), without the OOP overhead.

what do you guys think?

@chsasank
Copy link
Contributor

chsasank commented Sep 3, 2017

Great. Let's do this. I like this idea. We'll create 'functional' parts of the transforms like what you mentioned and use them in class implementations. Just like how pytorch deals with nn.functional and nn.

Shall I start working on this?

@chsasank
Copy link
Contributor

chsasank commented Sep 3, 2017

I've made a first cut refactor at #240 :)
Have a look and let me know what you think.

@fmassa
Copy link
Member Author

fmassa commented Sep 3, 2017

@chsasank Thanks, very efficient! Having a look at it now

@acgtyrant
Copy link

Will you add the new transforms attribute to those implemented dataset classes?

@alykhantejani
Copy link
Contributor

@acgtyrant I think at first we'll add this new attribute to the datasets it makes sense for (i.e. Coco detection), whereas in some cases there aren't transforms that you would apply jointly to both target and input for example when the target is a number and the input is an image (i.e. most of the other current datasets).

@versatran01
Copy link

Is there a reason why transforms are part of a dataset but not part of the loader? To me it makes more sense to pass transforms into loaders. Could someone explain the design choice here?

@soumith
Copy link
Member

soumith commented Feb 11, 2018

the loader is agnostic of the dataset type that it is handling.

@versatran01
Copy link

Say I have a dataset that I wish to split to train and test. I do this by creating 2 loaders each with a subset sampler of some disjoint indices. And I also have 2 transforms, the one for training has some random data augmentation, and one for testing has none. But with the current design, the dataset can only take in one transform. That's why I thought it would be nice for the loader to take the transforms. Can you think of a better way to handle this situation?

@alykhantejani
Copy link
Contributor

Hi @versatran01,

In your example you could create two instances of the Dataset one with your training transform(s) and one with the testing ones.

@Randl
Copy link

Randl commented Feb 27, 2018

It might be more convenient to have training_transform and validation_transform in dataset, allowing to use a single instance of Dataset

@fmassa
Copy link
Member Author

fmassa commented Mar 5, 2018

@Randl I don't think this is necessary. We can do as @alykhantejani said and use two datasets, one for training and one for testing. There is a simple functionality that allows us to split a dataset: Subset and random_split
So that's what you'd need to do

base_dataset = MyDataset(...)
from torch.utils.data.dataset import random_split
train_size, test_size = len(dataset) - 100, 100
train_dataset, test_dataset = random_split(base_dataset, [train_size, test_size])

And if you need different transforms for train/test, you can do something like

train_dataset = MyDataset(..., transform=train_transform)
val_dataset = MyDataset(..., transform=val_transform)

from torch.utils.data.dataset import Subset

indices = torch.randperm(len(train_dataset))
train_indices = indices[:-100]
val_indices = indices[-100:]

train_dataset = Subset(train_dataset, train_indices)
val_dataset = Subset(val_dataset, val_indices)

@Randl
Copy link

Randl commented Mar 24, 2018

So after we got #240 what are the steps/interface for joint transformations? I'd love to refactor COCO dataloader.

Should i first introduce joint transformations that take iterable of inputs and transforms them?

cc @fmassa

@fmassa
Copy link
Member Author

fmassa commented Mar 24, 2018

@Randl I'm working on a detectron implementation that uses COCO, and I'm thinking about some possible extra extensions of the transforms, that I'll be pushing here soon.

@Randl
Copy link

Randl commented Mar 25, 2018

@fmassa If you're open for contributions, I'd be happy to participate

@Randl
Copy link

Randl commented May 14, 2018

Are there any updates? @fmassa

@fmassa
Copy link
Member Author

fmassa commented May 15, 2018

@Randl there was some nice discussion yesterday on the pytorch/vision slack channel about this. A preview of what I'm planning to do can be found in https://github.com/pytorch/vision/tree/layers , but I'm still playing a bit with the API that's why I haven't pushed it to master

@agaldran
Copy link

Hi!

It's the first time I write in the forum, sorry if make any mistake, or this is not the appropriate place.

I am working on image segmentation, and I found this thread; If i'm right, the recommended approach atm is the one proposed by @alykhantejani here. I took the idea, and re-implemented torchvision.transforms in such a way that its default behavior is the same as usual, but if you pass a tuple (image, target) you get back (transform(image), transform(target)), and both are transformed in exactly the same way. For instance, in the Resize class, I do the following:

if target is not None:
    return F.resize(img, self.size, self.interpolation), \ 
              F.resize(target, self.size, self.interpolation_tg)
return F.resize(img, self.size, self.interpolation)

taking care that the interpolation in the target argument defaults to NEAREST, which is what one often wants in these situations.

I have shared this re-implementation, which I call ''paired transforms'' in this repo, and documented it with a couple of jupyter notebooks. I have few experience in software development, the code may not be bug-free, or it could be quite sub-optimal, but I thought I could share it in case it is useful for someone trying to apply this kind of transforms while minimally modifying their code.

Cheers,

Adrian

@Randl
Copy link

Randl commented May 30, 2018

@fmassa Can you send me invite, I'm not in pytorch slack yet?

@fmassa
Copy link
Member Author

fmassa commented Jun 6, 2018

Thanks @agaldran for sharing the code, we are thinking of ways of making it even easier to add joint transforms in torchvision, but the current recommended way is indeed to do as you did.

@Randl can you request an invite in https://pytorch.org/support/ ?

mdraw added a commit to ELEKTRONN/elektronn3 that referenced this issue Jun 15, 2018
WIP. Not sure about the API design of transforms.py. I decided to
use the general transform specification

    transformed_inp, transformed_target = transform(inp, target)

This has the disadvantage that it won't work with transforms that
are written in the current torchvsion style

    transformed_inp = transform(inp)

but we will definitely need joint input-target-transforms, so I
thought it would be best to use a generalized form that *can* work with
targets, but can also ignore them (just returning the unmodified
target).

The most obvious alternative would have been to use two kinds of
transforms: One that only works on the input (like torchvision's) and
one that does joint transforms to inputs and targets.
I decided against that because it would be more confusing to have two
kinds of incompatible transforms that have to be clearly distinguishable
for the user and would also require at least one more parameter in dataset
constructors (transforms=..., joint_transforms=..., maybe even
target_transforms=...).

Related discussion: pytorch/vision#230
There are some good points for other API designs in there, so we might
reconsider this later.
@patricio-astudillo
Copy link

patricio-astudillo commented Jul 11, 2018

Hi everyone, imgaug has a nice solution for this problem: they abstracted the transformation function where the random parameters are set into one function called 'to_deterministic'. This returns a function without random parameters which you can then apply to any number of images. One requirement to this solution is that you need to call to_deterministic each time you need a new sample otherwise it will take the same random parameters.
source: aleju/imgaug#41
example: https://www.kaggle.com/c/data-science-bowl-2018/discussion/51824
Could this be an idea for a possible fix?
@fmassa, @Randl?

@himat
Copy link

himat commented Jul 2, 2019

@fmassa I would like to point out that this would seem to make more sense if it wasn't specifically limited to source and target, but rather for an arbitrary number of images.

# In my custom dataset class
def __getitem__(self, idx):
    dep_id = self.dep_ids[idx]
    label = self.labels[idx]

    seed = np.random.randint(2147483647) # make a seed with numpy generator

    layer_list = []
    for layer_name in self.layer_names:
        layer_file_path = join(self.data_dir, layer_name, dep_id + ".tif")
        curr_layer = Image.open(layer_file_path)

        random.seed(seed) # reset seed every time
        if self.transform is not None:
            curr_layer = self.transform(curr_layer)

        layer_list.append(curr_layer)

    # Combine all layers together
    data = np.stack(layer_list, axis=0) # many_layers X H X W

A better way for this with many-channel images is if you could just pass in a list of images to be transformed in the same way to the transform call.

An even more ideal way would be if the transforms supported np arrays directly instead of needing the data to be PIl Images. It would also work if PIL supported arbitrary number of channel images. Because in this case, I could do:

# Combine all layers together
data = np.stack(layer_list, axis=0) # many_layers X H X W

if self.transform:
    self.transform(data) # Transform all layers at once in this big "image"

Right now, this last way is not possible because I can't store a 7 layer image in a PIL Image object.

If you have an alternative to what I'm currently doing in the loop, that would be good to hear.

@fmassa
Copy link
Member Author

fmassa commented Jul 3, 2019

@himat check https://github.com/pytorch/vision/blob/master/references/segmentation/transforms.py for how I currently do it. It's simple, and the most generic way I could find to solve those problems.

@himat
Copy link

himat commented Jul 3, 2019

Right, but I'm just saying that it's not that generic if it's fixed to two objects. Why not just be able to pass in a list of arbitrarily many images.

    def __call__(self, image, target):
        if random.random() < self.flip_prob:
            image = F.hflip(image)
            target = F.hflip(target)
        return image, target

vs

    def __call__(self, images):
        if random.random() < self.flip_prob:
            images = [F.hflip(x) for x in images]
        return images

@fmassa
Copy link
Member Author

fmassa commented Jul 4, 2019

@himat it's generic in the sense that there is no magic happening inside, and it's straightforward for the user to write their own modified implementation.

(Recall that you might want to apply different types of interpolation depending on the input type, if it's an image or a segmentation mask)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants