Skip to content

compatibility layer between stable datasets and prototype transforms? #6662

@pmeier

Description

@pmeier

The original plan was to roll out the datasets and transforms revamp at the same time since they somewhat depend on each other. However, it is becoming more and more likely that the prototype transforms will be finished sooner. Thus, we need some compatibility layer in the meantime. This issue explains how transforms are currently used with the datasets, what will or will not work without a compatibility layer, and how such a compatibility layer might look like.

Status quo

Most of our datasets support the transform and target_transform idiom. These transformations are applied separately to the first and second item of the raw sample returned by the dataset. For classification tasks this usually sufficient although I've never seen a practical use for target_transform:

dataset = torchvision.datasets.ImageFolder(
traindir,
presets.ClassificationPresetTrain(
crop_size=train_crop_size,
interpolation=interpolation,
auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob,
),
)

However, the separation of the transforms breaks down in case image and label need to be transformed at the same time, e.g. CutMix or MixUp. They are currently applied through a custom collation function for the dataloader:

mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731

Since these transforms do not work with the standard idioms, they never made it out of our references into the library.

The need to transform input and target at the same time is not a special case for other tasks such as segmentation or detection. Datasets for these tasks support the transforms parameter. It will be called with the complete sample and thus is able to support all use cases.

Since even datasets for the same task have very diverse outputs, there were only two options without revamping the APIs completely:

  1. Unify the datasets outputs on the dataset itself.
  2. Unify the datasets outputs through a compatibility layer.

When this first came up in the past, we went with option 2. On our references we unified the output for a few select datasets for a specific task, so we can apply custom joint transformations to them. Since we didn't want to commit to the interface, neither the minimal compatibility layer nor the transformations made it into the library. Thus, although some of our datasets in theory support joint transformations, the users have to implement them themselves.

Do we need a compatibility layer?

The new transformations support the joint use case out of the box. Meaning, all the custom transformations from our references are now part of the library. Plus, all transformations that previously only supported images, e.g. resizing or padding, now also support bounding boxes, masks and so on.

The information which part of the sample is what kind of type is not communicated through the sample structure, i.e. first element is an image and second one is a mask, but rather through the actual type of the object. We introduced several tensor subclasses that will be rolled out together with the transforms.

By treating simple tensors, i.e. not the new subclasses, as images, the new transformations are full BC1. Thus, if you previously only used the separated transform and target_transform idiom you can continue to do that and the new transforms will not get into your way:

import torch
from torchvision import datasets
from torchvision.prototype import transforms

transform = transforms.Compose(
    [
        transforms.PILToTensor(),
        transforms.Resize(256),
        transforms.CenterCrop(224),
    ]
)
dataset = datasets.ImageNet(..., transform=transform)

image, label = dataset[0]
assert isinstance(image, torch.Tensor)
assert image.shape[-2:] == (224, 224)
assert isinstance(label, int)

The transforms also work out of the box if you want to stick to PIL images:

import PIL.Image
from torchvision import datasets
from torchvision.prototype import transforms

transform = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
    ]
)
dataset = datasets.ImageNet(..., transform=transform)

image, label = dataset[0]
assert isinstance(image, PIL.Image.Image)
assert image.size == (224, 224)
assert isinstance(label, int)

Although it seems the new transforms can also be used out of the box if the dataset supports the transforms parameter, this unfortunately not the case. While the new datasets will provide the sample parts wrapped into the new tensor subclasses, the old datasets, i.e. the only ones available during the roll-out of the new transforms, do not.

Without the wrapping, the transform does not pick up on bounding boxes and subsequently does not transform them:

import torch
import PIL.Image
from torchvision import datasets
from torchvision.prototype import transforms

transform = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
    ]
)
dataset = datasets.CocoDetection(..., transforms=transform)

image, target = dataset[0]
assert isinstance(image, PIL.Image.Image)
assert image.size == (224, 224)

assert len(target) == 8

bbox = target[2]["bbox"]
# bounding boxes were not downsized and thus are now out of sync with the image
torch.testing.assert_close([int(coord) for coord in target[2]["bbox"]], [249, 229, 316, 245])

segmentation = target[2]["segmentation"]
# masks were not downsized and thus are now out of sync with the image. Plus, they still encoded and the user has to
# decode them themselves
assert isinstance(segmentation, list) and all(isinstance(item, (int, float)) for item in segmentation)

Masks will be transformed, but without wrapping they will be treated as normal images. This means, by default InterpolationMode.BILINEAR is used for interpolation, which will corrupt the information:

import torch
from torchvision import datasets
from torchvision.prototype import transforms

transform = transforms.Compose(
    [
        transforms.PILToTensor(),
        # we convert to float here to make the bilinear interpolation visible
        transforms.ConvertImageDtype(torch.float64),
        transforms.Resize(256),
        transforms.CenterCrop(224),
    ]
)
dataset = datasets.VOCSegmentation(..., transforms=transform)

image, mask = dataset[0]
assert isinstance(image, torch.Tensor)
assert image.shape[-2:] == (224, 224)
assert isinstance(mask, torch.Tensor)
assert mask.shape[-2:] == (224, 224)
# If the interpolation worked correctly, we would only see integer values in the uint8 range of [0, 255]
assert torch.any(torch.fmod(mask * 255, 1) > 0)

Thus, if we don't provide a compatibility layer until our datasets wrap automatically, the prototype transforms don't bring any real benefit to the user of our datasets.

Proposal

I propose to provide a thin wrapper for the datasets that does nothing else than wrapping the returned samples into the new tensor subclasses. This means, that the new object behaves exactly as the dataset as before, but upon accessing an element, i.e. calling __getitem__, we wrap the samples before passing them into the transforms.

from torchvision import datasets
from torchvision.prototype import transforms, features

transform = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
    ]
)
dataset = datasets.ImageNet(..., transform=transform)
dataset = features.VisionDatasetFeatureWrapper.from_torchvision_dataset(dataset)

image, label = dataset[0]
assert isinstance(image, features.Image)
assert image.image_size == (224, 224)
assert isinstance(label, features.Label)
assert label.to_categories() == "tench, Tinca tinca"

Going back to the segmentation example from above, with the wrapper in place the segmentation mask is now correctly
interpolated with InterpolationMode.NEAREST:

import torch
from torchvision import datasets
from torchvision.prototype import transforms, features

transform = transforms.Compose(
    [
        # we convert to float here to make the bilinear interpolation visible
        transforms.ToDtype(torch.float64, features.Mask),
        transforms.Resize(256),
        transforms.CenterCrop(224),
    ]
)
dataset = datasets.VOCSegmentation(..., transforms=transform)
dataset = features.VisionDatasetFeatureWrapper.from_torchvision_dataset(dataset)

image, mask = dataset[0]
assert isinstance(mask, torch.Tensor)
assert mask.shape[-2:] == (224, 224)
assert not torch.any(torch.fmod(mask * 255, 1) > 0)

In general, the wrapper should not change the structure of the sample unless it is necessary to be able to properly use
the new transformations. For example, the target of COCODetection is a list of dictionaries, in which each
dictionary holds the information for one object. Our models however require a dictionary where the value of the
bounding box key is a (N, 4) tensor, where N is the number of objects. Furthermore, while our basic transform can
work with individual bounding boxes, more elaborate ones that we ported from the reference scripts also require this
format.

Thus, if needed, we also perform this collation inside the wrapper:

import torch
from torchvision import datasets
from torchvision.prototype import transforms, features

transform = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
    ]
)
dataset = datasets.CocoDetection(..., transforms=transform)
dataset = features.VisionDatasetFeatureWrapper.from_torchvision_dataset(dataset)

image, target = dataset[0]

assert isinstance(image, features.Image)
assert image.shape[-2:] == (224, 224)

bbox = target["bbox"]
assert isinstance(bbox, features.BoundingBox)
assert bbox.shape == (8, 4)
torch.testing.assert_close(bbox[2].int().tolist(), [116, 106, 152, 114])

Furthermore, if the data is in an encoded state, like the masks the COCODetection provides, will be decoded so they can be used directly by the transforms and models:

segmentation = target["segmentation"]
assert isinstance(segmentation, features.Mask)
assert segmentation.shape == (8, 224, 224)

The VisionDatasetFeatureWrapper class in the examples above is implemented as a proof of concept in #6663.

Conclusion

If we don't roll out the new datasets at the same time as the new transformations, the transformations on their own will bring little value to the user. Their whole power can only be unleashed if we add a thin compatibility layer between them and the "old" datasets. I've proposed an, IMO clean, implementation for such a compatibility layer.

cc @vfdev-5 @datumbox @bjuncek

Footnotes

  1. Fully BC for what is discussed here. The only thing that will be BC breaking is that the new transforms will no longer be torch.jit.script'able whereas they were before.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions