|
1 | 1 | import warnings
|
2 |
| -from typing import Any, Optional |
| 2 | +from typing import Any, Optional, Union |
3 | 3 |
|
4 | 4 | from ....models.detection.faster_rcnn import (
|
5 |
| - _validate_trainable_layers, |
| 5 | + _mobilenet_extractor, |
6 | 6 | _resnet_fpn_extractor,
|
| 7 | + _validate_trainable_layers, |
| 8 | + AnchorGenerator, |
7 | 9 | FasterRCNN,
|
8 | 10 | misc_nn_ops,
|
9 | 11 | overwrite_eps,
|
10 | 12 | )
|
11 | 13 | from ...transforms.presets import CocoEval
|
12 | 14 | from .._api import Weights, WeightEntry
|
13 | 15 | from .._meta import _COCO_CATEGORIES
|
| 16 | +from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large |
14 | 17 | from ..resnet import ResNet50Weights, resnet50
|
15 | 18 |
|
16 | 19 |
|
17 |
| -__all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"] |
| 20 | +__all__ = [ |
| 21 | + "FasterRCNN", |
| 22 | + "FasterRCNNResNet50FPNWeights", |
| 23 | + "FasterRCNNMobileNetV3LargeFPNWeights", |
| 24 | + "FasterRCNNMobileNetV3Large320FPNWeights", |
| 25 | + "fasterrcnn_resnet50_fpn", |
| 26 | + "fasterrcnn_mobilenet_v3_large_fpn", |
| 27 | + "fasterrcnn_mobilenet_v3_large_320_fpn", |
| 28 | +] |
| 29 | + |
| 30 | + |
| 31 | +_common_meta = {"categories": _COCO_CATEGORIES} |
18 | 32 |
|
19 | 33 |
|
20 | 34 | class FasterRCNNResNet50FPNWeights(Weights):
|
21 | 35 | Coco_RefV1 = WeightEntry(
|
22 | 36 | url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
|
23 | 37 | transforms=CocoEval,
|
24 | 38 | meta={
|
25 |
| - "categories": _COCO_CATEGORIES, |
| 39 | + **_common_meta, |
26 | 40 | "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
|
27 | 41 | "map": 37.0,
|
28 | 42 | },
|
29 | 43 | )
|
30 | 44 |
|
31 | 45 |
|
| 46 | +class FasterRCNNMobileNetV3LargeFPNWeights(Weights): |
| 47 | + Coco_RefV1 = WeightEntry( |
| 48 | + url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", |
| 49 | + transforms=CocoEval, |
| 50 | + meta={ |
| 51 | + **_common_meta, |
| 52 | + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn", |
| 53 | + "map": 32.8, |
| 54 | + }, |
| 55 | + ) |
| 56 | + |
| 57 | + |
| 58 | +class FasterRCNNMobileNetV3Large320FPNWeights(Weights): |
| 59 | + Coco_RefV1 = WeightEntry( |
| 60 | + url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", |
| 61 | + transforms=CocoEval, |
| 62 | + meta={ |
| 63 | + **_common_meta, |
| 64 | + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn", |
| 65 | + "map": 22.8, |
| 66 | + }, |
| 67 | + ) |
| 68 | + |
| 69 | + |
32 | 70 | def fasterrcnn_resnet50_fpn(
|
33 | 71 | weights: Optional[FasterRCNNResNet50FPNWeights] = None,
|
34 | 72 | weights_backbone: Optional[ResNet50Weights] = None,
|
@@ -64,3 +102,109 @@ def fasterrcnn_resnet50_fpn(
|
64 | 102 | overwrite_eps(model, 0.0)
|
65 | 103 |
|
66 | 104 | return model
|
| 105 | + |
| 106 | + |
| 107 | +def _fasterrcnn_mobilenet_v3_large_fpn( |
| 108 | + weights: Optional[Union[FasterRCNNMobileNetV3LargeFPNWeights, FasterRCNNMobileNetV3Large320FPNWeights]] = None, |
| 109 | + weights_backbone: Optional[MobileNetV3LargeWeights] = None, |
| 110 | + progress: bool = True, |
| 111 | + num_classes: int = 91, |
| 112 | + trainable_backbone_layers: Optional[int] = None, |
| 113 | + **kwargs: Any, |
| 114 | +) -> FasterRCNN: |
| 115 | + if weights is not None: |
| 116 | + weights_backbone = None |
| 117 | + num_classes = len(weights.meta["categories"]) |
| 118 | + |
| 119 | + trainable_backbone_layers = _validate_trainable_layers( |
| 120 | + weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 3 |
| 121 | + ) |
| 122 | + |
| 123 | + backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) |
| 124 | + backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers) |
| 125 | + anchor_sizes = ( |
| 126 | + ( |
| 127 | + 32, |
| 128 | + 64, |
| 129 | + 128, |
| 130 | + 256, |
| 131 | + 512, |
| 132 | + ), |
| 133 | + ) * 3 |
| 134 | + aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) |
| 135 | + model = FasterRCNN( |
| 136 | + backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs |
| 137 | + ) |
| 138 | + |
| 139 | + if weights is not None: |
| 140 | + model.load_state_dict(weights.state_dict(progress=progress)) |
| 141 | + |
| 142 | + return model |
| 143 | + |
| 144 | + |
| 145 | +def fasterrcnn_mobilenet_v3_large_fpn( |
| 146 | + weights: Optional[FasterRCNNMobileNetV3LargeFPNWeights] = None, |
| 147 | + weights_backbone: Optional[MobileNetV3LargeWeights] = None, |
| 148 | + progress: bool = True, |
| 149 | + num_classes: int = 91, |
| 150 | + trainable_backbone_layers: Optional[int] = None, |
| 151 | + **kwargs: Any, |
| 152 | +) -> FasterRCNN: |
| 153 | + if "pretrained" in kwargs: |
| 154 | + warnings.warn("The argument pretrained is deprecated, please use weights instead.") |
| 155 | + weights = FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None |
| 156 | + weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights) |
| 157 | + if "pretrained_backbone" in kwargs: |
| 158 | + warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") |
| 159 | + weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None |
| 160 | + weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) |
| 161 | + |
| 162 | + defaults = { |
| 163 | + "rpn_score_thresh": 0.05, |
| 164 | + } |
| 165 | + |
| 166 | + kwargs = {**defaults, **kwargs} |
| 167 | + return _fasterrcnn_mobilenet_v3_large_fpn( |
| 168 | + weights, |
| 169 | + weights_backbone, |
| 170 | + progress, |
| 171 | + num_classes, |
| 172 | + trainable_backbone_layers, |
| 173 | + **kwargs, |
| 174 | + ) |
| 175 | + |
| 176 | + |
| 177 | +def fasterrcnn_mobilenet_v3_large_320_fpn( |
| 178 | + weights: Optional[FasterRCNNMobileNetV3Large320FPNWeights] = None, |
| 179 | + weights_backbone: Optional[MobileNetV3LargeWeights] = None, |
| 180 | + progress: bool = True, |
| 181 | + num_classes: int = 91, |
| 182 | + trainable_backbone_layers: Optional[int] = None, |
| 183 | + **kwargs: Any, |
| 184 | +) -> FasterRCNN: |
| 185 | + if "pretrained" in kwargs: |
| 186 | + warnings.warn("The argument pretrained is deprecated, please use weights instead.") |
| 187 | + weights = FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None |
| 188 | + weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights) |
| 189 | + if "pretrained_backbone" in kwargs: |
| 190 | + warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") |
| 191 | + weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None |
| 192 | + weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) |
| 193 | + |
| 194 | + defaults = { |
| 195 | + "min_size": 320, |
| 196 | + "max_size": 640, |
| 197 | + "rpn_pre_nms_top_n_test": 150, |
| 198 | + "rpn_post_nms_top_n_test": 150, |
| 199 | + "rpn_score_thresh": 0.05, |
| 200 | + } |
| 201 | + |
| 202 | + kwargs = {**defaults, **kwargs} |
| 203 | + return _fasterrcnn_mobilenet_v3_large_fpn( |
| 204 | + weights, |
| 205 | + weights_backbone, |
| 206 | + progress, |
| 207 | + num_classes, |
| 208 | + trainable_backbone_layers, |
| 209 | + **kwargs, |
| 210 | + ) |
0 commit comments