diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index b8f20a26b24..8b1665a3d31 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1488,16 +1488,13 @@ def test__transform(self, mocker): fn.assert_has_calls(expected_calls) - expected_within_targets = sum(is_within_crop_area) - # check number of bboxes vs number of labels: output_bboxes = output[1] assert isinstance(output_bboxes, datapoints.BoundingBox) - assert len(output_bboxes) == expected_within_targets + assert (output_bboxes[~is_within_crop_area] == 0).all() output_masks = output[2] assert isinstance(output_masks, datapoints.Mask) - assert len(output_masks) == expected_within_targets class TestScaleJitter: @@ -2253,10 +2250,11 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor): @pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) -@pytest.mark.parametrize("label_type", (torch.Tensor, list)) @pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite")) @pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor)) -def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): +@pytest.mark.parametrize("sanitize", (True, False)) +def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): + torch.manual_seed(0) if data_augmentation == "hflip": t = [ transforms.RandomHorizontalFlip(p=1), @@ -2290,20 +2288,20 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): t = [ transforms.RandomPhotometricDistort(p=1), transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})), - # TODO: put back IoUCrop once we remove its hard requirement for Labels - # transforms.RandomIoUCrop(), + transforms.RandomIoUCrop(), transforms.RandomHorizontalFlip(p=1), to_tensor(), transforms.ConvertImageDtype(torch.float), ] elif data_augmentation == "ssdlite": t = [ - # TODO: put back IoUCrop once we remove its hard requirement for Labels - # transforms.RandomIoUCrop(), + transforms.RandomIoUCrop(), transforms.RandomHorizontalFlip(p=1), to_tensor(), transforms.ConvertImageDtype(torch.float), ] + if sanitize: + t += [transforms.SanitizeBoundingBoxes()] t = transforms.Compose(t) num_boxes = 5 @@ -2317,10 +2315,7 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): assert is_simple_tensor(image) label = torch.randint(0, 10, size=(num_boxes,)) - if label_type is list: - label = label.tolist() - # TODO: is the shape of the boxes OK? Should it be (1, num_boxes, 4)?? Same for masks boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4)) boxes[:, 2:] += boxes[:, :2] boxes = boxes.clamp(min=0, max=min(H, W)) @@ -2343,8 +2338,19 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): assert isinstance(out["image"], datapoints.Image) assert isinstance(out["label"], type(sample["label"])) - out["label"] = torch.tensor(out["label"]) - assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes + num_boxes_expected = { + # ssd and ssdlite contain RandomIoUCrop which may "remove" some bbox. It + # doesn't remove them strictly speaking, it just marks some boxes as + # degenerate and those boxes will be later removed by + # SanitizeBoundingBoxes(), which we add to the pipelines if the sanitize + # param is True. + # Note that the values below are probably specific to the random seed + # set above (which is fine). + (True, "ssd"): 4, + (True, "ssdlite"): 4, + }.get((sanitize, data_augmentation), num_boxes) + + assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes_expected @pytest.mark.parametrize("min_size", (1, 10)) @@ -2377,7 +2383,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter): valid_indices = [i for (i, is_valid) in enumerate(is_valid_mask) if is_valid] boxes = torch.tensor(boxes) - labels = torch.arange(boxes.shape[-2]) + labels = torch.arange(boxes.shape[0]) boxes = datapoints.BoundingBox( boxes, @@ -2385,12 +2391,15 @@ def test_sanitize_bounding_boxes(min_size, labels_getter): spatial_size=(H, W), ) + masks = datapoints.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W))) + sample = { "image": torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8), "labels": labels, "boxes": boxes, "whatever": torch.rand(10), "None": None, + "masks": masks, } out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample) @@ -2402,7 +2411,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter): assert out["labels"] is sample["labels"] else: assert isinstance(out["labels"], torch.Tensor) - assert out["boxes"].shape[:-1] == out["labels"].shape + assert out["boxes"].shape[0] == out["labels"].shape[0] == out["masks"].shape[0] # This works because we conveniently set labels to arange(num_boxes) assert out["labels"].tolist() == valid_indices diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 9b3482f3f0a..ebee2eec58f 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -1090,13 +1090,16 @@ def make_datapoints(self, with_mask=True): "t_ref, t, data_kwargs", [ (det_transforms.RandomHorizontalFlip(p=1.0), v2_transforms.RandomHorizontalFlip(p=1.0), {}), - # FIXME: make - # v2_transforms.Compose([ - # v2_transforms.RandomIoUCrop(), - # v2_transforms.SanitizeBoundingBoxes() - # ]) - # work - # (det_transforms.RandomIoUCrop(), v2_transforms.RandomIoUCrop(), {"with_mask": False}), + ( + det_transforms.RandomIoUCrop(), + v2_transforms.Compose( + [ + v2_transforms.RandomIoUCrop(), + v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]), + ] + ), + {"with_mask": False}, + ), (det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}), (det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024)), {}), ( diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 65d116e2082..d8ab0bb2410 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -721,8 +721,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: if left == right or top == bottom: continue - # FIXME: I think we can stop here? - # check for any valid boxes with centers within the crop area xyxy_bboxes = F.convert_format_bounding_box( bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY @@ -745,23 +743,16 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - # FIXME: refactor this to not remove anything if len(params) < 1: return inpt - is_within_crop_area = params["is_within_crop_area"] - output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) if isinstance(output, datapoints.BoundingBox): - bboxes = output[is_within_crop_area] - bboxes = F.clamp_bounding_box(bboxes, output.format, output.spatial_size) - output = datapoints.BoundingBox.wrap_like(output, bboxes) - elif isinstance(output, datapoints.Mask): - # apply is_within_crop_area if mask is one-hot encoded - masks = output[is_within_crop_area] - output = datapoints.Mask.wrap_like(output, masks) + # We "mark" the invalid boxes as degenreate, and they can be + # removed by a later call to SanitizeBoundingBoxes() + output[~params["is_within_crop_area"]] = 0 return output diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index de1b7ce0022..6dd0755cfbb 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -265,14 +265,14 @@ def forward(self, *inputs: Any) -> Any: ), ) ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] - mask = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1) + valid = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1) # TODO: Do we really need to check for out of bounds here? All # transforms should be clamping anyway, so this should never happen? image_h, image_w = boxes.spatial_size - mask &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w) - mask &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h) + valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w) + valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h) - params = dict(mask=mask, labels=labels) + params = dict(valid=valid, labels=labels) flat_outputs = [ # Even-though it may look like we're transforming all inputs, we don't: # _transform() will only care about BoundingBoxes and the labels @@ -284,7 +284,9 @@ def forward(self, *inputs: Any) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if (inpt is not None and inpt is params["labels"]) or isinstance(inpt, datapoints.BoundingBox): - inpt = inpt[params["mask"]] + if (inpt is not None and inpt is params["labels"]) or isinstance( + inpt, (datapoints.BoundingBox, datapoints.Mask) + ): + inpt = inpt[params["valid"]] return inpt