Skip to content

Commit f71316f

Browse files
fmassaGuanheng ZhangGuanheng Zhang
authored
Fix mypy type annotations (#1696)
* Fix mypy type annotations * follow torchscript Tuple type * redefine torch_choice output type * change the type in cached_grid_anchors * minor bug Co-authored-by: Guanheng Zhang <[email protected]> Co-authored-by: Guanheng Zhang <[email protected]>
1 parent 3ac864d commit f71316f

File tree

9 files changed

+94
-70
lines changed

9 files changed

+94
-70
lines changed

torchvision/models/detection/_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class BalancedPositiveNegativeSampler(object):
2020
"""
2121

2222
def __init__(self, batch_size_per_image, positive_fraction):
23-
# type: (int, float)
23+
# type: (int, float) -> None
2424
"""
2525
Arguments:
2626
batch_size_per_image (int): number of elements to be selected per image
@@ -30,7 +30,7 @@ def __init__(self, batch_size_per_image, positive_fraction):
3030
self.positive_fraction = positive_fraction
3131

3232
def __call__(self, matched_idxs):
33-
# type: (List[Tensor])
33+
# type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
3434
"""
3535
Arguments:
3636
matched idxs: list of tensors containing -1, 0 or positive values.
@@ -139,7 +139,7 @@ class BoxCoder(object):
139139
"""
140140

141141
def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
142-
# type: (Tuple[float, float, float, float], float)
142+
# type: (Tuple[float, float, float, float], float) -> None
143143
"""
144144
Arguments:
145145
weights (4-element tuple)
@@ -149,7 +149,7 @@ def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
149149
self.bbox_xform_clip = bbox_xform_clip
150150

151151
def encode(self, reference_boxes, proposals):
152-
# type: (List[Tensor], List[Tensor])
152+
# type: (List[Tensor], List[Tensor]) -> List[Tensor]
153153
boxes_per_image = [len(b) for b in reference_boxes]
154154
reference_boxes = torch.cat(reference_boxes, dim=0)
155155
proposals = torch.cat(proposals, dim=0)
@@ -173,7 +173,7 @@ def encode_single(self, reference_boxes, proposals):
173173
return targets
174174

175175
def decode(self, rel_codes, boxes):
176-
# type: (Tensor, List[Tensor])
176+
# type: (Tensor, List[Tensor]) -> Tensor
177177
assert isinstance(boxes, (list, tuple))
178178
assert isinstance(rel_codes, torch.Tensor)
179179
boxes_per_image = [b.size(0) for b in boxes]
@@ -251,7 +251,7 @@ class Matcher(object):
251251
}
252252

253253
def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
254-
# type: (float, float, bool)
254+
# type: (float, float, bool) -> None
255255
"""
256256
Args:
257257
high_threshold (float): quality values greater than or equal to

torchvision/models/detection/generalized_rcnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def eager_outputs(self, losses, detections):
4242
return detections
4343

4444
def forward(self, images, targets=None):
45-
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
45+
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
4646
"""
4747
Arguments:
4848
images (list[Tensor]): images to be processed

torchvision/models/detection/image_list.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class ImageList(object):
1414
"""
1515

1616
def __init__(self, tensors, image_sizes):
17-
# type: (Tensor, List[Tuple[int, int]])
17+
# type: (Tensor, List[Tuple[int, int]]) -> None
1818
"""
1919
Arguments:
2020
tensors (tensor)
@@ -24,6 +24,6 @@ def __init__(self, tensors, image_sizes):
2424
self.image_sizes = image_sizes
2525

2626
def to(self, device):
27-
# type: (Device) # noqa
27+
# type: (Device) -> ImageList # noqa
2828
cast_tensor = self.tensors.to(device)
2929
return ImageList(cast_tensor, self.image_sizes)

torchvision/models/detection/roi_heads.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
18-
# type: (Tensor, Tensor, List[Tensor], List[Tensor])
18+
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
1919
"""
2020
Computes the loss for Faster R-CNN.
2121
@@ -55,7 +55,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
5555

5656

5757
def maskrcnn_inference(x, labels):
58-
# type: (Tensor, List[Tensor])
58+
# type: (Tensor, List[Tensor]) -> List[Tensor]
5959
"""
6060
From the results of the CNN, post process the masks
6161
by taking the mask corresponding to the class with max
@@ -85,7 +85,7 @@ def maskrcnn_inference(x, labels):
8585

8686

8787
def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
88-
# type: (Tensor, Tensor, Tensor, int)
88+
# type: (Tensor, Tensor, Tensor, int) -> Tensor
8989
"""
9090
Given segmentation masks and the bounding boxes corresponding
9191
to the location of the masks in the image, this function
@@ -100,7 +100,7 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
100100

101101

102102
def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
103-
# type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor])
103+
# type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
104104
"""
105105
Arguments:
106106
proposals (list[BoxList])
@@ -133,7 +133,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs
133133

134134

135135
def keypoints_to_heatmap(keypoints, rois, heatmap_size):
136-
# type: (Tensor, Tensor, int)
136+
# type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
137137
offset_x = rois[:, 0]
138138
offset_y = rois[:, 1]
139139
scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
@@ -277,7 +277,7 @@ def heatmaps_to_keypoints(maps, rois):
277277

278278

279279
def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
280-
# type: (Tensor, List[Tensor], List[Tensor], List[Tensor])
280+
# type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
281281
N, K, H, W = keypoint_logits.shape
282282
assert H == W
283283
discretization_size = H
@@ -307,7 +307,7 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched
307307

308308

309309
def keypointrcnn_inference(x, boxes):
310-
# type: (Tensor, List[Tensor])
310+
# type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
311311
kp_probs = []
312312
kp_scores = []
313313

@@ -323,7 +323,7 @@ def keypointrcnn_inference(x, boxes):
323323

324324

325325
def _onnx_expand_boxes(boxes, scale):
326-
# type: (Tensor, float)
326+
# type: (Tensor, float) -> Tensor
327327
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
328328
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
329329
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
@@ -344,7 +344,7 @@ def _onnx_expand_boxes(boxes, scale):
344344
# but are kept here for the moment while we need them
345345
# temporarily for paste_mask_in_image
346346
def expand_boxes(boxes, scale):
347-
# type: (Tensor, float)
347+
# type: (Tensor, float) -> Tensor
348348
if torchvision._is_tracing():
349349
return _onnx_expand_boxes(boxes, scale)
350350
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
@@ -370,7 +370,7 @@ def expand_masks_tracing_scale(M, padding):
370370

371371

372372
def expand_masks(mask, padding):
373-
# type: (Tensor, int)
373+
# type: (Tensor, int) -> Tuple[Tensor, float]
374374
M = mask.shape[-1]
375375
if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
376376
scale = expand_masks_tracing_scale(M, padding)
@@ -381,7 +381,7 @@ def expand_masks(mask, padding):
381381

382382

383383
def paste_mask_in_image(mask, box, im_h, im_w):
384-
# type: (Tensor, Tensor, int, int)
384+
# type: (Tensor, Tensor, int, int) -> Tensor
385385
TO_REMOVE = 1
386386
w = int(box[2] - box[0] + TO_REMOVE)
387387
h = int(box[3] - box[1] + TO_REMOVE)
@@ -459,7 +459,7 @@ def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
459459

460460

461461
def paste_masks_in_image(masks, boxes, img_shape, padding=1):
462-
# type: (Tensor, Tensor, Tuple[int, int], int)
462+
# type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
463463
masks, scale = expand_masks(masks, padding=padding)
464464
boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
465465
im_h, im_w = img_shape
@@ -558,7 +558,7 @@ def has_keypoint(self):
558558
return True
559559

560560
def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
561-
# type: (List[Tensor], List[Tensor], List[Tensor])
561+
# type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
562562
matched_idxs = []
563563
labels = []
564564
for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
@@ -595,7 +595,7 @@ def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
595595
return matched_idxs, labels
596596

597597
def subsample(self, labels):
598-
# type: (List[Tensor])
598+
# type: (List[Tensor]) -> List[Tensor]
599599
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
600600
sampled_inds = []
601601
for img_idx, (pos_inds_img, neg_inds_img) in enumerate(
@@ -606,7 +606,7 @@ def subsample(self, labels):
606606
return sampled_inds
607607

608608
def add_gt_proposals(self, proposals, gt_boxes):
609-
# type: (List[Tensor], List[Tensor])
609+
# type: (List[Tensor], List[Tensor]) -> List[Tensor]
610610
proposals = [
611611
torch.cat((proposal, gt_box))
612612
for proposal, gt_box in zip(proposals, gt_boxes)
@@ -615,22 +615,25 @@ def add_gt_proposals(self, proposals, gt_boxes):
615615
return proposals
616616

617617
def DELTEME_all(self, the_list):
618-
# type: (List[bool])
618+
# type: (List[bool]) -> bool
619619
for i in the_list:
620620
if not i:
621621
return False
622622
return True
623623

624624
def check_targets(self, targets):
625-
# type: (Optional[List[Dict[str, Tensor]]])
625+
# type: (Optional[List[Dict[str, Tensor]]]) -> None
626626
assert targets is not None
627627
assert self.DELTEME_all(["boxes" in t for t in targets])
628628
assert self.DELTEME_all(["labels" in t for t in targets])
629629
if self.has_mask():
630630
assert self.DELTEME_all(["masks" in t for t in targets])
631631

632-
def select_training_samples(self, proposals, targets):
633-
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
632+
def select_training_samples(self,
633+
proposals, # type: List[Tensor]
634+
targets # type: Optional[List[Dict[str, Tensor]]]
635+
):
636+
# type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
634637
self.check_targets(targets)
635638
assert targets is not None
636639
dtype = proposals[0].dtype
@@ -662,8 +665,13 @@ def select_training_samples(self, proposals, targets):
662665
regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
663666
return proposals, matched_idxs, labels, regression_targets
664667

665-
def postprocess_detections(self, class_logits, box_regression, proposals, image_shapes):
666-
# type: (Tensor, Tensor, List[Tensor], List[Tuple[int, int]])
668+
def postprocess_detections(self,
669+
class_logits, # type: Tensor
670+
box_regression, # type: Tensor
671+
proposals, # type: List[Tensor]
672+
image_shapes # type: List[Tuple[int, int]]
673+
):
674+
# type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
667675
device = class_logits.device
668676
num_classes = class_logits.shape[-1]
669677

@@ -715,8 +723,13 @@ def postprocess_detections(self, class_logits, box_regression, proposals, image_
715723

716724
return all_boxes, all_scores, all_labels
717725

718-
def forward(self, features, proposals, image_shapes, targets=None):
719-
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]], Optional[List[Dict[str, Tensor]]])
726+
def forward(self,
727+
features, # type: Dict[str, Tensor]
728+
proposals, # type: List[Tensor]
729+
image_shapes, # type: List[Tuple[int, int]]
730+
targets=None # type: Optional[List[Dict[str, Tensor]]]
731+
):
732+
# type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
720733
"""
721734
Arguments:
722735
features (List[Tensor])

torchvision/models/detection/rpn.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(
7575
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
7676
# This method assumes aspect ratio = height / width for an anchor.
7777
def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"):
78-
# type: (List[int], List[float], int, Device) # noqa: F821
78+
# type: (List[int], List[float], int, Device) -> Tensor # noqa: F821
7979
scales = torch.as_tensor(scales, dtype=dtype, device=device)
8080
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
8181
h_ratios = torch.sqrt(aspect_ratios)
@@ -88,7 +88,7 @@ def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="c
8888
return base_anchors.round()
8989

9090
def set_cell_anchors(self, dtype, device):
91-
# type: (int, Device) -> None # noqa: F821
91+
# type: (int, Device) -> None # noqa: F821
9292
if self.cell_anchors is not None:
9393
cell_anchors = self.cell_anchors
9494
assert cell_anchors is not None
@@ -114,7 +114,7 @@ def num_anchors_per_location(self):
114114
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
115115
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
116116
def grid_anchors(self, grid_sizes, strides):
117-
# type: (List[List[int]], List[List[Tensor]])
117+
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
118118
anchors = []
119119
cell_anchors = self.cell_anchors
120120
assert cell_anchors is not None
@@ -147,7 +147,7 @@ def grid_anchors(self, grid_sizes, strides):
147147
return anchors
148148

149149
def cached_grid_anchors(self, grid_sizes, strides):
150-
# type: (List[List[int]], List[List[Tensor]])
150+
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
151151
key = str(grid_sizes) + str(strides)
152152
if key in self._cache:
153153
return self._cache[key]
@@ -156,7 +156,7 @@ def cached_grid_anchors(self, grid_sizes, strides):
156156
return anchors
157157

158158
def forward(self, image_list, feature_maps):
159-
# type: (ImageList, List[Tensor])
159+
# type: (ImageList, List[Tensor]) -> List[Tensor]
160160
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
161161
image_size = image_list.tensors.shape[-2:]
162162
dtype, device = feature_maps[0].dtype, feature_maps[0].device
@@ -200,7 +200,7 @@ def __init__(self, in_channels, num_anchors):
200200
torch.nn.init.constant_(l.bias, 0)
201201

202202
def forward(self, x):
203-
# type: (List[Tensor])
203+
# type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
204204
logits = []
205205
bbox_reg = []
206206
for feature in x:
@@ -211,15 +211,15 @@ def forward(self, x):
211211

212212

213213
def permute_and_flatten(layer, N, A, C, H, W):
214-
# type: (Tensor, int, int, int, int, int)
214+
# type: (Tensor, int, int, int, int, int) -> Tensor
215215
layer = layer.view(N, -1, C, H, W)
216216
layer = layer.permute(0, 3, 4, 1, 2)
217217
layer = layer.reshape(N, -1, C)
218218
return layer
219219

220220

221221
def concat_box_prediction_layers(box_cls, box_regression):
222-
# type: (List[Tensor], List[Tensor])
222+
# type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
223223
box_cls_flattened = []
224224
box_regression_flattened = []
225225
# for each feature level, permute the outputs to make them be in the
@@ -325,7 +325,7 @@ def post_nms_top_n(self):
325325
return self._post_nms_top_n['testing']
326326

327327
def assign_targets_to_anchors(self, anchors, targets):
328-
# type: (List[Tensor], List[Dict[str, Tensor]])
328+
# type: (List[Tensor], List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]]
329329
labels = []
330330
matched_gt_boxes = []
331331
for anchors_per_image, targets_per_image in zip(anchors, targets):
@@ -361,7 +361,7 @@ def assign_targets_to_anchors(self, anchors, targets):
361361
return labels, matched_gt_boxes
362362

363363
def _get_top_n_idx(self, objectness, num_anchors_per_level):
364-
# type: (Tensor, List[int])
364+
# type: (Tensor, List[int]) -> Tensor
365365
r = []
366366
offset = 0
367367
for ob in objectness.split(num_anchors_per_level, 1):
@@ -376,7 +376,7 @@ def _get_top_n_idx(self, objectness, num_anchors_per_level):
376376
return torch.cat(r, dim=1)
377377

378378
def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level):
379-
# type: (Tensor, Tensor, List[Tuple[int, int]], List[int])
379+
# type: (Tensor, Tensor, List[Tuple[int, int]], List[int]) -> Tuple[List[Tensor], List[Tensor]]
380380
num_images = proposals.shape[0]
381381
device = proposals.device
382382
# do not backprop throught objectness
@@ -416,7 +416,7 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_
416416
return final_boxes, final_scores
417417

418418
def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets):
419-
# type: (Tensor, Tensor, List[Tensor], List[Tensor])
419+
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
420420
"""
421421
Arguments:
422422
objectness (Tensor)
@@ -453,8 +453,12 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets)
453453

454454
return objectness_loss, box_loss
455455

456-
def forward(self, images, features, targets=None):
457-
# type: (ImageList, Dict[str, Tensor], Optional[List[Dict[str, Tensor]]])
456+
def forward(self,
457+
images, # type: ImageList
458+
features, # type: Dict[str, Tensor]
459+
targets=None # type: Optional[List[Dict[str, Tensor]]]
460+
):
461+
# type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]]
458462
"""
459463
Arguments:
460464
images (ImageList): images for which we want to compute the predictions

0 commit comments

Comments
 (0)