Skip to content

Commit 8348c3a

Browse files
committed
Minor refactoring of a private method to make it reusuable.
1 parent 7bf6e7b commit 8348c3a

File tree

6 files changed

+25
-29
lines changed

6 files changed

+25
-29
lines changed

test/test_models_detection_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,17 @@ def test_resnet_fpn_backbone_frozen_layers(self):
3636

3737
def test_validate_resnet_inputs_detection(self):
3838
# default number of backbone layers to train
39-
ret = backbone_utils._validate_resnet_trainable_layers(
40-
pretrained=True, trainable_backbone_layers=None)
39+
ret = backbone_utils._validate_trainable_layers(
40+
pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3)
4141
self.assertEqual(ret, 3)
4242
# can't go beyond 5
4343
with self.assertRaises(AssertionError):
44-
ret = backbone_utils._validate_resnet_trainable_layers(
45-
pretrained=True, trainable_backbone_layers=6)
44+
ret = backbone_utils._validate_trainable_layers(
45+
pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3)
4646
# if not pretrained, should use all trainable layers and warn
4747
with self.assertWarns(UserWarning):
48-
ret = backbone_utils._validate_resnet_trainable_layers(
49-
pretrained=False, trainable_backbone_layers=0)
48+
ret = backbone_utils._validate_trainable_layers(
49+
pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3)
5050
self.assertEqual(ret, 5)
5151

5252
def test_transform_copy_targets(self):

torchvision/models/detection/backbone_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import warnings
2-
from collections import OrderedDict
32
from torch import nn
43
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
54

@@ -108,17 +107,18 @@ def resnet_fpn_backbone(
108107
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
109108

110109

111-
def _validate_resnet_trainable_layers(pretrained, trainable_backbone_layers):
110+
def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, default_value):
112111
# dont freeze any layers if pretrained model or backbone is not used
113112
if not pretrained:
114113
if trainable_backbone_layers is not None:
115114
warnings.warn(
116115
"Changing trainable_backbone_layers has not effect if "
117116
"neither pretrained nor pretrained_backbone have been set to True, "
118-
"falling back to trainable_backbone_layers=5 so that all layers are trainable")
119-
trainable_backbone_layers = 5
120-
# by default, freeze first 2 blocks following Faster R-CNN
117+
"falling back to trainable_backbone_layers={} so that all layers are trainable".format(max_value))
118+
trainable_backbone_layers = max_value
119+
120+
# by default freeze first blocks
121121
if trainable_backbone_layers is None:
122-
trainable_backbone_layers = 3
123-
assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0
122+
trainable_backbone_layers = default_value
123+
assert 0 <= trainable_backbone_layers <= max_value
124124
return trainable_backbone_layers

torchvision/models/detection/faster_rcnn.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .rpn import RPNHead, RegionProposalNetwork
1616
from .roi_heads import RoIHeads
1717
from .transform import GeneralizedRCNNTransform
18-
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
18+
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
1919

2020

2121
__all__ = [
@@ -353,9 +353,8 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
353353
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
354354
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
355355
"""
356-
# check default parameters and by default set it to 3 if possible
357-
trainable_backbone_layers = _validate_resnet_trainable_layers(
358-
pretrained or pretrained_backbone, trainable_backbone_layers)
356+
trainable_backbone_layers = _validate_trainable_layers(
357+
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
359358

360359
if pretrained:
361360
# no need to download the backbone if pretrained is set

torchvision/models/detection/keypoint_rcnn.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ..utils import load_state_dict_from_url
88

99
from .faster_rcnn import FasterRCNN
10-
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
10+
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
1111

1212

1313
__all__ = [
@@ -322,9 +322,8 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
322322
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
323323
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
324324
"""
325-
# check default parameters and by default set it to 3 if possible
326-
trainable_backbone_layers = _validate_resnet_trainable_layers(
327-
pretrained or pretrained_backbone, trainable_backbone_layers)
325+
trainable_backbone_layers = _validate_trainable_layers(
326+
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
328327

329328
if pretrained:
330329
# no need to download the backbone if pretrained is set

torchvision/models/detection/mask_rcnn.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ..utils import load_state_dict_from_url
99

1010
from .faster_rcnn import FasterRCNN
11-
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
11+
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
1212

1313
__all__ = [
1414
"MaskRCNN", "maskrcnn_resnet50_fpn",
@@ -317,9 +317,8 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
317317
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
318318
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
319319
"""
320-
# check default parameters and by default set it to 3 if possible
321-
trainable_backbone_layers = _validate_resnet_trainable_layers(
322-
pretrained or pretrained_backbone, trainable_backbone_layers)
320+
trainable_backbone_layers = _validate_trainable_layers(
321+
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
323322

324323
if pretrained:
325324
# no need to download the backbone if pretrained is set

torchvision/models/detection/retinanet.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from . import _utils as det_utils
1313
from .anchor_utils import AnchorGenerator
1414
from .transform import GeneralizedRCNNTransform
15-
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
15+
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
1616
from ...ops.feature_pyramid_network import LastLevelP6P7
1717
from ...ops import sigmoid_focal_loss
1818
from ...ops import boxes as box_ops
@@ -605,9 +605,8 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,
605605
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
606606
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
607607
"""
608-
# check default parameters and by default set it to 3 if possible
609-
trainable_backbone_layers = _validate_resnet_trainable_layers(
610-
pretrained or pretrained_backbone, trainable_backbone_layers)
608+
trainable_backbone_layers = _validate_trainable_layers(
609+
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
611610

612611
if pretrained:
613612
# no need to download the backbone if pretrained is set

0 commit comments

Comments
 (0)