Skip to content

Commit c786d75

Browse files
pmeierNicolasHug
andauthored
add end-to-end example gallery for transforms v2 (#7302)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent ed48bb1 commit c786d75

File tree

5 files changed

+156
-0
lines changed

5 files changed

+156
-0
lines changed

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ sphinx-gallery>=0.11.1
55
sphinx==5.0.0
66
tabulate
77
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
8+
pycocotools
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../astronaut.jpg
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../dog2.jpg

gallery/assets/coco/instances.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"images": [{"file_name": "000000000001.jpg", "height": 512, "width": 512, "id": 1}, {"file_name": "000000000002.jpg", "height": 500, "width": 500, "id": 2}], "annotations": [{"segmentation": [[40.0, 511.0, 26.0, 487.0, 28.0, 438.0, 17.0, 397.0, 24.0, 346.0, 38.0, 306.0, 61.0, 250.0, 111.0, 206.0, 111.0, 187.0, 120.0, 183.0, 136.0, 159.0, 159.0, 150.0, 181.0, 148.0, 182.0, 132.0, 175.0, 132.0, 168.0, 120.0, 154.0, 102.0, 153.0, 62.0, 188.0, 35.0, 191.0, 29.0, 208.0, 20.0, 210.0, 22.0, 227.0, 16.0, 240.0, 16.0, 276.0, 31.0, 285.0, 39.0, 301.0, 88.0, 297.0, 108.0, 281.0, 128.0, 273.0, 138.0, 266.0, 138.0, 264.0, 153.0, 257.0, 162.0, 256.0, 174.0, 284.0, 197.0, 300.0, 221.0, 303.0, 236.0, 337.0, 258.0, 357.0, 306.0, 361.0, 351.0, 358.0, 511.0]], "iscrowd": 0, "image_id": 1, "bbox": [17.0, 16.0, 344.0, 495.0], "category_id": 1, "id": 1}, {"segmentation": [[0.0, 411.0, 43.0, 401.0, 99.0, 395.0, 105.0, 351.0, 124.0, 326.0, 181.0, 294.0, 227.0, 280.0, 245.0, 262.0, 259.0, 234.0, 262.0, 207.0, 271.0, 140.0, 283.0, 139.0, 301.0, 162.0, 309.0, 181.0, 341.0, 175.0, 362.0, 139.0, 369.0, 139.0, 377.0, 163.0, 378.0, 203.0, 381.0, 212.0, 380.0, 220.0, 382.0, 242.0, 404.0, 264.0, 392.0, 293.0, 384.0, 295.0, 385.0, 316.0, 399.0, 343.0, 391.0, 448.0, 452.0, 475.0, 457.0, 494.0, 436.0, 498.0, 402.0, 491.0, 369.0, 488.0, 366.0, 496.0, 319.0, 496.0, 302.0, 485.0, 226.0, 469.0, 128.0, 456.0, 74.0, 458.0, 29.0, 439.0, 0.0, 445.0]], "iscrowd": 0, "image_id": 2, "bbox": [0.0, 139.0, 457.0, 359.0], "category_id": 18, "id": 2}]}

gallery/plot_transforms_v2_e2e.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""
2+
==================================================
3+
transforms v2: End-to-end object detection example
4+
==================================================
5+
6+
Object detection is not supported out of the box by ``torchvision.transforms`` v1, since it only supports images.
7+
``torchvision.transforms.v2`` enables jointly transforming images, videos, bounding boxes, and masks. This example
8+
showcases an end-to-end object detection training using the stable ``torchvisio.datasets`` and ``torchvision.models`` as
9+
well as the new ``torchvision.transforms.v2`` v2 API.
10+
"""
11+
12+
import pathlib
13+
from collections import defaultdict
14+
15+
import PIL.Image
16+
17+
import torch
18+
import torch.utils.data
19+
20+
import torchvision
21+
22+
23+
# sphinx_gallery_thumbnail_number = -1
24+
def show(sample):
25+
import matplotlib.pyplot as plt
26+
27+
from torchvision.transforms.v2 import functional as F
28+
from torchvision.utils import draw_bounding_boxes
29+
30+
image, target = sample
31+
if isinstance(image, PIL.Image.Image):
32+
image = F.to_image_tensor(image)
33+
image = F.convert_dtype(image, torch.uint8)
34+
annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)
35+
36+
fig, ax = plt.subplots()
37+
ax.imshow(annotated_image.permute(1, 2, 0).numpy())
38+
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
39+
fig.tight_layout()
40+
41+
fig.show()
42+
43+
44+
# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
45+
# some APIs may slightly change in the future
46+
torchvision.disable_beta_transforms_warning()
47+
48+
from torchvision import models, datasets
49+
import torchvision.transforms.v2 as transforms
50+
51+
52+
########################################################################################################################
53+
# We start off by loading the :class:`~torchvision.datasets.CocoDetection` dataset to have a look at what it currently
54+
# returns, and we'll see how to convert it to a format that is compatible with our new transforms.
55+
56+
57+
def load_example_coco_detection_dataset(**kwargs):
58+
# This loads fake data for illustration purposes of this example. In practice, you'll have
59+
# to replace this with the proper data
60+
root = pathlib.Path("assets") / "coco"
61+
return datasets.CocoDetection(str(root / "images"), str(root / "instances.json"), **kwargs)
62+
63+
64+
dataset = load_example_coco_detection_dataset()
65+
66+
sample = dataset[0]
67+
image, target = sample
68+
print(type(image))
69+
print(type(target), type(target[0]), list(target[0].keys()))
70+
71+
72+
########################################################################################################################
73+
# The dataset returns a two-tuple with the first item being a :class:`PIL.Image.Image` and second one a list of
74+
# dictionaries, which each containing the annotations for a single object instance. As is, this format is not compatible
75+
# with the ``torchvision.transforms.v2``, nor with the models. To overcome that, we provide the
76+
# :func:`~torchvision.datasets.wrap_dataset_for_transforms_v2` function. For
77+
# :class:`~torchvision.datasets.CocoDetection`, this changes the target structure to a single dictionary of lists. It
78+
# also adds the key-value-pairs ``"boxes"``, ``"masks"``, and ``"labels"`` wrapped in the corresponding
79+
# ``torchvision.datapoints``.
80+
81+
dataset = datasets.wrap_dataset_for_transforms_v2(dataset)
82+
83+
sample = dataset[0]
84+
image, target = sample
85+
print(type(image))
86+
print(type(target), list(target.keys()))
87+
print(type(target["boxes"]), type(target["masks"]), type(target["labels"]))
88+
89+
########################################################################################################################
90+
# As baseline, let's have a look at a sample without transformations:
91+
92+
show(sample)
93+
94+
95+
########################################################################################################################
96+
# With the dataset properly set up, we can now define the augmentation pipeline. This is done the same way it is done in
97+
# ``torchvision.transforms`` v1, but now handles bounding boxes and masks without any extra configuration.
98+
99+
transform = transforms.Compose(
100+
[
101+
transforms.RandomPhotometricDistort(),
102+
transforms.RandomZoomOut(
103+
fill=defaultdict(lambda: 0, {PIL.Image.Image: (123, 117, 104)})
104+
),
105+
transforms.RandomIoUCrop(),
106+
transforms.RandomHorizontalFlip(),
107+
transforms.ToImageTensor(),
108+
transforms.ConvertImageDtype(torch.float32),
109+
transforms.SanitizeBoundingBoxes(),
110+
]
111+
)
112+
113+
########################################################################################################################
114+
# .. note::
115+
# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` transform is a no-op in this example, but it
116+
# should be placed at least once at the end of a detection pipeline to remove degenerate bounding boxes as well as
117+
# the corresponding labels and optionally masks. It is particularly critical to add it if
118+
# :class:`~torchvision.transforms.v2.RandomIoUCrop` was used.
119+
#
120+
# Let's look how the sample looks like with our augmentation pipeline in place:
121+
122+
dataset = load_example_coco_detection_dataset(transforms=transform)
123+
dataset = datasets.wrap_dataset_for_transforms_v2(dataset)
124+
125+
torch.manual_seed(3141)
126+
sample = dataset[0]
127+
128+
show(sample)
129+
130+
131+
########################################################################################################################
132+
# We can see that the color of the image was distorted, we zoomed out on it (off center) and flipped it horizontally.
133+
# In all of this, the bounding box was transformed accordingly. And without any further ado, we can start training.
134+
135+
data_loader = torch.utils.data.DataLoader(
136+
dataset,
137+
batch_size=2,
138+
# We need a custom collation function here, since the object detection models expect a
139+
# sequence of images and target dictionaries. The default collation function tries to
140+
# `torch.stack` the individual elements, which fails in general for object detection,
141+
# because the number of object instances varies between the samples. This is the same for
142+
# `torchvision.transforms` v1
143+
collate_fn=lambda batch: tuple(zip(*batch)),
144+
)
145+
146+
model = models.get_model("ssd300_vgg16", weights=None, weights_backbone=None).train()
147+
148+
for images, targets in data_loader:
149+
loss_dict = model(images, targets)
150+
print(loss_dict)
151+
# Put your training logic here
152+
break

0 commit comments

Comments
 (0)