Skip to content

Add typing in detection module (faster rcnn). #4636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ ignore_errors = True

ignore_errors = True

[mypy-torchvision.models.detection.faster_rcnn]

ignore_errors = True

[mypy-torchvision.models.detection.mask_rcnn]

ignore_errors = True
Expand Down
123 changes: 70 additions & 53 deletions torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Tuple, List, Optional, Any, cast

import torch.nn.functional as F
from torch import nn
from torch import nn, Tensor
from torchvision.ops import MultiScaleRoIAlign

from ..._internally_replaced_utils import load_state_dict_from_url
Expand Down Expand Up @@ -55,10 +57,10 @@ class FasterRCNN(GeneralizedRCNN):
If box_predictor is specified, num_classes should be None.
min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
image_mean (Tuple[float, float, float]): mean values used for input normalization.
image_mean (List[float]): mean values used for input normalization.
They are generally the mean values of the dataset on which the backbone has been trained
on
image_std (Tuple[float, float, float]): std values used for input normalization.
image_std (List[float]): std values used for input normalization.
They are generally the std values of the dataset on which the backbone has been trained on
rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
maps.
Expand Down Expand Up @@ -143,39 +145,39 @@ class FasterRCNN(GeneralizedRCNN):

def __init__(
self,
backbone,
num_classes=None,
backbone: nn.Module,
num_classes: Optional[int] = None,
# transform parameters
min_size=800,
max_size=1333,
image_mean=None,
image_std=None,
min_size: int = 800,
max_size: int = 1333,
image_mean: Optional[List[float]] = None,
image_std: Optional[List[float]] = None,
# RPN parameters
rpn_anchor_generator=None,
rpn_head=None,
rpn_pre_nms_top_n_train=2000,
rpn_pre_nms_top_n_test=1000,
rpn_post_nms_top_n_train=2000,
rpn_post_nms_top_n_test=1000,
rpn_nms_thresh=0.7,
rpn_fg_iou_thresh=0.7,
rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256,
rpn_positive_fraction=0.5,
rpn_score_thresh=0.0,
rpn_anchor_generator: Optional[AnchorGenerator] = None,
rpn_head: Optional[nn.Module] = None,
rpn_pre_nms_top_n_train: int = 2000,
rpn_pre_nms_top_n_test: int = 1000,
rpn_post_nms_top_n_train: int = 2000,
rpn_post_nms_top_n_test: int = 1000,
rpn_nms_thresh: float = 0.7,
rpn_fg_iou_thresh: float = 0.7,
rpn_bg_iou_thresh: float = 0.3,
rpn_batch_size_per_image: int = 256,
rpn_positive_fraction: float = 0.5,
rpn_score_thresh: float = 0.0,
# Box parameters
box_roi_pool=None,
box_head=None,
box_predictor=None,
box_score_thresh=0.05,
box_nms_thresh=0.5,
box_detections_per_img=100,
box_fg_iou_thresh=0.5,
box_bg_iou_thresh=0.5,
box_batch_size_per_image=512,
box_positive_fraction=0.25,
bbox_reg_weights=None,
):
box_roi_pool: Optional[MultiScaleRoIAlign] = None,
box_head: Optional[nn.Module] = None,
box_predictor: Optional[nn.Module] = None,
box_score_thresh: float = 0.05,
box_nms_thresh: float = 0.5,
box_detections_per_img: int = 100,
box_fg_iou_thresh: float = 0.5,
box_bg_iou_thresh: float = 0.5,
box_batch_size_per_image: int = 512,
box_positive_fraction: float = 0.25,
bbox_reg_weights: Optional[Tuple[float, float, float, float]] = None,
) -> None:

if not hasattr(backbone, "out_channels"):
raise ValueError(
Expand All @@ -194,7 +196,7 @@ def __init__(
if box_predictor is None:
raise ValueError("num_classes should not be None when box_predictor " "is not specified")

out_channels = backbone.out_channels
out_channels = cast(int, backbone.out_channels)

if rpn_anchor_generator is None:
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
Expand Down Expand Up @@ -229,7 +231,7 @@ def __init__(

if box_predictor is None:
representation_size = 1024
box_predictor = FastRCNNPredictor(representation_size, num_classes)
box_predictor = FastRCNNPredictor(representation_size, num_classes) # type: ignore[arg-type]

roi_heads = RoIHeads(
# Box
Expand Down Expand Up @@ -264,13 +266,13 @@ class TwoMLPHead(nn.Module):
representation_size (int): size of the intermediate representation
"""

def __init__(self, in_channels, representation_size):
def __init__(self, in_channels: int, representation_size: int) -> None:
super(TwoMLPHead, self).__init__()

self.fc6 = nn.Linear(in_channels, representation_size)
self.fc7 = nn.Linear(representation_size, representation_size)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = x.flatten(start_dim=1)

x = F.relu(self.fc6(x))
Expand All @@ -289,12 +291,12 @@ class FastRCNNPredictor(nn.Module):
num_classes (int): number of output classes (including background)
"""

def __init__(self, in_channels, num_classes):
def __init__(self, in_channels: int, num_classes: int) -> None:
super(FastRCNNPredictor, self).__init__()
self.cls_score = nn.Linear(in_channels, num_classes)
self.bbox_pred = nn.Linear(in_channels, num_classes * 4)

def forward(self, x):
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
if x.dim() == 4:
assert list(x.shape[2:]) == [1, 1]
x = x.flatten(start_dim=1)
Expand All @@ -312,8 +314,13 @@ def forward(self, x):


def fasterrcnn_resnet50_fpn(
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
):
pretrained: bool = False,
progress: bool = True,
num_classes: int = 91,
pretrained_backbone: bool = True,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
"""
Constructs a Faster R-CNN model with a ResNet-50-FPN backbone.

Expand Down Expand Up @@ -395,14 +402,14 @@ def fasterrcnn_resnet50_fpn(


def _fasterrcnn_mobilenet_v3_large_fpn(
weights_name,
pretrained=False,
progress=True,
num_classes=91,
pretrained_backbone=True,
trainable_backbone_layers=None,
**kwargs,
):
weights_name: str,
pretrained: bool = False,
progress: bool = True,
num_classes: int = 91,
pretrained_backbone: bool = True,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3
)
Expand Down Expand Up @@ -436,8 +443,13 @@ def _fasterrcnn_mobilenet_v3_large_fpn(


def fasterrcnn_mobilenet_v3_large_320_fpn(
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
):
pretrained: bool = False,
progress: bool = True,
num_classes: int = 91,
pretrained_backbone: bool = True,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
"""
Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases.
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
Expand Down Expand Up @@ -481,8 +493,13 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(


def fasterrcnn_mobilenet_v3_large_fpn(
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
):
pretrained: bool = False,
progress: bool = True,
num_classes: int = 91,
pretrained_backbone: bool = True,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
"""
Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
Expand Down