|
| 1 | +import torch |
| 2 | + |
| 3 | +import torchvision.models |
| 4 | +from torchvision.ops import MultiScaleRoIAlign |
| 5 | +from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork |
| 6 | +from torchvision.models.detection.roi_heads import RoIHeads |
| 7 | +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead |
| 8 | + |
| 9 | +import unittest |
| 10 | + |
| 11 | + |
| 12 | +class Tester(unittest.TestCase): |
| 13 | + |
| 14 | + def test_targets_to_anchors(self): |
| 15 | + boxes = torch.zeros((0, 4), dtype=torch.float32) |
| 16 | + negative_target = {"boxes": boxes, |
| 17 | + "labels": torch.zeros((1, 1), dtype=torch.int64), |
| 18 | + "image_id": 4, |
| 19 | + "area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]), |
| 20 | + "iscrowd": torch.zeros((0,), dtype=torch.int64)} |
| 21 | + |
| 22 | + anchors = [torch.randint(-50, 50, (3, 4), dtype=torch.float32)] |
| 23 | + targets = [negative_target] |
| 24 | + |
| 25 | + anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) |
| 26 | + aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) |
| 27 | + rpn_anchor_generator = AnchorGenerator( |
| 28 | + anchor_sizes, aspect_ratios |
| 29 | + ) |
| 30 | + rpn_head = RPNHead(4, rpn_anchor_generator.num_anchors_per_location()[0]) |
| 31 | + |
| 32 | + head = RegionProposalNetwork( |
| 33 | + rpn_anchor_generator, rpn_head, |
| 34 | + 0.5, 0.3, |
| 35 | + 256, 0.5, |
| 36 | + 2000, 2000, 0.7) |
| 37 | + |
| 38 | + labels, matched_gt_boxes = head.assign_targets_to_anchors(anchors, targets) |
| 39 | + |
| 40 | + self.assertEqual(labels[0].sum(), 0) |
| 41 | + self.assertEqual(labels[0].shape, torch.Size([anchors[0].shape[0]])) |
| 42 | + self.assertEqual(labels[0].dtype, torch.float32) |
| 43 | + |
| 44 | + self.assertEqual(matched_gt_boxes[0].sum(), 0) |
| 45 | + self.assertEqual(matched_gt_boxes[0].shape, anchors[0].shape) |
| 46 | + self.assertEqual(matched_gt_boxes[0].dtype, torch.float32) |
| 47 | + |
| 48 | + def test_assign_targets_to_proposals(self): |
| 49 | + |
| 50 | + proposals = [torch.randint(-50, 50, (20, 4), dtype=torch.float32)] |
| 51 | + gt_boxes = [torch.zeros((0, 4), dtype=torch.float32)] |
| 52 | + gt_labels = [torch.tensor([[0]], dtype=torch.int64)] |
| 53 | + |
| 54 | + box_roi_pool = MultiScaleRoIAlign( |
| 55 | + featmap_names=['0', '1', '2', '3'], |
| 56 | + output_size=7, |
| 57 | + sampling_ratio=2) |
| 58 | + |
| 59 | + resolution = box_roi_pool.output_size[0] |
| 60 | + representation_size = 1024 |
| 61 | + box_head = TwoMLPHead( |
| 62 | + 4 * resolution ** 2, |
| 63 | + representation_size) |
| 64 | + |
| 65 | + representation_size = 1024 |
| 66 | + box_predictor = FastRCNNPredictor( |
| 67 | + representation_size, |
| 68 | + 2) |
| 69 | + |
| 70 | + roi_heads = RoIHeads( |
| 71 | + # Box |
| 72 | + box_roi_pool, box_head, box_predictor, |
| 73 | + 0.5, 0.5, |
| 74 | + 512, 0.25, |
| 75 | + None, |
| 76 | + 0.05, 0.5, 100) |
| 77 | + |
| 78 | + matched_idxs, labels = roi_heads.assign_targets_to_proposals(proposals, gt_boxes, gt_labels) |
| 79 | + |
| 80 | + self.assertEqual(matched_idxs[0].sum(), 0) |
| 81 | + self.assertEqual(matched_idxs[0].shape, torch.Size([proposals[0].shape[0]])) |
| 82 | + self.assertEqual(matched_idxs[0].dtype, torch.int64) |
| 83 | + |
| 84 | + self.assertEqual(labels[0].sum(), 0) |
| 85 | + self.assertEqual(labels[0].shape, torch.Size([proposals[0].shape[0]])) |
| 86 | + self.assertEqual(labels[0].dtype, torch.int64) |
| 87 | + |
| 88 | + def test_forward_negative_sample(self): |
| 89 | + model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) |
| 90 | + in_features = model.roi_heads.box_predictor.cls_score.in_features |
| 91 | + model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2) |
| 92 | + |
| 93 | + images = [torch.rand((3, 100, 100), dtype=torch.float32)] |
| 94 | + boxes = torch.zeros((0, 4), dtype=torch.float32) |
| 95 | + negative_target = {"boxes": boxes, |
| 96 | + "labels": torch.zeros((1, 1), dtype=torch.int64), |
| 97 | + "image_id": 4, |
| 98 | + "area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]), |
| 99 | + "iscrowd": torch.zeros((0,), dtype=torch.int64)} |
| 100 | + |
| 101 | + targets = [negative_target] |
| 102 | + |
| 103 | + loss_dict = model(images, targets) |
| 104 | + |
| 105 | + self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.)) |
| 106 | + self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.)) |
| 107 | + |
| 108 | + |
| 109 | +if __name__ == '__main__': |
| 110 | + unittest.main() |
0 commit comments