Skip to content

Commit a733627

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] add gallery for transforms v2 (#7331)
Reviewed By: vmoens Differential Revision: D44416549 fbshipit-source-id: f343c4f247ba09d5e6c7c137e9a05a0c7769a499
1 parent 5dd868f commit a733627

File tree

1 file changed

+109
-0
lines changed

1 file changed

+109
-0
lines changed

gallery/plot_transforms_v2.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""
2+
==================================
3+
Getting started with transforms v2
4+
==================================
5+
6+
Most computer vision tasks are not supported out of the box by ``torchvision.transforms`` v1, since it only supports
7+
images. ``torchvision.transforms.v2`` enables jointly transforming images, videos, bounding boxes, and masks. This
8+
example showcases the core functionality of the new ``torchvision.transforms.v2`` API.
9+
"""
10+
11+
import pathlib
12+
13+
import torch
14+
import torchvision
15+
16+
17+
def load_data():
18+
from torchvision.io import read_image
19+
from torchvision import datapoints
20+
from torchvision.ops import masks_to_boxes
21+
22+
assets_directory = pathlib.Path("assets")
23+
24+
path = assets_directory / "FudanPed00054.png"
25+
image = datapoints.Image(read_image(str(path)))
26+
merged_masks = read_image(str(assets_directory / "FudanPed00054_mask.png"))
27+
28+
labels = torch.unique(merged_masks)[1:]
29+
30+
masks = datapoints.Mask(merged_masks == labels.view(-1, 1, 1))
31+
32+
bounding_boxes = datapoints.BoundingBox(
33+
masks_to_boxes(masks), format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:]
34+
)
35+
36+
return path, image, bounding_boxes, masks, labels
37+
38+
39+
########################################################################################################################
40+
# The :mod:`torchvision.transforms.v2` API supports images, videos, bounding boxes, and instance and segmentation
41+
# masks. Thus, it offers native support for many Computer Vision tasks, like image and video classification, object
42+
# detection or instance and semantic segmentation. Still, the interface is the same, making
43+
# :mod:`torchvision.transforms.v2` a drop-in replacement for the existing :mod:`torchvision.transforms` API, aka v1.
44+
45+
# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
46+
# some APIs may slightly change in the future
47+
torchvision.disable_beta_transforms_warning()
48+
import torchvision.transforms.v2 as transforms
49+
50+
transform = transforms.Compose(
51+
[
52+
transforms.ColorJitter(contrast=0.5),
53+
transforms.RandomRotation(30),
54+
transforms.CenterCrop(480),
55+
]
56+
)
57+
58+
########################################################################################################################
59+
# :mod:`torchvision.transforms.v2` natively supports jointly transforming multiple inputs while making sure that
60+
# potential random behavior is consistent across all inputs. However, it doesn't enforce a specific input structure or
61+
# order.
62+
63+
path, image, bounding_boxes, masks, labels = load_data()
64+
65+
torch.manual_seed(0)
66+
new_image = transform(image) # Image Classification
67+
new_image, new_bounding_boxes, new_labels = transform(image, bounding_boxes, labels) # Object Detection
68+
new_image, new_bounding_boxes, new_masks, new_labels = transform(
69+
image, bounding_boxes, masks, labels
70+
) # Instance Segmentation
71+
new_image, new_target = transform((image, {"boxes": bounding_boxes, "labels": labels})) # Arbitrary Structure
72+
73+
########################################################################################################################
74+
# Under the hood, :mod:`torchvision.transforms.v2` relies on :mod:`torchvision.datapoints` for the dispatch to the
75+
# appropriate function for the input data: :ref:`sphx_glr_auto_examples_plot_datapoints.py`. Note however, that as
76+
# regular user, you likely don't have to touch this yourself. See
77+
# :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`.
78+
#
79+
# All "foreign" types like :class:`str`'s or :class:`pathlib.Path`'s are passed through, allowing to store extra
80+
# information directly with the sample:
81+
82+
sample = {"path": path, "image": image}
83+
new_sample = transform(sample)
84+
85+
assert new_sample["path"] is sample["path"]
86+
87+
########################################################################################################################
88+
# As stated above, :mod:`torchvision.transforms.v2` is a drop-in replacement for :mod:`torchvision.transforms` and thus
89+
# also supports transforming plain :class:`torch.Tensor`'s as image or video if applicable. This is achieved with a
90+
# simple heuristic:
91+
#
92+
# * If we find an explicit image or video (:class:`torchvision.datapoints.Image`, :class:`torchvision.datapoints.Video`,
93+
# or :class:`PIL.Image.Image`) in the input, all other plain tensors are passed through.
94+
# * If there is no explicit image or video, only the first plain :class:`torch.Tensor` will be transformed as image or
95+
# video, while all others will be passed through.
96+
97+
plain_tensor_image = torch.rand(image.shape)
98+
99+
print(image.shape, plain_tensor_image.shape)
100+
101+
# passing a plain tensor together with an explicit image, will not transform the former
102+
plain_tensor_image, image = transform(plain_tensor_image, image)
103+
104+
print(image.shape, plain_tensor_image.shape)
105+
106+
# passing a plain tensor without an explicit image, will transform the former
107+
plain_tensor_image, _ = transform(plain_tensor_image, bounding_boxes)
108+
109+
print(image.shape, plain_tensor_image.shape)

0 commit comments

Comments
 (0)