Skip to content

Commit c00a181

Browse files
committed
Add keypoint support
1 parent b5e3b91 commit c00a181

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

torchvision/datapoints/_dataset_wrapper.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,13 @@ def wrapper(idx, sample):
248248
)
249249
batched_target["labels"] = torch.tensor(batched_target["category_id"])
250250

251+
if "keypoints" in batched_target:
252+
batched_target["keypoints"] = torch.as_tensor(batched_target["keypoints"], dtype=torch.float32).reshape(
253+
len(batched_target["keypoints"]), -1, 3
254+
)
255+
256+
257+
251258
return image, batched_target
252259

253260
return wrapper

0 commit comments

Comments
 (0)