Skip to content

Commit 870bebf

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Convert "Visualization utilities" notebook into a sphinx-gallery example (#3774)
Reviewed By: cpuhrsch Differential Revision: D28538756 fbshipit-source-id: 9903ad98e78f7efb4db8dc0a0b90ce65173d0963
1 parent 51b10ec commit 870bebf

File tree

4 files changed

+134
-15
lines changed

4 files changed

+134
-15
lines changed

docs/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ html-noplot: # Avoids running the gallery examples, which may take time
2626

2727
clean:
2828
rm -rf $(BUILDDIR)/*
29-
rm -rf auto_examples/
29+
rm -rf $(SOURCEDIR)/auto_examples/
3030

3131
.PHONY: help Makefile docset
3232

gallery/plot_scripted_tensor_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535

3636
plt.rcParams["savefig.bbox"] = 'tight'
37-
torch.manual_seed(0)
37+
torch.manual_seed(1)
3838

3939

4040
def show(imgs):

gallery/plot_visualization_utils.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""
2+
=======================
3+
Visualization utilities
4+
=======================
5+
6+
This example illustrates some of the utilities that torchvision offers for
7+
visualizing images, bounding boxes, and segmentation masks.
8+
"""
9+
10+
11+
import torch
12+
import numpy as np
13+
import scipy.misc
14+
import matplotlib.pyplot as plt
15+
16+
import torchvision.transforms.functional as F
17+
18+
19+
plt.rcParams["savefig.bbox"] = 'tight'
20+
21+
22+
def show(imgs):
23+
if not isinstance(imgs, list):
24+
imgs = [imgs]
25+
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
26+
for i, img in enumerate(imgs):
27+
img = F.to_pil_image(img.to('cpu'))
28+
axs[0, i].imshow(np.asarray(img))
29+
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
30+
31+
32+
####################################
33+
# Visualizing a grid of images
34+
# ----------------------------
35+
# The :func:`~torchvision.utils.make_grid` function can be used to create a
36+
# tensor that represents multiple images in a grid. This util requires a single
37+
# image of dtype ``uint8`` as input.
38+
39+
from torchvision.utils import make_grid
40+
from torchvision.io import read_image
41+
from pathlib import Path
42+
43+
dog1_int = read_image(str(Path('assets') / 'dog1.jpg'))
44+
dog2_int = read_image(str(Path('assets') / 'dog2.jpg'))
45+
46+
grid = make_grid([dog1_int, dog2_int, dog1_int, dog2_int])
47+
show(grid)
48+
49+
####################################
50+
# Visualizing bounding boxes
51+
# --------------------------
52+
# We can use :func:`~torchvision.utils.draw_bounding_boxes` to draw boxes on an
53+
# image. We can set the colors, labels, width as well as font and font size !
54+
# The boxes are in ``(xmin, ymin, xmax, ymax)`` format
55+
# from torchvision.utils import draw_bounding_boxes
56+
57+
from torchvision.utils import draw_bounding_boxes
58+
59+
60+
boxes = torch.tensor([[50, 50, 100, 200], [210, 150, 350, 430]], dtype=torch.float)
61+
colors = ["blue", "yellow"]
62+
result = draw_bounding_boxes(dog1_int, boxes, colors=colors, width=5)
63+
show(result)
64+
65+
66+
#####################################
67+
# Naturally, we can also plot bounding boxes produced by torchvision detection
68+
# models. Here is demo with a Faster R-CNN model loaded from
69+
# :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn`
70+
# model. You can also try using a RetinaNet with
71+
# :func:`~torchvision.models.detection.retinanet_resnet50_fpn`.
72+
73+
from torchvision.models.detection import fasterrcnn_resnet50_fpn
74+
from torchvision.transforms.functional import convert_image_dtype
75+
76+
77+
dog1_float = convert_image_dtype(dog1_int, dtype=torch.float)
78+
dog2_float = convert_image_dtype(dog2_int, dtype=torch.float)
79+
batch = torch.stack([dog1_float, dog2_float])
80+
81+
model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False)
82+
model = model.eval()
83+
84+
outputs = model(batch)
85+
print(outputs)
86+
87+
#####################################
88+
# Let's plot the boxes detected by our model. We will only plot the boxes with a
89+
# score greater than a given threshold.
90+
91+
threshold = .8
92+
dogs_with_boxes = [
93+
draw_bounding_boxes(dog_int, boxes=output['boxes'][output['scores'] > threshold], width=4)
94+
for dog_int, output in zip((dog1_int, dog2_int), outputs)
95+
]
96+
show(dogs_with_boxes)
97+
98+
#####################################
99+
# Visualizing segmentation masks
100+
# ------------------------------
101+
# The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to
102+
# draw segmentation amasks on images. We can set the colors as well as
103+
# transparency of masks.
104+
#
105+
# Here is demo with torchvision's FCN Resnet-50, loaded with
106+
# :func:`~torchvision.models.segmentation.fcn_resnet50`.
107+
# You can also try using
108+
# DeepLabv3 (:func:`~torchvision.models.segmentation.deeplabv3_resnet50`)
109+
# or lraspp mobilenet models
110+
# (:func:`~torchvision.models.segmentation.lraspp_mobilenet_v3_large`).
111+
#
112+
# Like :func:`~torchvision.utils.draw_bounding_boxes`,
113+
# :func:`~torchvision.utils.draw_segmentation_masks` requires a single RGB image
114+
# of dtype `uint8`.
115+
116+
from torchvision.models.segmentation import fcn_resnet50
117+
from torchvision.utils import draw_segmentation_masks
118+
119+
120+
model = fcn_resnet50(pretrained=True, progress=False)
121+
model = model.eval()
122+
123+
# The model expects the batch to be normalized
124+
batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
125+
outputs = model(batch)
126+
127+
dogs_with_masks = [
128+
draw_segmentation_masks(dog_int, masks=masks, alpha=0.6)
129+
for dog_int, masks in zip((dog1_int, dog2_int), outputs['out'])
130+
]
131+
show(dogs_with_masks)

torchvision/utils.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,6 @@ def make_grid(
4040
4141
Returns:
4242
grid (Tensor): the tensor containing grid of images.
43-
44-
Example:
45-
See this notebook
46-
`here <https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb>`_
4743
"""
4844
if not (torch.is_tensor(tensor) or
4945
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
@@ -174,10 +170,6 @@ def draw_bounding_boxes(
174170
175171
Returns:
176172
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
177-
178-
Example:
179-
See this notebook
180-
`linked <https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb>`_
181173
"""
182174

183175
if not isinstance(image, torch.Tensor):
@@ -239,16 +231,12 @@ def draw_segmentation_masks(
239231
Args:
240232
image (Tensor): Tensor of shape (3 x H x W) and dtype uint8.
241233
masks (Tensor): Tensor of shape (num_masks, H, W). Each containing probability of predicted class.
242-
alpha (float): Float number between 0 and 1 denoting factor of transpaerency of masks.
234+
alpha (float): Float number between 0 and 1 denoting factor of transparency of masks.
243235
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can
244236
be represented as `str` or `Tuple[int, int, int]`.
245237
246238
Returns:
247239
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with segmentation masks plotted.
248-
249-
Example:
250-
See this notebook
251-
`attached <https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb>`_
252240
"""
253241

254242
if not isinstance(image, torch.Tensor):

0 commit comments

Comments
 (0)