Skip to content

Check target boxes input on generalized_rcnn.py #2207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,24 @@ def compute_mean_std(tensor):
# self.check_script(model, name)
self.checkModule(model, name, ([x],))

def _test_detection_model_validation(self, name):
set_rng_seed(0)
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False)
input_shape = (1, 3, 300, 300)
x = [torch.rand(input_shape)]

# validate that targets are present in training
self.assertRaises(ValueError, model, x)

# validate type
targets = [{'boxes': 0.}]
self.assertRaises(ValueError, model, x, targets=targets)

# validate boxes shape
for boxes in (torch.rand((4,)), torch.rand((1, 5))):
targets = [{'boxes': boxes}]
self.assertRaises(ValueError, model, x, targets=targets)

def _test_video_model(self, name):
# the default input shape is
# bs * num_channels * clip_len * h *w
Expand Down Expand Up @@ -303,6 +321,11 @@ def do_test(self, model_name=model_name):

setattr(ModelTester, "test_" + model_name, do_test)

def do_validation_test(self, model_name=model_name):
self._test_detection_model_validation(model_name)

setattr(ModelTester, "test_" + model_name + "_validation", do_validation_test)


for model_name in get_available_video_models():

Expand Down
13 changes: 13 additions & 0 deletions torchvision/models/detection/generalized_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,19 @@ def forward(self, images, targets=None):
"""
if self.training and targets is None:
raise ValueError("In training mode, targets should be passed")
if self.training:
assert targets is not None
for target in targets:
boxes = target["boxes"]
if isinstance(boxes, torch.Tensor):
if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
raise ValueError("Expected target boxes to be a tensor"
"of shape [N, 4], got {:}.".format(
boxes.shape))
else:
raise ValueError("Expected target boxes to be of type "
"Tensor, got {:}.".format(type(boxes)))

original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
for img in images:
val = img.shape[-2:]
Expand Down