Skip to content

Commit e75b497

Browse files
mnc537fmassa
andauthored
Train Faster R-CNN with negative samples (#1911)
* modified FasterRCNN to accept negative samples * remove debug lines * Change torch.zeros_like to torch.zerros * Add unit tests * take the `device` into account Co-authored-by: Francisco Massa <[email protected]>
1 parent d45a77d commit e75b497

File tree

3 files changed

+165
-31
lines changed

3 files changed

+165
-31
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import torch
2+
3+
import torchvision.models
4+
from torchvision.ops import MultiScaleRoIAlign
5+
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
6+
from torchvision.models.detection.roi_heads import RoIHeads
7+
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
8+
9+
import unittest
10+
11+
12+
class Tester(unittest.TestCase):
13+
14+
def test_targets_to_anchors(self):
15+
boxes = torch.zeros((0, 4), dtype=torch.float32)
16+
negative_target = {"boxes": boxes,
17+
"labels": torch.zeros((1, 1), dtype=torch.int64),
18+
"image_id": 4,
19+
"area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
20+
"iscrowd": torch.zeros((0,), dtype=torch.int64)}
21+
22+
anchors = [torch.randint(-50, 50, (3, 4), dtype=torch.float32)]
23+
targets = [negative_target]
24+
25+
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
26+
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
27+
rpn_anchor_generator = AnchorGenerator(
28+
anchor_sizes, aspect_ratios
29+
)
30+
rpn_head = RPNHead(4, rpn_anchor_generator.num_anchors_per_location()[0])
31+
32+
head = RegionProposalNetwork(
33+
rpn_anchor_generator, rpn_head,
34+
0.5, 0.3,
35+
256, 0.5,
36+
2000, 2000, 0.7)
37+
38+
labels, matched_gt_boxes = head.assign_targets_to_anchors(anchors, targets)
39+
40+
self.assertEqual(labels[0].sum(), 0)
41+
self.assertEqual(labels[0].shape, torch.Size([anchors[0].shape[0]]))
42+
self.assertEqual(labels[0].dtype, torch.float32)
43+
44+
self.assertEqual(matched_gt_boxes[0].sum(), 0)
45+
self.assertEqual(matched_gt_boxes[0].shape, anchors[0].shape)
46+
self.assertEqual(matched_gt_boxes[0].dtype, torch.float32)
47+
48+
def test_assign_targets_to_proposals(self):
49+
50+
proposals = [torch.randint(-50, 50, (20, 4), dtype=torch.float32)]
51+
gt_boxes = [torch.zeros((0, 4), dtype=torch.float32)]
52+
gt_labels = [torch.tensor([[0]], dtype=torch.int64)]
53+
54+
box_roi_pool = MultiScaleRoIAlign(
55+
featmap_names=['0', '1', '2', '3'],
56+
output_size=7,
57+
sampling_ratio=2)
58+
59+
resolution = box_roi_pool.output_size[0]
60+
representation_size = 1024
61+
box_head = TwoMLPHead(
62+
4 * resolution ** 2,
63+
representation_size)
64+
65+
representation_size = 1024
66+
box_predictor = FastRCNNPredictor(
67+
representation_size,
68+
2)
69+
70+
roi_heads = RoIHeads(
71+
# Box
72+
box_roi_pool, box_head, box_predictor,
73+
0.5, 0.5,
74+
512, 0.25,
75+
None,
76+
0.05, 0.5, 100)
77+
78+
matched_idxs, labels = roi_heads.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
79+
80+
self.assertEqual(matched_idxs[0].sum(), 0)
81+
self.assertEqual(matched_idxs[0].shape, torch.Size([proposals[0].shape[0]]))
82+
self.assertEqual(matched_idxs[0].dtype, torch.int64)
83+
84+
self.assertEqual(labels[0].sum(), 0)
85+
self.assertEqual(labels[0].shape, torch.Size([proposals[0].shape[0]]))
86+
self.assertEqual(labels[0].dtype, torch.int64)
87+
88+
def test_forward_negative_sample(self):
89+
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
90+
in_features = model.roi_heads.box_predictor.cls_score.in_features
91+
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
92+
93+
images = [torch.rand((3, 100, 100), dtype=torch.float32)]
94+
boxes = torch.zeros((0, 4), dtype=torch.float32)
95+
negative_target = {"boxes": boxes,
96+
"labels": torch.zeros((1, 1), dtype=torch.int64),
97+
"image_id": 4,
98+
"area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
99+
"iscrowd": torch.zeros((0,), dtype=torch.int64)}
100+
101+
targets = [negative_target]
102+
103+
loss_dict = model(images, targets)
104+
105+
self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
106+
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
107+
108+
109+
if __name__ == '__main__':
110+
unittest.main()

torchvision/models/detection/roi_heads.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -574,22 +574,33 @@ def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
574574
matched_idxs = []
575575
labels = []
576576
for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
577-
# set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
578-
match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
579-
matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
580577

581-
clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
578+
if gt_boxes_in_image.numel() == 0:
579+
# Background image
580+
device = proposals_in_image.device
581+
clamped_matched_idxs_in_image = torch.zeros(
582+
(proposals_in_image.shape[0],), dtype=torch.int64, device=device
583+
)
584+
labels_in_image = torch.zeros(
585+
(proposals_in_image.shape[0],), dtype=torch.int64, device=device
586+
)
587+
else:
588+
# set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
589+
match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
590+
matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
582591

583-
labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
584-
labels_in_image = labels_in_image.to(dtype=torch.int64)
592+
clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
585593

586-
# Label background (below the low threshold)
587-
bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
588-
labels_in_image[bg_inds] = torch.tensor(0)
594+
labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
595+
labels_in_image = labels_in_image.to(dtype=torch.int64)
589596

590-
# Label ignore proposals (between low and high thresholds)
591-
ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
592-
labels_in_image[ignore_inds] = torch.tensor(-1) # -1 is ignored by sampler
597+
# Label background (below the low threshold)
598+
bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
599+
labels_in_image[bg_inds] = torch.tensor(0)
600+
601+
# Label ignore proposals (between low and high thresholds)
602+
ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
603+
labels_in_image[ignore_inds] = torch.tensor(-1) # -1 is ignored by sampler
593604

594605
matched_idxs.append(clamped_matched_idxs_in_image)
595606
labels.append(labels_in_image)
@@ -635,6 +646,8 @@ def select_training_samples(self, proposals, targets):
635646
self.check_targets(targets)
636647
assert targets is not None
637648
dtype = proposals[0].dtype
649+
device = proposals[0].device
650+
638651
gt_boxes = [t["boxes"].to(dtype) for t in targets]
639652
gt_labels = [t["labels"] for t in targets]
640653

@@ -652,7 +665,11 @@ def select_training_samples(self, proposals, targets):
652665
proposals[img_id] = proposals[img_id][img_sampled_inds]
653666
labels[img_id] = labels[img_id][img_sampled_inds]
654667
matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
655-
matched_gt_boxes.append(gt_boxes[img_id][matched_idxs[img_id]])
668+
669+
gt_boxes_in_image = gt_boxes[img_id]
670+
if gt_boxes_in_image.numel() == 0:
671+
gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
672+
matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
656673

657674
regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
658675
return proposals, matched_idxs, labels, regression_targets

torchvision/models/detection/rpn.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -336,24 +336,31 @@ def assign_targets_to_anchors(self, anchors, targets):
336336
matched_gt_boxes = []
337337
for anchors_per_image, targets_per_image in zip(anchors, targets):
338338
gt_boxes = targets_per_image["boxes"]
339-
match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image)
340-
matched_idxs = self.proposal_matcher(match_quality_matrix)
341-
# get the targets corresponding GT for each proposal
342-
# NB: need to clamp the indices because we can have a single
343-
# GT in the image, and matched_idxs can be -2, which goes
344-
# out of bounds
345-
matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
346-
347-
labels_per_image = matched_idxs >= 0
348-
labels_per_image = labels_per_image.to(dtype=torch.float32)
349-
350-
# Background (negative examples)
351-
bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
352-
labels_per_image[bg_indices] = torch.tensor(0.0)
353-
354-
# discard indices that are between thresholds
355-
inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
356-
labels_per_image[inds_to_discard] = torch.tensor(-1.0)
339+
340+
if gt_boxes.numel() == 0:
341+
# Background image (negative example)
342+
device = anchors_per_image.device
343+
matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device)
344+
labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device)
345+
else:
346+
match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image)
347+
matched_idxs = self.proposal_matcher(match_quality_matrix)
348+
# get the targets corresponding GT for each proposal
349+
# NB: need to clamp the indices because we can have a single
350+
# GT in the image, and matched_idxs can be -2, which goes
351+
# out of bounds
352+
matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
353+
354+
labels_per_image = matched_idxs >= 0
355+
labels_per_image = labels_per_image.to(dtype=torch.float32)
356+
357+
# Background (negative examples)
358+
bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
359+
labels_per_image[bg_indices] = torch.tensor(0.0)
360+
361+
# discard indices that are between thresholds
362+
inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
363+
labels_per_image[inds_to_discard] = torch.tensor(-1.0)
357364

358365
labels.append(labels_per_image)
359366
matched_gt_boxes.append(matched_gt_boxes_per_image)

0 commit comments

Comments
 (0)