Skip to content

Commit af64a81

Browse files
oke-adityadatumbox
authored andcommitted
Add utility to draw keypoints (pytorch#4216)
* fix * Outline Keypoints API * Add utility * make it work :) * Fix optional type * Add connectivity, fmassa's advice 😃 * Minor code improvement * small fix * fix implementation * Add tests * Fix tests * Update colors * Fix bug and test more robustly * Add a comment, merge stuff * Fix fmt * Support single str for merging * Remove unnecessary vars. Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 9eb05e4 commit af64a81

File tree

4 files changed

+127
-1
lines changed

4 files changed

+127
-1
lines changed

docs/source/utils.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ vizualization <sphx_glr_auto_examples_plot_visualization_utils.py>`.
1414

1515
draw_bounding_boxes
1616
draw_segmentation_masks
17+
draw_keypoints
1718
make_grid
1819
save_image
300 Bytes
Loading

test/test_utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
1818

19+
keypoints = torch.tensor([[[10, 10], [5, 5], [2, 2]], [[20, 20], [30, 30], [3, 3]]], dtype=torch.float)
20+
1921

2022
def test_make_grid_not_inplace():
2123
t = torch.rand(5, 3, 10, 10)
@@ -248,5 +250,58 @@ def test_draw_segmentation_masks_errors():
248250
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
249251

250252

253+
def test_draw_keypoints_vanilla():
254+
# Keypoints is declared on top as global variable
255+
keypoints_cp = keypoints.clone()
256+
257+
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
258+
img_cp = img.clone()
259+
result = utils.draw_keypoints(img, keypoints, colors="red", connectivity=((0, 1),))
260+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png")
261+
if not os.path.exists(path):
262+
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
263+
res.save(path)
264+
265+
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
266+
assert_equal(result, expected)
267+
# Check that keypoints are not modified inplace
268+
assert_equal(keypoints, keypoints_cp)
269+
# Check that image is not modified in place
270+
assert_equal(img, img_cp)
271+
272+
273+
@pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)])
274+
def test_draw_keypoints_colored(colors):
275+
# Keypoints is declared on top as global variable
276+
keypoints_cp = keypoints.clone()
277+
278+
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
279+
img_cp = img.clone()
280+
result = utils.draw_keypoints(img, keypoints, colors=colors, connectivity=((0, 1),))
281+
assert result.size(0) == 3
282+
assert_equal(keypoints, keypoints_cp)
283+
assert_equal(img, img_cp)
284+
285+
286+
def test_draw_keypoints_errors():
287+
h, w = 10, 10
288+
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
289+
290+
with pytest.raises(TypeError, match="The image must be a tensor"):
291+
utils.draw_keypoints(image="Not A Tensor Image", keypoints=keypoints)
292+
with pytest.raises(ValueError, match="The image dtype must be"):
293+
img_bad_dtype = torch.full((3, h, w), 0, dtype=torch.int64)
294+
utils.draw_keypoints(image=img_bad_dtype, keypoints=keypoints)
295+
with pytest.raises(ValueError, match="Pass individual images, not batches"):
296+
batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8)
297+
utils.draw_keypoints(image=batch, keypoints=keypoints)
298+
with pytest.raises(ValueError, match="Pass an RGB image"):
299+
one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8)
300+
utils.draw_keypoints(image=one_channel, keypoints=keypoints)
301+
with pytest.raises(ValueError, match="keypoints must be of shape"):
302+
invalid_keypoints = torch.tensor([[10, 10, 10, 10], [5, 6, 7, 8]], dtype=torch.float)
303+
utils.draw_keypoints(image=img, keypoints=invalid_keypoints)
304+
305+
251306
if __name__ == "__main__":
252307
pytest.main([__file__])

torchvision/utils.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from PIL import Image, ImageDraw, ImageFont, ImageColor
99

10-
__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"]
10+
__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks", "draw_keypoints"]
1111

1212

1313
@torch.no_grad()
@@ -300,6 +300,76 @@ def draw_segmentation_masks(
300300
return out.to(out_dtype)
301301

302302

303+
@torch.no_grad()
304+
def draw_keypoints(
305+
image: torch.Tensor,
306+
keypoints: torch.Tensor,
307+
connectivity: Optional[Tuple[Tuple[int, int]]] = None,
308+
colors: Optional[Union[str, Tuple[int, int, int]]] = None,
309+
radius: int = 2,
310+
width: int = 3,
311+
) -> torch.Tensor:
312+
313+
"""
314+
Draws Keypoints on given RGB image.
315+
The values of the input image should be uint8 between 0 and 255.
316+
317+
Args:
318+
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
319+
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
320+
in the format [x, y].
321+
connectivity (Tuple[Tuple[int, int]]]): A Tuple of tuple where,
322+
each tuple contains pair of keypoints to be connected.
323+
colors (str, Tuple): The color can be represented as
324+
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
325+
radius (int): Integer denoting radius of keypoint.
326+
width (int): Integer denoting width of line connecting keypoints.
327+
328+
Returns:
329+
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
330+
"""
331+
332+
if not isinstance(image, torch.Tensor):
333+
raise TypeError(f"The image must be a tensor, got {type(image)}")
334+
elif image.dtype != torch.uint8:
335+
raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
336+
elif image.dim() != 3:
337+
raise ValueError("Pass individual images, not batches")
338+
elif image.size()[0] != 3:
339+
raise ValueError("Pass an RGB image. Other Image formats are not supported")
340+
341+
if keypoints.ndim != 3:
342+
raise ValueError("keypoints must be of shape (num_instances, K, 2)")
343+
344+
ndarr = image.permute(1, 2, 0).numpy()
345+
img_to_draw = Image.fromarray(ndarr)
346+
draw = ImageDraw.Draw(img_to_draw)
347+
img_kpts = keypoints.to(torch.int64).tolist()
348+
349+
for kpt_id, kpt_inst in enumerate(img_kpts):
350+
for inst_id, kpt in enumerate(kpt_inst):
351+
x1 = kpt[0] - radius
352+
x2 = kpt[0] + radius
353+
y1 = kpt[1] - radius
354+
y2 = kpt[1] + radius
355+
draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0)
356+
357+
if connectivity:
358+
for connection in connectivity:
359+
start_pt_x = kpt_inst[connection[0]][0]
360+
start_pt_y = kpt_inst[connection[0]][1]
361+
362+
end_pt_x = kpt_inst[connection[1]][0]
363+
end_pt_y = kpt_inst[connection[1]][1]
364+
365+
draw.line(
366+
((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)),
367+
width=width,
368+
)
369+
370+
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
371+
372+
303373
def _generate_color_palette(num_masks: int):
304374
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
305375
return [tuple((i * palette) % 255) for i in range(num_masks)]

0 commit comments

Comments
 (0)