-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Added typing annotations to models/detection [1/n] #4220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b2f6615
4fb038d
deda5d7
5490821
4cfc220
6306746
e8c93cf
6871ccc
53fe949
ecc58a7
6a30c92
fb3ea88
1f5f715
d13311d
5adda72
d16be19
4580541
1ba33ce
254e51b
47f75dc
6aff88c
516fb68
b026039
18ec557
24e3f74
ab7c671
ff5698a
50641d5
6a9c0cb
700d74d
c0e836b
0a93c17
a0b3b2a
fc8032c
730a33d
4cd5b84
c278189
fe7e289
1fda3e4
0daf0d8
4fa9294
f9c7fbe
f1eeea1
12e7d7b
eae251f
9d0d6cf
e437b60
a6bb70c
76b26ba
623d5e4
bb73a22
4e8e9f7
ca5d600
8976768
6fc8812
7422d3b
a10588c
5f2dcc4
70e5b6a
7d4e919
01fdb48
cace5ff
e941e48
623017b
becbbd6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -1,3 +1,5 @@ | ||||
from typing import Optional, Tuple, Any, cast, List | ||||
|
||||
import torch.nn.functional as F | ||||
from torch import nn | ||||
from torchvision.ops import MultiScaleRoIAlign | ||||
|
@@ -55,10 +57,10 @@ class FasterRCNN(GeneralizedRCNN): | |||
If box_predictor is specified, num_classes should be None. | ||||
min_size (int): minimum size of the image to be rescaled before feeding it to the backbone | ||||
max_size (int): maximum size of the image to be rescaled before feeding it to the backbone | ||||
image_mean (Tuple[float, float, float]): mean values used for input normalization. | ||||
image_mean (List[float]): mean values used for input normalization. | ||||
They are generally the mean values of the dataset on which the backbone has been trained | ||||
on | ||||
image_std (Tuple[float, float, float]): std values used for input normalization. | ||||
image_std (List[float]): std values used for input normalization. | ||||
They are generally the std values of the dataset on which the backbone has been trained on | ||||
rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature | ||||
maps. | ||||
|
@@ -143,39 +145,39 @@ class FasterRCNN(GeneralizedRCNN): | |||
|
||||
def __init__( | ||||
self, | ||||
backbone, | ||||
num_classes=None, | ||||
backbone: nn.Module, | ||||
num_classes: Optional[int] = None, | ||||
# transform parameters | ||||
min_size=800, | ||||
max_size=1333, | ||||
image_mean=None, | ||||
image_std=None, | ||||
min_size: int = 800, | ||||
max_size: int = 1333, | ||||
image_mean: Optional[List[float]] = None, | ||||
image_std: Optional[List[float]] = None, | ||||
# RPN parameters | ||||
rpn_anchor_generator=None, | ||||
rpn_head=None, | ||||
rpn_pre_nms_top_n_train=2000, | ||||
rpn_pre_nms_top_n_test=1000, | ||||
rpn_post_nms_top_n_train=2000, | ||||
rpn_post_nms_top_n_test=1000, | ||||
rpn_nms_thresh=0.7, | ||||
rpn_fg_iou_thresh=0.7, | ||||
rpn_bg_iou_thresh=0.3, | ||||
rpn_batch_size_per_image=256, | ||||
rpn_positive_fraction=0.5, | ||||
rpn_score_thresh=0.0, | ||||
rpn_anchor_generator: Optional[AnchorGenerator] = None, | ||||
rpn_head: Optional[nn.Module] = None, | ||||
rpn_pre_nms_top_n_train: int = 2000, | ||||
rpn_pre_nms_top_n_test: int = 1000, | ||||
rpn_post_nms_top_n_train: int = 2000, | ||||
rpn_post_nms_top_n_test: int = 1000, | ||||
rpn_nms_thresh: float = 0.7, | ||||
rpn_fg_iou_thresh: float = 0.7, | ||||
rpn_bg_iou_thresh: float = 0.3, | ||||
rpn_batch_size_per_image: int = 256, | ||||
rpn_positive_fraction: float = 0.5, | ||||
rpn_score_thresh: float = 0.0, | ||||
# Box parameters | ||||
box_roi_pool=None, | ||||
box_head=None, | ||||
box_predictor=None, | ||||
box_score_thresh=0.05, | ||||
box_nms_thresh=0.5, | ||||
box_detections_per_img=100, | ||||
box_fg_iou_thresh=0.5, | ||||
box_bg_iou_thresh=0.5, | ||||
box_batch_size_per_image=512, | ||||
box_positive_fraction=0.25, | ||||
bbox_reg_weights=None, | ||||
): | ||||
box_roi_pool: Optional[MultiScaleRoIAlign] = None, | ||||
oke-aditya marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
box_head: Optional[nn.Module] = None, | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was bit unsure on these. I think that user can write their own classes and just pass them to Let me know, I will change all of these otherwise. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we allow passing custom classes, we need structural duck typing here. Meaning we need to define a custom There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm unsure of what you mean @pmeier . I'm new to type hints. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes. As a rule of thumb, the input types should be as loose as possible and the output types as strict as possible. After some digging, I think
is the only time we do anything with the object passed as In other cases @frgfm's comment might have some merit in other cases: for example, if we at some point access
Especially if users are allowed to pass custom objects, I would always prefer 2. over 1. because it doesn't require any changes on the user side. I don't know if this works with torchscript though. Let's discuss this if the need for something like this arises. |
||||
box_predictor: Optional[nn.Module] = None, | ||||
oke-aditya marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
box_score_thresh: float = 0.05, | ||||
box_nms_thresh: float = 0.5, | ||||
box_detections_per_img: int = 100, | ||||
box_fg_iou_thresh: float = 0.5, | ||||
box_bg_iou_thresh: float = 0.5, | ||||
box_batch_size_per_image: int = 512, | ||||
box_positive_fraction: float = 0.25, | ||||
bbox_reg_weights: Optional[Tuple[float, ...]] = None, | ||||
) -> None: | ||||
|
||||
if not hasattr(backbone, "out_channels"): | ||||
raise ValueError( | ||||
|
@@ -194,7 +196,7 @@ def __init__( | |||
if box_predictor is None: | ||||
raise ValueError("num_classes should not be None when box_predictor " "is not specified") | ||||
|
||||
out_channels = backbone.out_channels | ||||
out_channels = cast(int, backbone.out_channels) | ||||
|
||||
if rpn_anchor_generator is None: | ||||
anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) | ||||
|
@@ -229,7 +231,7 @@ def __init__( | |||
|
||||
if box_predictor is None: | ||||
representation_size = 1024 | ||||
box_predictor = FastRCNNPredictor(representation_size, num_classes) | ||||
box_predictor = FastRCNNPredictor(representation_size, num_classes) # type: ignore[arg-type] | ||||
|
||||
roi_heads = RoIHeads( | ||||
# Box | ||||
|
@@ -264,7 +266,7 @@ class TwoMLPHead(nn.Module): | |||
representation_size (int): size of the intermediate representation | ||||
""" | ||||
|
||||
def __init__(self, in_channels, representation_size): | ||||
def __init__(self, in_channels: int, representation_size: int) -> None: | ||||
super(TwoMLPHead, self).__init__() | ||||
|
||||
self.fc6 = nn.Linear(in_channels, representation_size) | ||||
|
@@ -289,7 +291,7 @@ class FastRCNNPredictor(nn.Module): | |||
num_classes (int): number of output classes (including background) | ||||
""" | ||||
|
||||
def __init__(self, in_channels, num_classes): | ||||
def __init__(self, in_channels: int, num_classes: int) -> None: | ||||
super(FastRCNNPredictor, self).__init__() | ||||
self.cls_score = nn.Linear(in_channels, num_classes) | ||||
self.bbox_pred = nn.Linear(in_channels, num_classes * 4) | ||||
|
@@ -312,8 +314,13 @@ def forward(self, x): | |||
|
||||
|
||||
def fasterrcnn_resnet50_fpn( | ||||
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs | ||||
): | ||||
pretrained: bool = False, | ||||
progress: bool = True, | ||||
num_classes: int = 91, | ||||
pretrained_backbone: bool = True, | ||||
trainable_backbone_layers: Optional[int] = None, | ||||
**kwargs: Any, | ||||
) -> FasterRCNN: | ||||
""" | ||||
Constructs a Faster R-CNN model with a ResNet-50-FPN backbone. | ||||
|
||||
|
@@ -395,14 +402,15 @@ def fasterrcnn_resnet50_fpn( | |||
|
||||
|
||||
def _fasterrcnn_mobilenet_v3_large_fpn( | ||||
weights_name, | ||||
pretrained=False, | ||||
progress=True, | ||||
num_classes=91, | ||||
pretrained_backbone=True, | ||||
trainable_backbone_layers=None, | ||||
**kwargs, | ||||
): | ||||
weights_name: str, | ||||
pretrained: bool = False, | ||||
progress: bool = True, | ||||
num_classes: int = 91, | ||||
pretrained_backbone: bool = True, | ||||
trainable_backbone_layers: Optional[int] = None, | ||||
**kwargs: Any, | ||||
) -> FasterRCNN: | ||||
|
||||
trainable_backbone_layers = _validate_trainable_layers( | ||||
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3 | ||||
) | ||||
|
@@ -436,8 +444,14 @@ def _fasterrcnn_mobilenet_v3_large_fpn( | |||
|
||||
|
||||
def fasterrcnn_mobilenet_v3_large_320_fpn( | ||||
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs | ||||
): | ||||
pretrained: bool = False, | ||||
progress: bool = True, | ||||
num_classes: int = 91, | ||||
pretrained_backbone: bool = True, | ||||
trainable_backbone_layers: Optional[int] = None, | ||||
**kwargs: Any, | ||||
) -> FasterRCNN: | ||||
|
||||
""" | ||||
Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases. | ||||
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See | ||||
|
@@ -481,8 +495,14 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( | |||
|
||||
|
||||
def fasterrcnn_mobilenet_v3_large_fpn( | ||||
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs | ||||
): | ||||
pretrained: bool = False, | ||||
progress: bool = True, | ||||
num_classes: int = 91, | ||||
pretrained_backbone: bool = True, | ||||
trainable_backbone_layers: Optional[int] = None, | ||||
**kwargs: Any, | ||||
) -> FasterRCNN: | ||||
|
||||
""" | ||||
Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone. | ||||
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See | ||||
|
Uh oh!
There was an error while loading. Please reload this page.