-
Notifications
You must be signed in to change notification settings - Fork 7.1k
add end-to-end example gallery for transforms v2 #7302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9998552
58bf87a
9da7c50
c052161
94bab17
5f3d8c6
a5d3517
389103f
022bddb
85fff2c
398a98d
6ab3b31
5cef5d7
4168e1d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../astronaut.jpg |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../dog2.jpg |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"images": [{"file_name": "000000000001.jpg", "height": 512, "width": 512, "id": 1}, {"file_name": "000000000002.jpg", "height": 500, "width": 500, "id": 2}], "annotations": [{"segmentation": [[40.0, 511.0, 26.0, 487.0, 28.0, 438.0, 17.0, 397.0, 24.0, 346.0, 38.0, 306.0, 61.0, 250.0, 111.0, 206.0, 111.0, 187.0, 120.0, 183.0, 136.0, 159.0, 159.0, 150.0, 181.0, 148.0, 182.0, 132.0, 175.0, 132.0, 168.0, 120.0, 154.0, 102.0, 153.0, 62.0, 188.0, 35.0, 191.0, 29.0, 208.0, 20.0, 210.0, 22.0, 227.0, 16.0, 240.0, 16.0, 276.0, 31.0, 285.0, 39.0, 301.0, 88.0, 297.0, 108.0, 281.0, 128.0, 273.0, 138.0, 266.0, 138.0, 264.0, 153.0, 257.0, 162.0, 256.0, 174.0, 284.0, 197.0, 300.0, 221.0, 303.0, 236.0, 337.0, 258.0, 357.0, 306.0, 361.0, 351.0, 358.0, 511.0]], "iscrowd": 0, "image_id": 1, "bbox": [17.0, 16.0, 344.0, 495.0], "category_id": 1, "id": 1}, {"segmentation": [[0.0, 411.0, 43.0, 401.0, 99.0, 395.0, 105.0, 351.0, 124.0, 326.0, 181.0, 294.0, 227.0, 280.0, 245.0, 262.0, 259.0, 234.0, 262.0, 207.0, 271.0, 140.0, 283.0, 139.0, 301.0, 162.0, 309.0, 181.0, 341.0, 175.0, 362.0, 139.0, 369.0, 139.0, 377.0, 163.0, 378.0, 203.0, 381.0, 212.0, 380.0, 220.0, 382.0, 242.0, 404.0, 264.0, 392.0, 293.0, 384.0, 295.0, 385.0, 316.0, 399.0, 343.0, 391.0, 448.0, 452.0, 475.0, 457.0, 494.0, 436.0, 498.0, 402.0, 491.0, 369.0, 488.0, 366.0, 496.0, 319.0, 496.0, 302.0, 485.0, 226.0, 469.0, 128.0, 456.0, 74.0, 458.0, 29.0, 439.0, 0.0, 445.0]], "iscrowd": 0, "image_id": 2, "bbox": [0.0, 139.0, 457.0, 359.0], "category_id": 18, "id": 2}]} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
""" | ||
================================================== | ||
transforms v2: End-to-end object detection example | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for being late for the review. - transforms v2: End-to-end object detection example
+ Transforms v2: End-to-end object detection example |
||
================================================== | ||
|
||
Object detection is not supported out of the box by ``torchvision.transforms`` v1, since it only supports images. | ||
``torchvision.transforms.v2`` enables jointly transforming images, videos, bounding boxes, and masks. This example | ||
showcases an end-to-end object detection training using the stable ``torchvisio.datasets`` and ``torchvision.models`` as | ||
well as the new ``torchvision.transforms.v2`` v2 API. | ||
""" | ||
|
||
import pathlib | ||
from collections import defaultdict | ||
|
||
import PIL.Image | ||
|
||
import torch | ||
import torch.utils.data | ||
|
||
import torchvision | ||
|
||
|
||
# sphinx_gallery_thumbnail_number = -1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this necessary ? Looks a bit weird once rendered: https://output.circle-artifacts.com/output/job/a6035cf7-ba73-4f96-b92e-1b2aa5b18cfc/artifacts/0/docs/auto_examples/plot_transforms_v2_e2e.html#sphx-glr-auto-examples-plot-transforms-v2-e2e-py There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I remembered that there is an option to hide these kind of comments: https://sphinx-gallery.github.io/stable/configuration.html#removing-config-comments Will send a PR |
||
def show(sample): | ||
import matplotlib.pyplot as plt | ||
|
||
from torchvision.transforms.v2 import functional as F | ||
from torchvision.utils import draw_bounding_boxes | ||
|
||
image, target = sample | ||
if isinstance(image, PIL.Image.Image): | ||
image = F.to_image_tensor(image) | ||
image = F.convert_dtype(image, torch.uint8) | ||
annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3) | ||
|
||
fig, ax = plt.subplots() | ||
ax.imshow(annotated_image.permute(1, 2, 0).numpy()) | ||
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) | ||
fig.tight_layout() | ||
|
||
fig.show() | ||
|
||
|
||
# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that | ||
# some APIs may slightly change in the future | ||
torchvision.disable_beta_transforms_warning() | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
from torchvision import models, datasets | ||
import torchvision.transforms.v2 as transforms | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
######################################################################################################################## | ||
# We start off by loading the :class:`~torchvision.datasets.CocoDetection` dataset to have a look at what it currently | ||
# returns, and we'll see how to convert it to a format that is compatible with our new transforms. | ||
|
||
|
||
def load_example_coco_detection_dataset(**kwargs): | ||
# This loads fake data for illustration purposes of this example. In practice, you'll have | ||
# to replace this with the proper data | ||
root = pathlib.Path("assets") / "coco" | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return datasets.CocoDetection(str(root / "images"), str(root / "instances.json"), **kwargs) | ||
|
||
|
||
dataset = load_example_coco_detection_dataset() | ||
|
||
sample = dataset[0] | ||
image, target = sample | ||
print(type(image)) | ||
print(type(target), type(target[0]), list(target[0].keys())) | ||
|
||
|
||
######################################################################################################################## | ||
# The dataset returns a two-tuple with the first item being a :class:`PIL.Image.Image` and second one a list of | ||
# dictionaries, which each containing the annotations for a single object instance. As is, this format is not compatible | ||
# with the ``torchvision.transforms.v2``, nor with the models. To overcome that, we provide the | ||
# :func:`~torchvision.datasets.wrap_dataset_for_transforms_v2` function. For | ||
# :class:`~torchvision.datasets.CocoDetection`, this changes the target structure to a single dictionary of lists. It | ||
# also adds the key-value-pairs ``"boxes"``, ``"masks"``, and ``"labels"`` wrapped in the corresponding | ||
# ``torchvision.datapoints``. | ||
|
||
dataset = datasets.wrap_dataset_for_transforms_v2(dataset) | ||
|
||
sample = dataset[0] | ||
image, target = sample | ||
print(type(image)) | ||
print(type(target), list(target.keys())) | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
print(type(target["boxes"]), type(target["masks"]), type(target["labels"])) | ||
|
||
######################################################################################################################## | ||
# As baseline, let's have a look at a sample without transformations: | ||
|
||
show(sample) | ||
|
||
|
||
######################################################################################################################## | ||
# With the dataset properly set up, we can now define the augmentation pipeline. This is done the same way it is done in | ||
# ``torchvision.transforms`` v1, but now handles bounding boxes and masks without any extra configuration. | ||
|
||
transform = transforms.Compose( | ||
[ | ||
transforms.RandomPhotometricDistort(), | ||
transforms.RandomZoomOut( | ||
fill=defaultdict(lambda: 0, {PIL.Image.Image: (123, 117, 104)}) | ||
), | ||
transforms.RandomIoUCrop(), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToImageTensor(), | ||
transforms.ConvertImageDtype(torch.float32), | ||
transforms.SanitizeBoundingBoxes(), | ||
] | ||
) | ||
|
||
######################################################################################################################## | ||
# .. note:: | ||
# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` transform is a no-op in this example, but it | ||
# should be placed at least once at the end of a detection pipeline to remove degenerate bounding boxes as well as | ||
# the corresponding labels and optionally masks. It is particularly critical to add it if | ||
# :class:`~torchvision.transforms.v2.RandomIoUCrop` was used. | ||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# | ||
# Let's look how the sample looks like with our augmentation pipeline in place: | ||
|
||
dataset = load_example_coco_detection_dataset(transforms=transform) | ||
dataset = datasets.wrap_dataset_for_transforms_v2(dataset) | ||
|
||
torch.manual_seed(3141) | ||
sample = dataset[0] | ||
|
||
show(sample) | ||
|
||
|
||
######################################################################################################################## | ||
# We can see that the color of the image was distorted, we zoomed out on it (off center) and flipped it horizontally. | ||
# In all of this, the bounding box was transformed accordingly. And without any further ado, we can start training. | ||
|
||
data_loader = torch.utils.data.DataLoader( | ||
dataset, | ||
batch_size=2, | ||
# We need a custom collation function here, since the object detection models expect a | ||
# sequence of images and target dictionaries. The default collation function tries to | ||
# `torch.stack` the individual elements, which fails in general for object detection, | ||
# because the number of object instances varies between the samples. This is the same for | ||
# `torchvision.transforms` v1 | ||
collate_fn=lambda batch: tuple(zip(*batch)), | ||
) | ||
|
||
model = models.get_model("ssd300_vgg16", weights=None, weights_backbone=None).train() | ||
|
||
for images, targets in data_loader: | ||
loss_dict = model(images, targets) | ||
print(loss_dict) | ||
# Put your training logic here | ||
break |
Uh oh!
There was an error while loading. Please reload this page.