Skip to content

Commit 66d777e

Browse files
oke-adityafmassa
andauthored
Improved utilites, adds examples, tests (#3594)
* start adding tests * add return type and doc * adds tests * add no fill tests * add rgb test * check inplace * bug fix * bug fix * rewrite make grid * add plotting demos * rename file * remove * updt * Add viz * updt * update readme, add links * complte bounding boxes * Complete the examples! * link fix * link fixed Co-authored-by: Francisco Massa <[email protected]>
1 parent 20a771e commit 66d777e

File tree

5 files changed

+768
-8
lines changed

5 files changed

+768
-8
lines changed

examples/python/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
[Examples of Tensor Images transformations](https://github.com/pytorch/vision/blob/master/examples/python/tensor_transforms.ipynb)
55
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/video_api.ipynb)
66
[Example of VideoAPI](https://github.com/pytorch/vision/blob/master/examples/python/video_api.ipynb)
7+
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb)
8+
[Example of Visualization Utils](https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb)
79

810

911
Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric and presented multiple limitations due to
@@ -16,3 +18,5 @@ features:
1618
- read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats)
1719

1820
Furthermore, previously we used to provide a very high-level API for video decoding which left little control to the user. We're now expanding that API (and replacing it in the future) with a lower-level API that allows the user a frame-based access to a video.
21+
22+
Torchvision also provides utilities to visualize results. You can make grid of images, plot bounding boxes as well as segmentation masks. Thse utilities work standalone as well as with torchvision models for detection and segmentation.

examples/python/visualization_utils.ipynb

Lines changed: 683 additions & 0 deletions
Large diffs are not rendered by default.
360 Bytes
Loading

test/test_utils.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
import torchvision.transforms.functional as F
1010
from PIL import Image
1111

12+
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
13+
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
14+
1215
masks = torch.tensor([
1316
[
1417
[-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
@@ -106,8 +109,8 @@ def test_save_image_single_pixel_file_object(self):
106109

107110
def test_draw_boxes(self):
108111
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
109-
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
110-
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
112+
img_cp = img.clone()
113+
boxes_cp = boxes.clone()
111114
labels = ["a", "b", "c", "d"]
112115
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
113116
result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True)
@@ -119,9 +122,41 @@ def test_draw_boxes(self):
119122

120123
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
121124
self.assertTrue(torch.equal(result, expected))
125+
# Check if modification is not in place
126+
self.assertTrue(torch.all(torch.eq(boxes, boxes_cp)).item())
127+
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
128+
129+
def test_draw_boxes_vanilla(self):
130+
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
131+
img_cp = img.clone()
132+
boxes_cp = boxes.clone()
133+
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7)
134+
135+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
136+
if not os.path.exists(path):
137+
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
138+
res.save(path)
139+
140+
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
141+
self.assertTrue(torch.equal(result, expected))
142+
# Check if modification is not in place
143+
self.assertTrue(torch.all(torch.eq(boxes, boxes_cp)).item())
144+
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
145+
146+
def test_draw_invalid_boxes(self):
147+
img_tp = ((1, 1, 1), (1, 2, 3))
148+
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
149+
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
150+
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
151+
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
152+
self.assertRaises(TypeError, utils.draw_bounding_boxes, img_tp, boxes)
153+
self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong1, boxes)
154+
self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong2, boxes)
122155

123156
def test_draw_segmentation_masks_colors(self):
124157
img = torch.full((3, 5, 5), 255, dtype=torch.uint8)
158+
img_cp = img.clone()
159+
masks_cp = masks.clone()
125160
colors = ["#FF00FF", (0, 255, 0), "red"]
126161
result = utils.draw_segmentation_masks(img, masks, colors=colors)
127162

@@ -134,9 +169,14 @@ def test_draw_segmentation_masks_colors(self):
134169

135170
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
136171
self.assertTrue(torch.equal(result, expected))
172+
# Check if modification is not in place
173+
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
174+
self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item())
137175

138176
def test_draw_segmentation_masks_no_colors(self):
139177
img = torch.full((3, 20, 20), 255, dtype=torch.uint8)
178+
img_cp = img.clone()
179+
masks_cp = masks.clone()
140180
result = utils.draw_segmentation_masks(img, masks, colors=None)
141181

142182
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
@@ -148,6 +188,20 @@ def test_draw_segmentation_masks_no_colors(self):
148188

149189
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
150190
self.assertTrue(torch.equal(result, expected))
191+
# Check if modification is not in place
192+
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
193+
self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item())
194+
195+
def test_draw_invalid_masks(self):
196+
img_tp = ((1, 1, 1), (1, 2, 3))
197+
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
198+
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
199+
img_wrong3 = torch.full((4, 5, 5), 255, dtype=torch.uint8)
200+
201+
self.assertRaises(TypeError, utils.draw_segmentation_masks, img_tp, masks)
202+
self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong1, masks)
203+
self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong2, masks)
204+
self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong3, masks)
151205

152206

153207
if __name__ == '__main__':

torchvision/utils.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def make_grid(
2020
pad_value: int = 0,
2121
**kwargs
2222
) -> torch.Tensor:
23-
"""Make a grid of images.
23+
"""
24+
Make a grid of images.
2425
2526
Args:
2627
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
@@ -37,9 +38,12 @@ def make_grid(
3738
images separately rather than the (min, max) over all images. Default: ``False``.
3839
pad_value (float, optional): Value for the padded pixels. Default: ``0``.
3940
40-
Example:
41-
See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_
41+
Returns:
42+
grid (Tensor): the tensor containing grid of images.
4243
44+
Example:
45+
See this notebook
46+
`here <https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb>`_
4347
"""
4448
if not (torch.is_tensor(tensor) or
4549
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
@@ -117,7 +121,8 @@ def save_image(
117121
format: Optional[str] = None,
118122
**kwargs
119123
) -> None:
120-
"""Save a given Tensor into an image file.
124+
"""
125+
Save a given Tensor into an image file.
121126
122127
Args:
123128
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
@@ -150,7 +155,7 @@ def draw_bounding_boxes(
150155
"""
151156
Draws bounding boxes on given image.
152157
The values of the input image should be uint8 between 0 and 255.
153-
If filled, Resulting Tensor should be saved as PNG image.
158+
If fill is True, Resulting Tensor should be saved as PNG image.
154159
155160
Args:
156161
image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
@@ -166,6 +171,13 @@ def draw_bounding_boxes(
166171
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
167172
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
168173
font_size (int): The requested font size in points.
174+
175+
Returns:
176+
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
177+
178+
Example:
179+
See this notebook
180+
`linked <https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb>`_
169181
"""
170182

171183
if not isinstance(image, torch.Tensor):
@@ -209,7 +221,7 @@ def draw_bounding_boxes(
209221
if labels is not None:
210222
draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font)
211223

212-
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)
224+
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
213225

214226

215227
@torch.no_grad()
@@ -230,6 +242,13 @@ def draw_segmentation_masks(
230242
alpha (float): Float number between 0 and 1 denoting factor of transpaerency of masks.
231243
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can
232244
be represented as `str` or `Tuple[int, int, int]`.
245+
246+
Returns:
247+
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with segmentation masks plotted.
248+
249+
Example:
250+
See this notebook
251+
`attached <https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb>`_
233252
"""
234253

235254
if not isinstance(image, torch.Tensor):

0 commit comments

Comments
 (0)