Skip to content

Commit 7aeedd1

Browse files
shapovalovfacebook-github-bot
authored andcommitted
When bounding boxes are cached in metadata, don’t crash on load_masks=False
Summary: We currently support caching bounding boxes in MaskAnnotation. If present, they are not re-computed from the mask. However, the masks need to be loaded for the bbox to be set. This diff fixes that. Even if load_masks / load_blobs are unset, the bounding box can be picked up from the metadata. Reviewed By: bottler Differential Revision: D45144918 fbshipit-source-id: 8a2e2c115e96070b6fcdc29cbe57e1cee606ddcd
1 parent 0e3138e commit 7aeedd1

File tree

2 files changed

+25
-24
lines changed

2 files changed

+25
-24
lines changed

pytorch3d/implicitron/dataset/frame_data.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -555,12 +555,19 @@ def build(
555555
else None,
556556
)
557557

558-
if load_blobs and self.load_masks and frame_annotation.mask is not None:
559-
(
560-
frame_data.fg_probability,
561-
frame_data.mask_path,
562-
frame_data.bbox_xywh,
563-
) = self._load_fg_probability(frame_annotation)
558+
mask_annotation = frame_annotation.mask
559+
if mask_annotation is not None:
560+
fg_mask_np: Optional[np.ndarray] = None
561+
if load_blobs and self.load_masks:
562+
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)
563+
frame_data.mask_path = mask_path
564+
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
565+
566+
bbox_xywh = mask_annotation.bounding_box_xywh
567+
if bbox_xywh is None and fg_mask_np is not None:
568+
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
569+
570+
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long)
564571

565572
if frame_annotation.image is not None:
566573
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
@@ -604,25 +611,15 @@ def build(
604611

605612
def _load_fg_probability(
606613
self, entry: types.FrameAnnotation
607-
) -> Tuple[Optional[torch.Tensor], Optional[str], Optional[torch.Tensor]]:
608-
614+
) -> Tuple[np.ndarray, str]:
609615
full_path = os.path.join(self.dataset_root, entry.mask.path) # pyre-ignore
610616
fg_probability = load_mask(self._local_path(full_path))
611-
# we can use provided bbox_xywh or calculate it based on mask
612-
# saves time to skip bbox calculation
613-
# pyre-ignore
614-
bbox_xywh = entry.mask.bounding_box_xywh or get_bbox_from_mask(
615-
fg_probability, self.box_crop_mask_thr
616-
)
617617
if fg_probability.shape[-2:] != entry.image.size:
618618
raise ValueError(
619619
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
620620
)
621-
return (
622-
safe_as_tensor(fg_probability, torch.float),
623-
full_path,
624-
safe_as_tensor(bbox_xywh, torch.long),
625-
)
621+
622+
return fg_probability, full_path
626623

627624
def _load_images(
628625
self,

tests/implicitron/test_frame_data_builder.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pytorch3d.implicitron.dataset.dataset_base import FrameData
1818
from pytorch3d.implicitron.dataset.frame_data import FrameDataBuilder
1919
from pytorch3d.implicitron.dataset.utils import (
20+
get_bbox_from_mask,
2021
load_16big_png_depth,
2122
load_1bit_png_mask,
2223
load_depth,
@@ -107,11 +108,14 @@ def test_load_and_adjust_frame_data(self):
107108
)
108109
self.frame_data.effective_image_size_hw = self.frame_data.image_size_hw
109110

110-
(
111-
self.frame_data.fg_probability,
112-
self.frame_data.mask_path,
113-
self.frame_data.bbox_xywh,
114-
) = self.frame_data_builder._load_fg_probability(self.frame_annotation)
111+
fg_mask_np, mask_path = self.frame_data_builder._load_fg_probability(
112+
self.frame_annotation
113+
)
114+
self.frame_data.mask_path = mask_path
115+
self.frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
116+
mask_thr = self.frame_data_builder.box_crop_mask_thr
117+
bbox_xywh = get_bbox_from_mask(fg_mask_np, mask_thr)
118+
self.frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long)
115119

116120
self.assertIsNotNone(self.frame_data.mask_path)
117121
self.assertTrue(torch.is_tensor(self.frame_data.fg_probability))

0 commit comments

Comments
 (0)