-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Open
Labels
Description
🚀 The feature
Note: To track the progress of the project check out this board.
The current transforms in the torchvision.transforms
name space are limited to images. This makes it hard to use them for tasks that require the transform to be applied not only to the input image, but also to the target. For example, in object detection, resizing or cropping the input image also affects the bounding boxes.
This projects aims to resolve this by providing transforms that can handle the full sample possibly including images, bounding boxes, segmentation masks, and so on without the need for user interference. The implementation of this project happens in the torchvision.prototype.transforms
namespace. For example:
from torchvision.prototype import features, transforms
# this will be supplied by a dataset from torchvision.prototype.datasets
image = features.EncodedImage.from_path("test/assets/fakedata/logos/rgb_pytorch.png")
label = features.Label(0)
transform = transforms.Compose(
transforms.DecodeImage(),
transforms.RandomHorizontalFlip(p=1.0),
transforms.Resize((100, 300)),
)
transformed_image1 = transform(image)
transformed_image2, transformed_label = transform(image, label)
# whether we call the transform with label or not has no effect on the image transform
assert transformed_image1.eq(transformed_image2).all()
# the transform is a no-op for labels
assert transformed_label.eq(label).all()
# this will be supplied by a dataset from torchvision.prototype.datasets
bounding_box = features.BoundingBox(
[60, 30, 15, 15], format=features.BoundingBoxFormat.CXCYWH, image_size=(100, 100)
)
transformed_image, transformed_bounding_box = transform(image, bounding_box)
# this will be supplied by a dataset from torchvision.prototype.datasets
segmentation_mask = torch.zeros((100, 100), dtype=torch.bool)
segmentation_mask[24:36, 55:66] = True
segmentation_mask = features.SegmentationMask(segmentation_mask)
transformed_image, transformed_segmentation_mask = transform(image, segmentation_mask)
Classification
- Port transforms from
torchvision.transforms
port transforms from the old to the new API #5520 - Port transforms from the classification references Port transforms from the classification references #5780
Detection
- New kernels for bounding boxes [RFC] Implement transforms primitives for Bounding Boxes #5514
- Port transforms from the detection references Migrate detection reference transforms to
torchvision.prototype.transforms
#5542
Segmentation
- New kernels for segmentation masks [RFC] Implement transforms primitives for Segmentation Masks #5782
- Port transforms from segmentation references. This will automatically be achieved with the completion above, since there are no special segmentation transforms in our references. If the respective kernels are implemented, the transforms will handle the segmentation case out of the box.
Other
datumbox