|
16 | 16 |
|
17 | 17 | boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
|
18 | 18 |
|
| 19 | +keypoints = torch.tensor([[[10, 10], [5, 5], [2, 2]], [[20, 20], [30, 30], [3, 3]]], dtype=torch.float) |
| 20 | + |
19 | 21 |
|
20 | 22 | def test_make_grid_not_inplace():
|
21 | 23 | t = torch.rand(5, 3, 10, 10)
|
@@ -248,5 +250,58 @@ def test_draw_segmentation_masks_errors():
|
248 | 250 | utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
|
249 | 251 |
|
250 | 252 |
|
| 253 | +def test_draw_keypoints_vanilla(): |
| 254 | + # Keypoints is declared on top as global variable |
| 255 | + keypoints_cp = keypoints.clone() |
| 256 | + |
| 257 | + img = torch.full((3, 100, 100), 0, dtype=torch.uint8) |
| 258 | + img_cp = img.clone() |
| 259 | + result = utils.draw_keypoints(img, keypoints, colors="red", connectivity=((0, 1),)) |
| 260 | + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png") |
| 261 | + if not os.path.exists(path): |
| 262 | + res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) |
| 263 | + res.save(path) |
| 264 | + |
| 265 | + expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) |
| 266 | + assert_equal(result, expected) |
| 267 | + # Check that keypoints are not modified inplace |
| 268 | + assert_equal(keypoints, keypoints_cp) |
| 269 | + # Check that image is not modified in place |
| 270 | + assert_equal(img, img_cp) |
| 271 | + |
| 272 | + |
| 273 | +@pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)]) |
| 274 | +def test_draw_keypoints_colored(colors): |
| 275 | + # Keypoints is declared on top as global variable |
| 276 | + keypoints_cp = keypoints.clone() |
| 277 | + |
| 278 | + img = torch.full((3, 100, 100), 0, dtype=torch.uint8) |
| 279 | + img_cp = img.clone() |
| 280 | + result = utils.draw_keypoints(img, keypoints, colors=colors, connectivity=((0, 1),)) |
| 281 | + assert result.size(0) == 3 |
| 282 | + assert_equal(keypoints, keypoints_cp) |
| 283 | + assert_equal(img, img_cp) |
| 284 | + |
| 285 | + |
| 286 | +def test_draw_keypoints_errors(): |
| 287 | + h, w = 10, 10 |
| 288 | + img = torch.full((3, 100, 100), 0, dtype=torch.uint8) |
| 289 | + |
| 290 | + with pytest.raises(TypeError, match="The image must be a tensor"): |
| 291 | + utils.draw_keypoints(image="Not A Tensor Image", keypoints=keypoints) |
| 292 | + with pytest.raises(ValueError, match="The image dtype must be"): |
| 293 | + img_bad_dtype = torch.full((3, h, w), 0, dtype=torch.int64) |
| 294 | + utils.draw_keypoints(image=img_bad_dtype, keypoints=keypoints) |
| 295 | + with pytest.raises(ValueError, match="Pass individual images, not batches"): |
| 296 | + batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8) |
| 297 | + utils.draw_keypoints(image=batch, keypoints=keypoints) |
| 298 | + with pytest.raises(ValueError, match="Pass an RGB image"): |
| 299 | + one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8) |
| 300 | + utils.draw_keypoints(image=one_channel, keypoints=keypoints) |
| 301 | + with pytest.raises(ValueError, match="keypoints must be of shape"): |
| 302 | + invalid_keypoints = torch.tensor([[10, 10, 10, 10], [5, 6, 7, 8]], dtype=torch.float) |
| 303 | + utils.draw_keypoints(image=img, keypoints=invalid_keypoints) |
| 304 | + |
| 305 | + |
251 | 306 | if __name__ == "__main__":
|
252 | 307 | pytest.main([__file__])
|
0 commit comments