Skip to content

Commit 18569be

Browse files
prabhat00155datumbox
authored andcommitted
[fbsync] Add typing annotations to detection/generalized_rcnn (#4631)
Summary: * Update typing * Fix bug * Unblock mypy * Ignore small error Reviewed By: kazhang Differential Revision: D32216673 fbshipit-source-id: cc73b84b830c1e0c31502a2cb6cd7408a899d6ba Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent a474fa5 commit 18569be

File tree

2 files changed

+3
-7
lines changed

2 files changed

+3
-7
lines changed

mypy.ini

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,6 @@ ignore_errors = True
4646

4747
ignore_errors = True
4848

49-
[mypy-torchvision.models.detection.generalized_rcnn]
50-
51-
ignore_errors = True
52-
5349
[mypy-torchvision.models.detection.faster_rcnn]
5450

5551
ignore_errors = True

torchvision/models/detection/generalized_rcnn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class GeneralizedRCNN(nn.Module):
2525
the model
2626
"""
2727

28-
def __init__(self, backbone, rpn, roi_heads, transform):
28+
def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None:
2929
super().__init__()
3030
_log_api_usage_once(self)
3131
self.transform = transform
@@ -48,7 +48,7 @@ def forward(self, images, targets=None):
4848
"""
4949
Args:
5050
images (list[Tensor]): images to be processed
51-
targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
51+
targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
5252
5353
Returns:
5454
result (list[BoxList] or dict[Tensor]): the output from the model.
@@ -97,7 +97,7 @@ def forward(self, images, targets=None):
9797
features = OrderedDict([("0", features)])
9898
proposals, proposal_losses = self.rpn(images, features, targets)
9999
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
100-
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
100+
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator]
101101

102102
losses = {}
103103
losses.update(detector_losses)

0 commit comments

Comments
 (0)