diff --git a/docs/requirements.txt b/docs/requirements.txt index 09a11359ae7..2a50d9b8f45 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -5,3 +5,4 @@ sphinx-gallery>=0.11.1 sphinx==5.0.0 tabulate -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme +pycocotools diff --git a/gallery/assets/coco/images/000000000001.jpg b/gallery/assets/coco/images/000000000001.jpg new file mode 120000 index 00000000000..9be80c7c273 --- /dev/null +++ b/gallery/assets/coco/images/000000000001.jpg @@ -0,0 +1 @@ +../../astronaut.jpg \ No newline at end of file diff --git a/gallery/assets/coco/images/000000000002.jpg b/gallery/assets/coco/images/000000000002.jpg new file mode 120000 index 00000000000..9f8efef9928 --- /dev/null +++ b/gallery/assets/coco/images/000000000002.jpg @@ -0,0 +1 @@ +../../dog2.jpg \ No newline at end of file diff --git a/gallery/assets/coco/instances.json b/gallery/assets/coco/instances.json new file mode 100644 index 00000000000..fe0e09270bf --- /dev/null +++ b/gallery/assets/coco/instances.json @@ -0,0 +1 @@ +{"images": [{"file_name": "000000000001.jpg", "height": 512, "width": 512, "id": 1}, {"file_name": "000000000002.jpg", "height": 500, "width": 500, "id": 2}], "annotations": [{"segmentation": [[40.0, 511.0, 26.0, 487.0, 28.0, 438.0, 17.0, 397.0, 24.0, 346.0, 38.0, 306.0, 61.0, 250.0, 111.0, 206.0, 111.0, 187.0, 120.0, 183.0, 136.0, 159.0, 159.0, 150.0, 181.0, 148.0, 182.0, 132.0, 175.0, 132.0, 168.0, 120.0, 154.0, 102.0, 153.0, 62.0, 188.0, 35.0, 191.0, 29.0, 208.0, 20.0, 210.0, 22.0, 227.0, 16.0, 240.0, 16.0, 276.0, 31.0, 285.0, 39.0, 301.0, 88.0, 297.0, 108.0, 281.0, 128.0, 273.0, 138.0, 266.0, 138.0, 264.0, 153.0, 257.0, 162.0, 256.0, 174.0, 284.0, 197.0, 300.0, 221.0, 303.0, 236.0, 337.0, 258.0, 357.0, 306.0, 361.0, 351.0, 358.0, 511.0]], "iscrowd": 0, "image_id": 1, "bbox": [17.0, 16.0, 344.0, 495.0], "category_id": 1, "id": 1}, {"segmentation": [[0.0, 411.0, 43.0, 401.0, 99.0, 395.0, 105.0, 351.0, 124.0, 326.0, 181.0, 294.0, 227.0, 280.0, 245.0, 262.0, 259.0, 234.0, 262.0, 207.0, 271.0, 140.0, 283.0, 139.0, 301.0, 162.0, 309.0, 181.0, 341.0, 175.0, 362.0, 139.0, 369.0, 139.0, 377.0, 163.0, 378.0, 203.0, 381.0, 212.0, 380.0, 220.0, 382.0, 242.0, 404.0, 264.0, 392.0, 293.0, 384.0, 295.0, 385.0, 316.0, 399.0, 343.0, 391.0, 448.0, 452.0, 475.0, 457.0, 494.0, 436.0, 498.0, 402.0, 491.0, 369.0, 488.0, 366.0, 496.0, 319.0, 496.0, 302.0, 485.0, 226.0, 469.0, 128.0, 456.0, 74.0, 458.0, 29.0, 439.0, 0.0, 445.0]], "iscrowd": 0, "image_id": 2, "bbox": [0.0, 139.0, 457.0, 359.0], "category_id": 18, "id": 2}]} diff --git a/gallery/plot_transforms_v2_e2e.py b/gallery/plot_transforms_v2_e2e.py new file mode 100644 index 00000000000..938578e4af9 --- /dev/null +++ b/gallery/plot_transforms_v2_e2e.py @@ -0,0 +1,152 @@ +""" +================================================== +transforms v2: End-to-end object detection example +================================================== + +Object detection is not supported out of the box by ``torchvision.transforms`` v1, since it only supports images. +``torchvision.transforms.v2`` enables jointly transforming images, videos, bounding boxes, and masks. This example +showcases an end-to-end object detection training using the stable ``torchvisio.datasets`` and ``torchvision.models`` as +well as the new ``torchvision.transforms.v2`` v2 API. +""" + +import pathlib +from collections import defaultdict + +import PIL.Image + +import torch +import torch.utils.data + +import torchvision + + +# sphinx_gallery_thumbnail_number = -1 +def show(sample): + import matplotlib.pyplot as plt + + from torchvision.transforms.v2 import functional as F + from torchvision.utils import draw_bounding_boxes + + image, target = sample + if isinstance(image, PIL.Image.Image): + image = F.to_image_tensor(image) + image = F.convert_dtype(image, torch.uint8) + annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3) + + fig, ax = plt.subplots() + ax.imshow(annotated_image.permute(1, 2, 0).numpy()) + ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) + fig.tight_layout() + + fig.show() + + +# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that +# some APIs may slightly change in the future +torchvision.disable_beta_transforms_warning() + +from torchvision import models, datasets +import torchvision.transforms.v2 as transforms + + +######################################################################################################################## +# We start off by loading the :class:`~torchvision.datasets.CocoDetection` dataset to have a look at what it currently +# returns, and we'll see how to convert it to a format that is compatible with our new transforms. + + +def load_example_coco_detection_dataset(**kwargs): + # This loads fake data for illustration purposes of this example. In practice, you'll have + # to replace this with the proper data + root = pathlib.Path("assets") / "coco" + return datasets.CocoDetection(str(root / "images"), str(root / "instances.json"), **kwargs) + + +dataset = load_example_coco_detection_dataset() + +sample = dataset[0] +image, target = sample +print(type(image)) +print(type(target), type(target[0]), list(target[0].keys())) + + +######################################################################################################################## +# The dataset returns a two-tuple with the first item being a :class:`PIL.Image.Image` and second one a list of +# dictionaries, which each containing the annotations for a single object instance. As is, this format is not compatible +# with the ``torchvision.transforms.v2``, nor with the models. To overcome that, we provide the +# :func:`~torchvision.datasets.wrap_dataset_for_transforms_v2` function. For +# :class:`~torchvision.datasets.CocoDetection`, this changes the target structure to a single dictionary of lists. It +# also adds the key-value-pairs ``"boxes"``, ``"masks"``, and ``"labels"`` wrapped in the corresponding +# ``torchvision.datapoints``. + +dataset = datasets.wrap_dataset_for_transforms_v2(dataset) + +sample = dataset[0] +image, target = sample +print(type(image)) +print(type(target), list(target.keys())) +print(type(target["boxes"]), type(target["masks"]), type(target["labels"])) + +######################################################################################################################## +# As baseline, let's have a look at a sample without transformations: + +show(sample) + + +######################################################################################################################## +# With the dataset properly set up, we can now define the augmentation pipeline. This is done the same way it is done in +# ``torchvision.transforms`` v1, but now handles bounding boxes and masks without any extra configuration. + +transform = transforms.Compose( + [ + transforms.RandomPhotometricDistort(), + transforms.RandomZoomOut( + fill=defaultdict(lambda: 0, {PIL.Image.Image: (123, 117, 104)}) + ), + transforms.RandomIoUCrop(), + transforms.RandomHorizontalFlip(), + transforms.ToImageTensor(), + transforms.ConvertImageDtype(torch.float32), + transforms.SanitizeBoundingBoxes(), + ] +) + +######################################################################################################################## +# .. note:: +# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` transform is a no-op in this example, but it +# should be placed at least once at the end of a detection pipeline to remove degenerate bounding boxes as well as +# the corresponding labels and optionally masks. It is particularly critical to add it if +# :class:`~torchvision.transforms.v2.RandomIoUCrop` was used. +# +# Let's look how the sample looks like with our augmentation pipeline in place: + +dataset = load_example_coco_detection_dataset(transforms=transform) +dataset = datasets.wrap_dataset_for_transforms_v2(dataset) + +torch.manual_seed(3141) +sample = dataset[0] + +show(sample) + + +######################################################################################################################## +# We can see that the color of the image was distorted, we zoomed out on it (off center) and flipped it horizontally. +# In all of this, the bounding box was transformed accordingly. And without any further ado, we can start training. + +data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=2, + # We need a custom collation function here, since the object detection models expect a + # sequence of images and target dictionaries. The default collation function tries to + # `torch.stack` the individual elements, which fails in general for object detection, + # because the number of object instances varies between the samples. This is the same for + # `torchvision.transforms` v1 + collate_fn=lambda batch: tuple(zip(*batch)), +) + +model = models.get_model("ssd300_vgg16", weights=None, weights_backbone=None).train() + +for images, targets in data_loader: + loss_dict = model(images, targets) + print(loss_dict) + # Put your training logic here + break