1
1
import warnings
2
- from typing import List , Optional
2
+ from typing import Callable , Dict , Optional , List
3
3
4
- from torch import nn
4
+ from torch import nn , Tensor
5
5
from torchvision .ops import misc as misc_nn_ops
6
6
from torchvision .ops .feature_pyramid_network import FeaturePyramidNetwork , LastLevelMaxPool , ExtraFPNBlock
7
7
@@ -29,7 +29,14 @@ class BackboneWithFPN(nn.Module):
29
29
out_channels (int): the number of channels in the FPN
30
30
"""
31
31
32
- def __init__ (self , backbone , return_layers , in_channels_list , out_channels , extra_blocks = None ):
32
+ def __init__ (
33
+ self ,
34
+ backbone : nn .Module ,
35
+ return_layers : Dict [str , str ],
36
+ in_channels_list : List [int ],
37
+ out_channels : int ,
38
+ extra_blocks : Optional [ExtraFPNBlock ] = None ,
39
+ ) -> None :
33
40
super (BackboneWithFPN , self ).__init__ ()
34
41
35
42
if extra_blocks is None :
@@ -43,20 +50,20 @@ def __init__(self, backbone, return_layers, in_channels_list, out_channels, extr
43
50
)
44
51
self .out_channels = out_channels
45
52
46
- def forward (self , x ) :
53
+ def forward (self , x : Tensor ) -> Dict [ str , Tensor ] :
47
54
x = self .body (x )
48
55
x = self .fpn (x )
49
56
return x
50
57
51
58
52
59
def resnet_fpn_backbone (
53
- backbone_name ,
54
- pretrained ,
55
- norm_layer = misc_nn_ops .FrozenBatchNorm2d ,
56
- trainable_layers = 3 ,
57
- returned_layers = None ,
58
- extra_blocks = None ,
59
- ):
60
+ backbone_name : str ,
61
+ pretrained : bool ,
62
+ norm_layer : Callable [..., nn . Module ] = misc_nn_ops .FrozenBatchNorm2d ,
63
+ trainable_layers : int = 3 ,
64
+ returned_layers : Optional [ List [ int ]] = None ,
65
+ extra_blocks : Optional [ ExtraFPNBlock ] = None ,
66
+ ) -> BackboneWithFPN :
60
67
"""
61
68
Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
62
69
@@ -80,7 +87,7 @@ def resnet_fpn_backbone(
80
87
backbone_name (string): resnet architecture. Possible values are 'ResNet', 'resnet18', 'resnet34', 'resnet50',
81
88
'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
82
89
pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet
83
- norm_layer (torchvision.ops ): it is recommended to use the default value. For details visit:
90
+ norm_layer (callable ): it is recommended to use the default value. For details visit:
84
91
(https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
85
92
trainable_layers (int): number of trainable (not frozen) resnet layers starting from final block.
86
93
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
@@ -101,7 +108,8 @@ def _resnet_backbone_config(
101
108
trainable_layers : int ,
102
109
returned_layers : Optional [List [int ]],
103
110
extra_blocks : Optional [ExtraFPNBlock ],
104
- ):
111
+ ) -> BackboneWithFPN :
112
+
105
113
# select layers that wont be frozen
106
114
assert 0 <= trainable_layers <= 5
107
115
layers_to_train = ["layer4" , "layer3" , "layer2" , "layer1" , "conv1" ][:trainable_layers ]
@@ -125,8 +133,13 @@ def _resnet_backbone_config(
125
133
return BackboneWithFPN (backbone , return_layers , in_channels_list , out_channels , extra_blocks = extra_blocks )
126
134
127
135
128
- def _validate_trainable_layers (pretrained , trainable_backbone_layers , max_value , default_value ):
129
- # dont freeze any layers if pretrained model or backbone is not used
136
+ def _validate_trainable_layers (
137
+ pretrained : bool ,
138
+ trainable_backbone_layers : Optional [int ],
139
+ max_value : int ,
140
+ default_value : int ,
141
+ ) -> int :
142
+ # don't freeze any layers if pretrained model or backbone is not used
130
143
if not pretrained :
131
144
if trainable_backbone_layers is not None :
132
145
warnings .warn (
@@ -144,14 +157,15 @@ def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value,
144
157
145
158
146
159
def mobilenet_backbone (
147
- backbone_name ,
148
- pretrained ,
149
- fpn ,
150
- norm_layer = misc_nn_ops .FrozenBatchNorm2d ,
151
- trainable_layers = 2 ,
152
- returned_layers = None ,
153
- extra_blocks = None ,
154
- ):
160
+ backbone_name : str ,
161
+ pretrained : bool ,
162
+ fpn : bool ,
163
+ norm_layer : Callable [..., nn .Module ] = misc_nn_ops .FrozenBatchNorm2d ,
164
+ trainable_layers : int = 2 ,
165
+ returned_layers : Optional [List [int ]] = None ,
166
+ extra_blocks : Optional [ExtraFPNBlock ] = None ,
167
+ ) -> nn .Module :
168
+
155
169
backbone = mobilenet .__dict__ [backbone_name ](pretrained = pretrained , norm_layer = norm_layer ).features
156
170
157
171
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
@@ -185,5 +199,5 @@ def mobilenet_backbone(
185
199
# depthwise linear combination of channels to reduce their size
186
200
nn .Conv2d (backbone [- 1 ].out_channels , out_channels , 1 ),
187
201
)
188
- m .out_channels = out_channels
202
+ m .out_channels = out_channels # type: ignore[assignment]
189
203
return m
0 commit comments