From bd082fbf3bde41d5451075bf032532ffb67dda00 Mon Sep 17 00:00:00 2001 From: Matheus Centa Date: Tue, 12 May 2020 15:59:15 +0200 Subject: [PATCH 1/3] Check target boxes input on generalized_rcnn.py --- torchvision/models/detection/generalized_rcnn.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/torchvision/models/detection/generalized_rcnn.py b/torchvision/models/detection/generalized_rcnn.py index e58517c6a21..a39c5585fcc 100644 --- a/torchvision/models/detection/generalized_rcnn.py +++ b/torchvision/models/detection/generalized_rcnn.py @@ -57,6 +57,16 @@ def forward(self, images, targets=None): """ if self.training and targets is None: raise ValueError("In training mode, targets should be passed") + if self.training: + boxes = targets["boxes"] + if isinstance(boxes, torch.Tensor): + if len(boxes.shape) != 2 or boxes.shape[-1] != 4: + raise ValueError("Expected target boxes to be a tensor of " + "shape [N, 4], got {:}.".format(boxes.shape)) + else: + raise ValueError("Expected target boxes to be of type Tensor, " + "got {:}.".format(type(boxes))) + original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], []) for img in images: val = img.shape[-2:] From 26c8bced3dae261c0dd4ac6127947f20b03d218f Mon Sep 17 00:00:00 2001 From: Matheus Centa Date: Wed, 13 May 2020 17:18:23 +0200 Subject: [PATCH 2/3] Fix target box validation in generalized_rcnn.py --- .../models/detection/generalized_rcnn.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/torchvision/models/detection/generalized_rcnn.py b/torchvision/models/detection/generalized_rcnn.py index a39c5585fcc..9c1c3a4411b 100644 --- a/torchvision/models/detection/generalized_rcnn.py +++ b/torchvision/models/detection/generalized_rcnn.py @@ -58,14 +58,17 @@ def forward(self, images, targets=None): if self.training and targets is None: raise ValueError("In training mode, targets should be passed") if self.training: - boxes = targets["boxes"] - if isinstance(boxes, torch.Tensor): - if len(boxes.shape) != 2 or boxes.shape[-1] != 4: - raise ValueError("Expected target boxes to be a tensor of " - "shape [N, 4], got {:}.".format(boxes.shape)) - else: - raise ValueError("Expected target boxes to be of type Tensor, " - "got {:}.".format(type(boxes))) + assert targets is not None + for target in targets: + boxes = target["boxes"] + if isinstance(boxes, torch.Tensor): + if len(boxes.shape) != 2 or boxes.shape[-1] != 4: + raise ValueError("Expected target boxes to be a tensor" + "of shape [N, 4], got {:}.".format( + boxes.shape)) + else: + raise ValueError("Expected target boxes to be of type " + "Tensor, got {:}.".format(type(boxes))) original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], []) for img in images: From 0841afb95cb555fed092954c783568ffacb21957 Mon Sep 17 00:00:00 2001 From: Matheus Centa Date: Wed, 13 May 2020 17:19:23 +0200 Subject: [PATCH 3/3] Add tests for input validation of detection models --- test/test_models.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/test/test_models.py b/test/test_models.py index 32d40c86b49..099bd355b8e 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -155,6 +155,24 @@ def compute_mean_std(tensor): # self.check_script(model, name) self.checkModule(model, name, ([x],)) + def _test_detection_model_validation(self, name): + set_rng_seed(0) + model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False) + input_shape = (1, 3, 300, 300) + x = [torch.rand(input_shape)] + + # validate that targets are present in training + self.assertRaises(ValueError, model, x) + + # validate type + targets = [{'boxes': 0.}] + self.assertRaises(ValueError, model, x, targets=targets) + + # validate boxes shape + for boxes in (torch.rand((4,)), torch.rand((1, 5))): + targets = [{'boxes': boxes}] + self.assertRaises(ValueError, model, x, targets=targets) + def _test_video_model(self, name): # the default input shape is # bs * num_channels * clip_len * h *w @@ -303,6 +321,11 @@ def do_test(self, model_name=model_name): setattr(ModelTester, "test_" + model_name, do_test) + def do_validation_test(self, model_name=model_name): + self._test_detection_model_validation(model_name) + + setattr(ModelTester, "test_" + model_name + "_validation", do_validation_test) + for model_name in get_available_video_models():