15
15
16
16
17
17
def fastrcnn_loss (class_logits , box_regression , labels , regression_targets ):
18
- # type: (Tensor, Tensor, List[Tensor], List[Tensor])
18
+ # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
19
19
"""
20
20
Computes the loss for Faster R-CNN.
21
21
@@ -55,7 +55,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
55
55
56
56
57
57
def maskrcnn_inference (x , labels ):
58
- # type: (Tensor, List[Tensor])
58
+ # type: (Tensor, List[Tensor]) -> List[Tensor]
59
59
"""
60
60
From the results of the CNN, post process the masks
61
61
by taking the mask corresponding to the class with max
@@ -85,7 +85,7 @@ def maskrcnn_inference(x, labels):
85
85
86
86
87
87
def project_masks_on_boxes (gt_masks , boxes , matched_idxs , M ):
88
- # type: (Tensor, Tensor, Tensor, int)
88
+ # type: (Tensor, Tensor, Tensor, int) -> Tensor
89
89
"""
90
90
Given segmentation masks and the bounding boxes corresponding
91
91
to the location of the masks in the image, this function
@@ -100,7 +100,7 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
100
100
101
101
102
102
def maskrcnn_loss (mask_logits , proposals , gt_masks , gt_labels , mask_matched_idxs ):
103
- # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor])
103
+ # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
104
104
"""
105
105
Arguments:
106
106
proposals (list[BoxList])
@@ -133,7 +133,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs
133
133
134
134
135
135
def keypoints_to_heatmap (keypoints , rois , heatmap_size ):
136
- # type: (Tensor, Tensor, int)
136
+ # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
137
137
offset_x = rois [:, 0 ]
138
138
offset_y = rois [:, 1 ]
139
139
scale_x = heatmap_size / (rois [:, 2 ] - rois [:, 0 ])
@@ -277,7 +277,7 @@ def heatmaps_to_keypoints(maps, rois):
277
277
278
278
279
279
def keypointrcnn_loss (keypoint_logits , proposals , gt_keypoints , keypoint_matched_idxs ):
280
- # type: (Tensor, List[Tensor], List[Tensor], List[Tensor])
280
+ # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
281
281
N , K , H , W = keypoint_logits .shape
282
282
assert H == W
283
283
discretization_size = H
@@ -307,7 +307,7 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched
307
307
308
308
309
309
def keypointrcnn_inference (x , boxes ):
310
- # type: (Tensor, List[Tensor])
310
+ # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
311
311
kp_probs = []
312
312
kp_scores = []
313
313
@@ -323,7 +323,7 @@ def keypointrcnn_inference(x, boxes):
323
323
324
324
325
325
def _onnx_expand_boxes (boxes , scale ):
326
- # type: (Tensor, float)
326
+ # type: (Tensor, float) -> Tensor
327
327
w_half = (boxes [:, 2 ] - boxes [:, 0 ]) * .5
328
328
h_half = (boxes [:, 3 ] - boxes [:, 1 ]) * .5
329
329
x_c = (boxes [:, 2 ] + boxes [:, 0 ]) * .5
@@ -344,7 +344,7 @@ def _onnx_expand_boxes(boxes, scale):
344
344
# but are kept here for the moment while we need them
345
345
# temporarily for paste_mask_in_image
346
346
def expand_boxes (boxes , scale ):
347
- # type: (Tensor, float)
347
+ # type: (Tensor, float) -> Tensor
348
348
if torchvision ._is_tracing ():
349
349
return _onnx_expand_boxes (boxes , scale )
350
350
w_half = (boxes [:, 2 ] - boxes [:, 0 ]) * .5
@@ -370,7 +370,7 @@ def expand_masks_tracing_scale(M, padding):
370
370
371
371
372
372
def expand_masks (mask , padding ):
373
- # type: (Tensor, int)
373
+ # type: (Tensor, int) -> Tuple[Tensor, float]
374
374
M = mask .shape [- 1 ]
375
375
if torch ._C ._get_tracing_state (): # could not import is_tracing(), not sure why
376
376
scale = expand_masks_tracing_scale (M , padding )
@@ -381,7 +381,7 @@ def expand_masks(mask, padding):
381
381
382
382
383
383
def paste_mask_in_image (mask , box , im_h , im_w ):
384
- # type: (Tensor, Tensor, int, int)
384
+ # type: (Tensor, Tensor, int, int) -> Tensor
385
385
TO_REMOVE = 1
386
386
w = int (box [2 ] - box [0 ] + TO_REMOVE )
387
387
h = int (box [3 ] - box [1 ] + TO_REMOVE )
@@ -459,7 +459,7 @@ def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
459
459
460
460
461
461
def paste_masks_in_image (masks , boxes , img_shape , padding = 1 ):
462
- # type: (Tensor, Tensor, Tuple[int, int], int)
462
+ # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
463
463
masks , scale = expand_masks (masks , padding = padding )
464
464
boxes = expand_boxes (boxes , scale ).to (dtype = torch .int64 )
465
465
im_h , im_w = img_shape
@@ -558,7 +558,7 @@ def has_keypoint(self):
558
558
return True
559
559
560
560
def assign_targets_to_proposals (self , proposals , gt_boxes , gt_labels ):
561
- # type: (List[Tensor], List[Tensor], List[Tensor])
561
+ # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
562
562
matched_idxs = []
563
563
labels = []
564
564
for proposals_in_image , gt_boxes_in_image , gt_labels_in_image in zip (proposals , gt_boxes , gt_labels ):
@@ -595,7 +595,7 @@ def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
595
595
return matched_idxs , labels
596
596
597
597
def subsample (self , labels ):
598
- # type: (List[Tensor])
598
+ # type: (List[Tensor]) -> List[Tensor]
599
599
sampled_pos_inds , sampled_neg_inds = self .fg_bg_sampler (labels )
600
600
sampled_inds = []
601
601
for img_idx , (pos_inds_img , neg_inds_img ) in enumerate (
@@ -606,7 +606,7 @@ def subsample(self, labels):
606
606
return sampled_inds
607
607
608
608
def add_gt_proposals (self , proposals , gt_boxes ):
609
- # type: (List[Tensor], List[Tensor])
609
+ # type: (List[Tensor], List[Tensor]) -> List[Tensor]
610
610
proposals = [
611
611
torch .cat ((proposal , gt_box ))
612
612
for proposal , gt_box in zip (proposals , gt_boxes )
@@ -615,22 +615,25 @@ def add_gt_proposals(self, proposals, gt_boxes):
615
615
return proposals
616
616
617
617
def DELTEME_all (self , the_list ):
618
- # type: (List[bool])
618
+ # type: (List[bool]) -> bool
619
619
for i in the_list :
620
620
if not i :
621
621
return False
622
622
return True
623
623
624
624
def check_targets (self , targets ):
625
- # type: (Optional[List[Dict[str, Tensor]]])
625
+ # type: (Optional[List[Dict[str, Tensor]]]) -> None
626
626
assert targets is not None
627
627
assert self .DELTEME_all (["boxes" in t for t in targets ])
628
628
assert self .DELTEME_all (["labels" in t for t in targets ])
629
629
if self .has_mask ():
630
630
assert self .DELTEME_all (["masks" in t for t in targets ])
631
631
632
- def select_training_samples (self , proposals , targets ):
633
- # type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
632
+ def select_training_samples (self ,
633
+ proposals , # type: List[Tensor]
634
+ targets # type: Optional[List[Dict[str, Tensor]]]
635
+ ):
636
+ # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
634
637
self .check_targets (targets )
635
638
assert targets is not None
636
639
dtype = proposals [0 ].dtype
@@ -662,8 +665,13 @@ def select_training_samples(self, proposals, targets):
662
665
regression_targets = self .box_coder .encode (matched_gt_boxes , proposals )
663
666
return proposals , matched_idxs , labels , regression_targets
664
667
665
- def postprocess_detections (self , class_logits , box_regression , proposals , image_shapes ):
666
- # type: (Tensor, Tensor, List[Tensor], List[Tuple[int, int]])
668
+ def postprocess_detections (self ,
669
+ class_logits , # type: Tensor
670
+ box_regression , # type: Tensor
671
+ proposals , # type: List[Tensor]
672
+ image_shapes # type: List[Tuple[int, int]]
673
+ ):
674
+ # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
667
675
device = class_logits .device
668
676
num_classes = class_logits .shape [- 1 ]
669
677
@@ -715,8 +723,13 @@ def postprocess_detections(self, class_logits, box_regression, proposals, image_
715
723
716
724
return all_boxes , all_scores , all_labels
717
725
718
- def forward (self , features , proposals , image_shapes , targets = None ):
719
- # type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]], Optional[List[Dict[str, Tensor]]])
726
+ def forward (self ,
727
+ features , # type: Dict[str, Tensor]
728
+ proposals , # type: List[Tensor]
729
+ image_shapes , # type: List[Tuple[int, int]]
730
+ targets = None # type: Optional[List[Dict[str, Tensor]]]
731
+ ):
732
+ # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
720
733
"""
721
734
Arguments:
722
735
features (List[Tensor])
0 commit comments