Skip to content

[RFC] torchvision.transforms revamp #5754

@pmeier

Description

@pmeier

🚀 The feature

Note: To track the progress of the project check out this board.

The current transforms in the torchvision.transforms name space are limited to images. This makes it hard to use them for tasks that require the transform to be applied not only to the input image, but also to the target. For example, in object detection, resizing or cropping the input image also affects the bounding boxes.

This projects aims to resolve this by providing transforms that can handle the full sample possibly including images, bounding boxes, segmentation masks, and so on without the need for user interference. The implementation of this project happens in the torchvision.prototype.transforms namespace. For example:

from torchvision.prototype import features, transforms

# this will be supplied by a dataset from torchvision.prototype.datasets
image = features.EncodedImage.from_path("test/assets/fakedata/logos/rgb_pytorch.png")
label = features.Label(0)

transform = transforms.Compose(
    transforms.DecodeImage(),
    transforms.RandomHorizontalFlip(p=1.0),
    transforms.Resize((100, 300)),
)

transformed_image1 = transform(image)
transformed_image2, transformed_label = transform(image, label)

# whether we call the transform with label or not has no effect on the image transform
assert transformed_image1.eq(transformed_image2).all()
# the transform is a no-op for labels
assert transformed_label.eq(label).all()

before_image after_image

# this will be supplied by a dataset from torchvision.prototype.datasets
bounding_box = features.BoundingBox(
    [60, 30, 15, 15], format=features.BoundingBoxFormat.CXCYWH, image_size=(100, 100)
)

transformed_image, transformed_bounding_box = transform(image, bounding_box)

before_image_and_bounding_box after_image_and_bounding_box

# this will be supplied by a dataset from torchvision.prototype.datasets
segmentation_mask = torch.zeros((100, 100), dtype=torch.bool)
segmentation_mask[24:36, 55:66] = True
segmentation_mask = features.SegmentationMask(segmentation_mask)

transformed_image, transformed_segmentation_mask = transform(image, segmentation_mask)

before_image_and_segmentation_mask after_image_and_segmentation_mask

Classification

Detection

Segmentation

Other

cc @vfdev-5 @datumbox

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions