-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Make RandomIoUCrop compatible with SanitizeBoundingBoxes #7268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
53cc2fb
Make RandomIoUCrop compatible with SanitizeBoundingBoxes
NicolasHug 5eb4e5d
mask -> valid_indices
NicolasHug 5653d46
fix RandomIoUCrop and tests
pmeier c1d80c7
valid_indices -> valid
pmeier 64c8089
cleanup
pmeier 626b467
Merge branch 'main' into sanitize_and_ioucrop
pmeier 6a4a775
Merge branch 'main' into sanitize_and_ioucrop
pmeier File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1488,16 +1488,13 @@ def test__transform(self, mocker): | |
|
||
fn.assert_has_calls(expected_calls) | ||
|
||
expected_within_targets = sum(is_within_crop_area) | ||
|
||
# check number of bboxes vs number of labels: | ||
output_bboxes = output[1] | ||
assert isinstance(output_bboxes, datapoints.BoundingBox) | ||
assert len(output_bboxes) == expected_within_targets | ||
assert (output_bboxes[~is_within_crop_area] == 0).all() | ||
|
||
output_masks = output[2] | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert isinstance(output_masks, datapoints.Mask) | ||
assert len(output_masks) == expected_within_targets | ||
|
||
|
||
class TestScaleJitter: | ||
|
@@ -2253,10 +2250,11 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor): | |
|
||
|
||
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) | ||
@pytest.mark.parametrize("label_type", (torch.Tensor, list)) | ||
@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite")) | ||
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor)) | ||
def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): | ||
@pytest.mark.parametrize("sanitize", (True, False)) | ||
def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): | ||
torch.manual_seed(0) | ||
if data_augmentation == "hflip": | ||
t = [ | ||
transforms.RandomHorizontalFlip(p=1), | ||
|
@@ -2290,20 +2288,20 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): | |
t = [ | ||
transforms.RandomPhotometricDistort(p=1), | ||
transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})), | ||
# TODO: put back IoUCrop once we remove its hard requirement for Labels | ||
# transforms.RandomIoUCrop(), | ||
transforms.RandomIoUCrop(), | ||
transforms.RandomHorizontalFlip(p=1), | ||
to_tensor(), | ||
transforms.ConvertImageDtype(torch.float), | ||
] | ||
elif data_augmentation == "ssdlite": | ||
t = [ | ||
# TODO: put back IoUCrop once we remove its hard requirement for Labels | ||
# transforms.RandomIoUCrop(), | ||
transforms.RandomIoUCrop(), | ||
transforms.RandomHorizontalFlip(p=1), | ||
to_tensor(), | ||
transforms.ConvertImageDtype(torch.float), | ||
] | ||
if sanitize: | ||
t += [transforms.SanitizeBoundingBoxes()] | ||
t = transforms.Compose(t) | ||
|
||
num_boxes = 5 | ||
|
@@ -2317,10 +2315,7 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): | |
assert is_simple_tensor(image) | ||
|
||
label = torch.randint(0, 10, size=(num_boxes,)) | ||
if label_type is list: | ||
label = label.tolist() | ||
|
||
# TODO: is the shape of the boxes OK? Should it be (1, num_boxes, 4)?? Same for masks | ||
boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4)) | ||
boxes[:, 2:] += boxes[:, :2] | ||
boxes = boxes.clamp(min=0, max=min(H, W)) | ||
|
@@ -2343,8 +2338,19 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): | |
assert isinstance(out["image"], datapoints.Image) | ||
assert isinstance(out["label"], type(sample["label"])) | ||
|
||
out["label"] = torch.tensor(out["label"]) | ||
assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes | ||
num_boxes_expected = { | ||
# ssd and ssdlite contain RandomIoUCrop which may "remove" some bbox. It | ||
# doesn't remove them strictly speaking, it just marks some boxes as | ||
# degenerate and those boxes will be later removed by | ||
# SanitizeBoundingBoxes(), which we add to the pipelines if the sanitize | ||
# param is True. | ||
# Note that the values below are probably specific to the random seed | ||
# set above (which is fine). | ||
(True, "ssd"): 4, | ||
(True, "ssdlite"): 4, | ||
}.get((sanitize, data_augmentation), num_boxes) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I was reviewing that code, I'd probably swear a bit (I can refactor if needed) |
||
|
||
assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes_expected | ||
|
||
|
||
@pytest.mark.parametrize("min_size", (1, 10)) | ||
|
@@ -2377,20 +2383,23 @@ def test_sanitize_bounding_boxes(min_size, labels_getter): | |
valid_indices = [i for (i, is_valid) in enumerate(is_valid_mask) if is_valid] | ||
|
||
boxes = torch.tensor(boxes) | ||
labels = torch.arange(boxes.shape[-2]) | ||
labels = torch.arange(boxes.shape[0]) | ||
|
||
boxes = datapoints.BoundingBox( | ||
boxes, | ||
format=datapoints.BoundingBoxFormat.XYXY, | ||
spatial_size=(H, W), | ||
) | ||
|
||
masks = datapoints.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W))) | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
sample = { | ||
"image": torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8), | ||
"labels": labels, | ||
"boxes": boxes, | ||
"whatever": torch.rand(10), | ||
"None": None, | ||
"masks": masks, | ||
} | ||
|
||
out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample) | ||
|
@@ -2402,7 +2411,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter): | |
assert out["labels"] is sample["labels"] | ||
else: | ||
assert isinstance(out["labels"], torch.Tensor) | ||
assert out["boxes"].shape[:-1] == out["labels"].shape | ||
assert out["boxes"].shape[0] == out["labels"].shape[0] == out["masks"].shape[0] | ||
# This works because we conveniently set labels to arange(num_boxes) | ||
assert out["labels"].tolist() == valid_indices | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.