Description
Proposal
Following discussions started in #9, #230, #240, and most recently in #610, I would like to propose the following change to transforms.Compose
(and the transform classes) that would allow easy processing of multiple images with the same parameters using existing infrastructure. I think this would be very useful for segmentation tasks, where both the image and label need to be transformed exactly the same way.
Details
Currently the problem is that each transform, when called, implicitly generates randomized parameters (if it is a random transform) right before computing the transformation. In my opinion, it doesn't have to be so - parameter generation (get_params
) is already separated from the actual image operation (which relies on the functional backend). My idea comes in two parts: first, completely decouple parameter generation from transformation; then allow Compose
to generate parameters once and apply transformations multiple times.
Step 1, on the example of RandomResizedCrop
:
- Add a
generate_params
method, to access the existingget_params
but without the need to pass specific arguments. This function would look exactly the same for every transform that needs any random parameters. Passing specific arguments toget_params
will be implementation-dependent.
def generate_params(self, image):
return self.get_params(image, self.scale, self.ratio)
- Allow
__call__
to optionally accept a tuple of pre-generated params:
def __call__(self, img, params=None):
if params:
i, j, h, w = params
else:
i, j, h, w = self.get_params(img, self.scale, self.ratio)
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
Step 2 is enabling this functionality in Compose
by changing __call__
to accept iterables. Alternatively, we could subclass it entirely, which I will do in this example:
class MultiCompose(Compose):
def __call__(self, imgs):
for t in self.transforms:
try:
params = t.generate_params()
except AttributeError:
params = None
imgs = tuple(t(img, params) for img in imgs)
return imgs
Subclassing offers some advantages, for example interpolation methods could be bound to iterable indices at __init__
, so we could interpolate the first item bilinearly, and the second with nearest-neighbour (ideal for segmentation).
Alternative approach
Instead of doing try/except in the Compose
subclass, all transforms could be changed to inherit from a new BaseTransform
abstract class, which could define generate_params
as a trivial function returning None
. Then we could just do:
class MultiCompose(Compose):
def __call__(self, imgs):
for t in self.transforms:
params = t.generate_params()
imgs = tuple(t(img, params) for img in imgs)
return imgs
because static transforms like Pad
would simply return None, while any random transforms would need to define generate_params
accordingly.
Yes I do realize that this requires a slight refactoring of e.g. RandomHorizontalFlip
Usage
The user could subclass Dataset
to yield (image, label) tuples. This change would allow them to apply custom preprocessing/augmentation transformations separately, instead of hard-coding them in the dataset implementation using functional backend. It would look sort of like this:
data_transform = transforms.MultiCompose([
transforms.RandomCrop(256),
transforms.ToTensor()
])
data_source = datasets.SegmentationDataset( # user class
root='/path/to/data/,
transform=data_transform
)
loader = torch.utils.data.DataLoader(data_source, batch_size=4, num_workers=8)
I think this would be a significantly more convenient way of doing this.
Let me know if you think this is worth pursuing. I will have some free time next week so if this would be useful and has a chance of being merged, I'd happily implement it myself. If you see any potential pitfalls or backwards-compatibility ruining caveats - please tell me as well.
Addenda
Later I have found PR #611, but it seems to have been abandoned by now, having encountered some issues that I think my plan of attack can overcome.
Some deal of the problems with this stem from the fact that get_params
, since their introduction in #311, do not share an interface between classes. Instead, getting params for each transform is a completely different call. This feels anti-OOP and counter-intuitive to me; are there any reasons why this has been made this way? @alykhantejani?
cc @vfdev-5