diff --git a/mypy.ini b/mypy.ini index d2bbe22614f..44bca6cd832 100644 --- a/mypy.ini +++ b/mypy.ini @@ -46,10 +46,6 @@ ignore_errors = True ignore_errors = True -[mypy-torchvision.models.detection.generalized_rcnn] - -ignore_errors = True - [mypy-torchvision.models.detection.faster_rcnn] ignore_errors = True diff --git a/torchvision/models/detection/generalized_rcnn.py b/torchvision/models/detection/generalized_rcnn.py index bd4ff74cea0..f02f2b33928 100644 --- a/torchvision/models/detection/generalized_rcnn.py +++ b/torchvision/models/detection/generalized_rcnn.py @@ -25,7 +25,7 @@ class GeneralizedRCNN(nn.Module): the model """ - def __init__(self, backbone, rpn, roi_heads, transform): + def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None: super().__init__() _log_api_usage_once(self) self.transform = transform @@ -36,19 +36,26 @@ def __init__(self, backbone, rpn, roi_heads, transform): self._has_warned = False @torch.jit.unused - def eager_outputs(self, losses, detections): - # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]] + def eager_outputs( + self, + losses: Dict[str, Tensor], + detections: List[Dict[str, Tensor]], + ) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]: + if self.training: return losses return detections - def forward(self, images, targets=None): - # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] + def forward( + self, + images: List[Tensor], + targets: Optional[List[Dict[str, Tensor]]] = None, + ) -> Union[Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]], Dict[str, Tensor], List[Dict[str, Tensor]]]: """ Args: images (list[Tensor]): images to be processed - targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional) + targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional) Returns: result (list[BoxList] or dict[Tensor]): the output from the model. @@ -97,7 +104,7 @@ def forward(self, images, targets=None): features = OrderedDict([("0", features)]) proposals, proposal_losses = self.rpn(images, features, targets) detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets) - detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) + detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator] losses = {} losses.update(detector_losses)