diff --git a/mypy.ini b/mypy.ini index a2733d3ae3b..bc3cc31bf15 100644 --- a/mypy.ini +++ b/mypy.ini @@ -25,10 +25,6 @@ ignore_errors = True ignore_errors = True -[mypy-torchvision.models.detection.transform] - -ignore_errors = True - [mypy-torchvision.models.detection.roi_heads] ignore_errors = True diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index e4a1134b85c..764aa268d14 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -1,5 +1,5 @@ import math -from typing import List, Tuple, Dict, Optional +from typing import List, Tuple, Dict, Optional, cast, Union import torch import torchvision @@ -19,7 +19,8 @@ def _get_shape_onnx(image: Tensor) -> Tensor: @torch.jit.unused def _fake_cast_onnx(v: Tensor) -> float: # ONNX requires a tensor but here we fake its type for JIT. - return v + # cast is no-op at runtime and it's there only to help mypy. + return cast(float, v) def _resize_image_and_masks( @@ -29,6 +30,7 @@ def _resize_image_and_masks( target: Optional[Dict[str, Tensor]] = None, fixed_size: Optional[Tuple[int, int]] = None, ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + if torchvision._is_tracing(): im_shape = _get_shape_onnx(image) else: @@ -85,13 +87,13 @@ class GeneralizedRCNNTransform(nn.Module): def __init__( self, - min_size: int, + min_size: Union[int, List[int]], max_size: int, image_mean: List[float], image_std: List[float], size_divisible: int = 32, fixed_size: Optional[Tuple[int, int]] = None, - ): + ) -> None: super(GeneralizedRCNNTransform, self).__init__() if not isinstance(min_size, (list, tuple)): min_size = (min_size,) @@ -179,12 +181,12 @@ def resize( return image, target bbox = target["boxes"] - bbox = resize_boxes(bbox, (h, w), image.shape[-2:]) + bbox = resize_boxes(bbox, [h, w], list(image.shape[-2:])) target["boxes"] = bbox if "keypoints" in target: keypoints = target["keypoints"] - keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:]) + keypoints = resize_keypoints(keypoints, [h, w], list(image.shape[-2:])) target["keypoints"] = keypoints return image, target @@ -192,7 +194,7 @@ def resize( # batch_images() that is supported by ONNX tracing. @torch.jit.unused def _onnx_batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor: - max_size = [] + max_size: List[Tensor] = [] for i in range(images[0].dim()): max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64) max_size.append(max_size_i) @@ -242,8 +244,8 @@ def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor def postprocess( self, result: List[Dict[str, Tensor]], - image_shapes: List[Tuple[int, int]], - original_image_sizes: List[Tuple[int, int]], + image_shapes: List[List[int]], + original_image_sizes: List[List[int]], ) -> List[Dict[str, Tensor]]: if self.training: return result