|
15 | 15 | from .rpn import RPNHead, RegionProposalNetwork
|
16 | 16 | from .roi_heads import RoIHeads
|
17 | 17 | from .transform import GeneralizedRCNNTransform
|
18 |
| -from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers |
| 18 | +from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone |
19 | 19 |
|
20 | 20 |
|
21 | 21 | __all__ = [
|
22 |
| - "FasterRCNN", "fasterrcnn_resnet50_fpn", |
| 22 | + "FasterRCNN", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large", "fasterrcnn_mobilenet_v3_large_fpn" |
23 | 23 | ]
|
24 | 24 |
|
25 | 25 |
|
@@ -291,6 +291,8 @@ def forward(self, x):
|
291 | 291 | model_urls = {
|
292 | 292 | 'fasterrcnn_resnet50_fpn_coco':
|
293 | 293 | 'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth',
|
| 294 | + 'fasterrcnn_mobilenet_v3_large_coco': None, |
| 295 | + 'fasterrcnn_mobilenet_v3_large_fpn_coco': None, |
294 | 296 | }
|
295 | 297 |
|
296 | 298 |
|
@@ -367,3 +369,83 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
|
367 | 369 | model.load_state_dict(state_dict)
|
368 | 370 | overwrite_eps(model, 0.0)
|
369 | 371 | return model
|
| 372 | + |
| 373 | + |
| 374 | +def fasterrcnn_mobilenet_v3_large(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, |
| 375 | + trainable_backbone_layers=None, **kwargs): |
| 376 | + """ |
| 377 | + Constructs a Faster R-CNN model with a MobileNetV3-Large backbone. It works similarly |
| 378 | + to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details. |
| 379 | +
|
| 380 | + Example:: |
| 381 | +
|
| 382 | + >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large(pretrained=True) |
| 383 | + >>> model.eval() |
| 384 | + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] |
| 385 | + >>> predictions = model(x) |
| 386 | +
|
| 387 | + Args: |
| 388 | + pretrained (bool): If True, returns a model pre-trained on COCO train2017 |
| 389 | + progress (bool): If True, displays a progress bar of the download to stderr |
| 390 | + num_classes (int): number of output classes of the model (including the background) |
| 391 | + pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet |
| 392 | + trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. |
| 393 | + Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. |
| 394 | + """ |
| 395 | + trainable_backbone_layers = _validate_trainable_layers( |
| 396 | + pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3) |
| 397 | + |
| 398 | + if pretrained: |
| 399 | + pretrained_backbone = False |
| 400 | + backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, False, |
| 401 | + trainable_layers=trainable_backbone_layers) |
| 402 | + |
| 403 | + anchor_sizes = ((32, 64, 128, 256, 512), ) |
| 404 | + aspect_ratios = ((0.5, 1.0, 2.0), ) |
| 405 | + |
| 406 | + model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), |
| 407 | + **kwargs) |
| 408 | + if pretrained: |
| 409 | + state_dict = load_state_dict_from_url(model_urls['fasterrcnn_mobilenet_v3_large_coco'], progress=progress) |
| 410 | + model.load_state_dict(state_dict) |
| 411 | + return model |
| 412 | + |
| 413 | + |
| 414 | +def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, |
| 415 | + trainable_backbone_layers=None, **kwargs): |
| 416 | + """ |
| 417 | + Constructs a Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly |
| 418 | + to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details. |
| 419 | +
|
| 420 | + Example:: |
| 421 | +
|
| 422 | + >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True) |
| 423 | + >>> model.eval() |
| 424 | + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] |
| 425 | + >>> predictions = model(x) |
| 426 | +
|
| 427 | + Args: |
| 428 | + pretrained (bool): If True, returns a model pre-trained on COCO train2017 |
| 429 | + progress (bool): If True, displays a progress bar of the download to stderr |
| 430 | + num_classes (int): number of output classes of the model (including the background) |
| 431 | + pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet |
| 432 | + trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. |
| 433 | + Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. |
| 434 | + """ |
| 435 | + trainable_backbone_layers = _validate_trainable_layers( |
| 436 | + pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3) |
| 437 | + |
| 438 | + if pretrained: |
| 439 | + pretrained_backbone = False |
| 440 | + backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, True, |
| 441 | + trainable_layers=trainable_backbone_layers) |
| 442 | + |
| 443 | + anchor_sizes = ((32, 64, 128, 256, 512, ), ) * 3 |
| 444 | + aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) |
| 445 | + |
| 446 | + model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), |
| 447 | + **kwargs) |
| 448 | + if pretrained: |
| 449 | + state_dict = load_state_dict_from_url(model_urls['fasterrcnn_mobilenet_v3_large_fpn_coco'], progress=progress) |
| 450 | + model.load_state_dict(state_dict) |
| 451 | + return model |
0 commit comments