@@ -31,34 +31,39 @@ def _make_divisible(v, divisor, min_value=None):
31
31
32
32
33
33
class ConvBNReLU (nn .Sequential ):
34
- def __init__ (self , in_planes , out_planes , kernel_size = 3 , stride = 1 , groups = 1 ):
34
+ def __init__ (self , in_planes , out_planes , kernel_size = 3 , stride = 1 , groups = 1 , norm_layer = None ):
35
35
padding = (kernel_size - 1 ) // 2
36
+ if norm_layer is None :
37
+ norm_layer = nn .BatchNorm2d
36
38
super (ConvBNReLU , self ).__init__ (
37
39
nn .Conv2d (in_planes , out_planes , kernel_size , stride , padding , groups = groups , bias = False ),
38
- nn . BatchNorm2d (out_planes ),
40
+ norm_layer (out_planes ),
39
41
nn .ReLU6 (inplace = True )
40
42
)
41
43
42
44
43
45
class InvertedResidual (nn .Module ):
44
- def __init__ (self , inp , oup , stride , expand_ratio ):
46
+ def __init__ (self , inp , oup , stride , expand_ratio , norm_layer = None ):
45
47
super (InvertedResidual , self ).__init__ ()
46
48
self .stride = stride
47
49
assert stride in [1 , 2 ]
48
50
51
+ if norm_layer is None :
52
+ norm_layer = nn .BatchNorm2d
53
+
49
54
hidden_dim = int (round (inp * expand_ratio ))
50
55
self .use_res_connect = self .stride == 1 and inp == oup
51
56
52
57
layers = []
53
58
if expand_ratio != 1 :
54
59
# pw
55
- layers .append (ConvBNReLU (inp , hidden_dim , kernel_size = 1 ))
60
+ layers .append (ConvBNReLU (inp , hidden_dim , kernel_size = 1 , norm_layer = norm_layer ))
56
61
layers .extend ([
57
62
# dw
58
- ConvBNReLU (hidden_dim , hidden_dim , stride = stride , groups = hidden_dim ),
63
+ ConvBNReLU (hidden_dim , hidden_dim , stride = stride , groups = hidden_dim , norm_layer = norm_layer ),
59
64
# pw-linear
60
65
nn .Conv2d (hidden_dim , oup , 1 , 1 , 0 , bias = False ),
61
- nn . BatchNorm2d (oup ),
66
+ norm_layer (oup ),
62
67
])
63
68
self .conv = nn .Sequential (* layers )
64
69
@@ -75,7 +80,8 @@ def __init__(self,
75
80
width_mult = 1.0 ,
76
81
inverted_residual_setting = None ,
77
82
round_nearest = 8 ,
78
- block = None ):
83
+ block = None ,
84
+ norm_layer = None ):
79
85
"""
80
86
MobileNet V2 main class
81
87
@@ -86,12 +92,17 @@ def __init__(self,
86
92
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
87
93
Set to 1 to turn off rounding
88
94
block: Module specifying inverted residual building block for mobilenet
95
+ norm_layer: Module specifying the normalization layer to use
89
96
90
97
"""
91
98
super (MobileNetV2 , self ).__init__ ()
92
99
93
100
if block is None :
94
101
block = InvertedResidual
102
+
103
+ if norm_layer is None :
104
+ norm_layer = nn .BatchNorm2d
105
+
95
106
input_channel = 32
96
107
last_channel = 1280
97
108
@@ -115,16 +126,16 @@ def __init__(self,
115
126
# building first layer
116
127
input_channel = _make_divisible (input_channel * width_mult , round_nearest )
117
128
self .last_channel = _make_divisible (last_channel * max (1.0 , width_mult ), round_nearest )
118
- features = [ConvBNReLU (3 , input_channel , stride = 2 )]
129
+ features = [ConvBNReLU (3 , input_channel , stride = 2 , norm_layer = norm_layer )]
119
130
# building inverted residual blocks
120
131
for t , c , n , s in inverted_residual_setting :
121
132
output_channel = _make_divisible (c * width_mult , round_nearest )
122
133
for i in range (n ):
123
134
stride = s if i == 0 else 1
124
- features .append (block (input_channel , output_channel , stride , expand_ratio = t ))
135
+ features .append (block (input_channel , output_channel , stride , expand_ratio = t , norm_layer = norm_layer ))
125
136
input_channel = output_channel
126
137
# building last several layers
127
- features .append (ConvBNReLU (input_channel , self .last_channel , kernel_size = 1 ))
138
+ features .append (ConvBNReLU (input_channel , self .last_channel , kernel_size = 1 , norm_layer = norm_layer ))
128
139
# make it nn.Sequential
129
140
self .features = nn .Sequential (* features )
130
141
@@ -140,7 +151,7 @@ def __init__(self,
140
151
nn .init .kaiming_normal_ (m .weight , mode = 'fan_out' )
141
152
if m .bias is not None :
142
153
nn .init .zeros_ (m .bias )
143
- elif isinstance (m , nn .BatchNorm2d ):
154
+ elif isinstance (m , ( nn .BatchNorm2d , nn . GroupNorm ) ):
144
155
nn .init .ones_ (m .weight )
145
156
nn .init .zeros_ (m .bias )
146
157
elif isinstance (m , nn .Linear ):
0 commit comments