diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 8f2a96e2be1..790740fe9c5 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -383,15 +383,15 @@ def fasterrcnn_resnet50_fpn( Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 - ) + is_trained = pretrained or pretrained_backbone + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) + norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) model = FasterRCNN(backbone, num_classes, **kwargs) if pretrained: @@ -410,16 +410,14 @@ def _fasterrcnn_mobilenet_v3_large_fpn( trainable_backbone_layers=None, **kwargs, ): - trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3 - ) + is_trained = pretrained or pretrained_backbone + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3) + norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d if pretrained: pretrained_backbone = False - backbone = mobilenet_v3_large( - pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d - ) + backbone = mobilenet_v3_large(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers) anchor_sizes = ( diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 32a413e9cb1..c4c2e6f5842 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -686,15 +686,15 @@ def fcos_resnet50_fpn( from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. Default: None """ - trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 - ) + is_trained = pretrained or pretrained_backbone + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) + norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor( backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) ) diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 93e966bae4b..9f23e66e0c5 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -365,15 +365,15 @@ def keypointrcnn_resnet50_fpn( Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 - ) + is_trained = pretrained or pretrained_backbone + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) + norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) if pretrained: diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index f4278cfb502..37f88116c5e 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -360,15 +360,15 @@ def maskrcnn_resnet50_fpn( Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 - ) + is_trained = pretrained or pretrained_backbone + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) + norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) model = MaskRCNN(backbone, num_classes, **kwargs) if pretrained: diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 1909f6a8b73..4f79b5ddbfc 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -626,15 +626,15 @@ def retinanet_resnet50_fpn( Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 - ) + is_trained = pretrained or pretrained_backbone + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) + norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) # skip P2 because it generates too many anchors (according to their paper) backbone = _resnet_fpn_extractor( backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index 764cc3fe042..4fbbbc0c1e8 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -1,5 +1,6 @@ from typing import Any, Optional, Union +from torch import nn from torchvision.prototype.transforms import CocoEval from torchvision.transforms.functional import InterpolationMode @@ -103,11 +104,11 @@ def fasterrcnn_resnet50_fpn( elif num_classes is None: num_classes = 91 - trainable_backbone_layers = _validate_trainable_layers( - weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 - ) + is_trained = weights is not None or weights_backbone is not None + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) + norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) model = FasterRCNN(backbone, num_classes=num_classes, **kwargs) @@ -134,11 +135,11 @@ def _fasterrcnn_mobilenet_v3_large_fpn( elif num_classes is None: num_classes = 91 - trainable_backbone_layers = _validate_trainable_layers( - weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 3 - ) + is_trained = weights is not None or weights_backbone is not None + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3) + norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers) anchor_sizes = ( ( diff --git a/torchvision/prototype/models/detection/fcos.py b/torchvision/prototype/models/detection/fcos.py index cf3007290a8..faa181b60b0 100644 --- a/torchvision/prototype/models/detection/fcos.py +++ b/torchvision/prototype/models/detection/fcos.py @@ -1,5 +1,6 @@ from typing import Any, Optional +from torch import nn from torchvision.prototype.transforms import CocoEval from torchvision.transforms.functional import InterpolationMode @@ -63,11 +64,11 @@ def fcos_resnet50_fpn( elif num_classes is None: num_classes = 91 - trainable_backbone_layers = _validate_trainable_layers( - weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 - ) + is_trained = weights is not None or weights_backbone is not None + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) + norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor( backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) ) diff --git a/torchvision/prototype/models/detection/keypoint_rcnn.py b/torchvision/prototype/models/detection/keypoint_rcnn.py index 976cccadd39..c10d761fa26 100644 --- a/torchvision/prototype/models/detection/keypoint_rcnn.py +++ b/torchvision/prototype/models/detection/keypoint_rcnn.py @@ -1,5 +1,6 @@ from typing import Any, Optional +from torch import nn from torchvision.prototype.transforms import CocoEval from torchvision.transforms.functional import InterpolationMode @@ -91,11 +92,11 @@ def keypointrcnn_resnet50_fpn( if num_keypoints is None: num_keypoints = 17 - trainable_backbone_layers = _validate_trainable_layers( - weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 - ) + is_trained = weights is not None or weights_backbone is not None + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) + norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) diff --git a/torchvision/prototype/models/detection/mask_rcnn.py b/torchvision/prototype/models/detection/mask_rcnn.py index af67f21c3e1..3e438dab160 100644 --- a/torchvision/prototype/models/detection/mask_rcnn.py +++ b/torchvision/prototype/models/detection/mask_rcnn.py @@ -1,5 +1,6 @@ from typing import Any, Optional +from torch import nn from torchvision.prototype.transforms import CocoEval from torchvision.transforms.functional import InterpolationMode @@ -64,11 +65,11 @@ def maskrcnn_resnet50_fpn( elif num_classes is None: num_classes = 91 - trainable_backbone_layers = _validate_trainable_layers( - weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 - ) + is_trained = weights is not None or weights_backbone is not None + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) + norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) model = MaskRCNN(backbone, num_classes=num_classes, **kwargs) diff --git a/torchvision/prototype/models/detection/retinanet.py b/torchvision/prototype/models/detection/retinanet.py index b0c02b1e30c..b819150ade0 100644 --- a/torchvision/prototype/models/detection/retinanet.py +++ b/torchvision/prototype/models/detection/retinanet.py @@ -1,5 +1,6 @@ from typing import Any, Optional +from torch import nn from torchvision.prototype.transforms import CocoEval from torchvision.transforms.functional import InterpolationMode @@ -64,11 +65,11 @@ def retinanet_resnet50_fpn( elif num_classes is None: num_classes = 91 - trainable_backbone_layers = _validate_trainable_layers( - weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 - ) + is_trained = weights is not None or weights_backbone is not None + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) + norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) # skip P2 because it generates too many anchors (according to their paper) backbone = _resnet_fpn_extractor( backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)