48
48
(inception_resnet_v2 .InceptionResNetV2 , 1536 ),
49
49
(mobilenet .MobileNet , 1024 ),
50
50
(mobilenet_v2 .MobileNetV2 , 1280 ),
51
- (mobilenet_v3 .MobileNetV3Small , 1024 ),
52
- (mobilenet_v3 .MobileNetV3Large , 1280 ),
51
+ (mobilenet_v3 .MobileNetV3Small , 576 ),
52
+ (mobilenet_v3 .MobileNetV3Large , 960 ),
53
53
(densenet .DenseNet121 , 1024 ),
54
54
(densenet .DenseNet169 , 1664 ),
55
55
(densenet .DenseNet201 , 1920 ),
70
70
71
71
MODEL_LIST = MODEL_LIST_NO_NASNET + NASNET_LIST
72
72
73
+ # Parameters for loading weights for MobileNetV3.
74
+ # (class, alpha, minimalistic, include_top)
75
+ MOBILENET_V3_FOR_WEIGHTS = [
76
+ (mobilenet_v3 .MobileNetV3Large , 0.75 , False , False ),
77
+ (mobilenet_v3 .MobileNetV3Large , 1.0 , False , False ),
78
+ (mobilenet_v3 .MobileNetV3Large , 1.0 , True , False ),
79
+ (mobilenet_v3 .MobileNetV3Large , 0.75 , False , True ),
80
+ (mobilenet_v3 .MobileNetV3Large , 1.0 , False , True ),
81
+ (mobilenet_v3 .MobileNetV3Large , 1.0 , True , True ),
82
+ (mobilenet_v3 .MobileNetV3Small , 0.75 , False , False ),
83
+ (mobilenet_v3 .MobileNetV3Small , 1.0 , False , False ),
84
+ (mobilenet_v3 .MobileNetV3Small , 1.0 , True , False ),
85
+ (mobilenet_v3 .MobileNetV3Small , 0.75 , False , True ),
86
+ (mobilenet_v3 .MobileNetV3Small , 1.0 , False , True ),
87
+ (mobilenet_v3 .MobileNetV3Small , 1.0 , True , True ),
88
+ ]
89
+
73
90
74
91
class ApplicationsTest (tf .test .TestCase , parameterized .TestCase ):
75
92
@@ -93,7 +110,7 @@ def test_application_base(self, app, _):
93
110
94
111
@parameterized .parameters (* MODEL_LIST )
95
112
def test_application_notop (self , app , last_dim ):
96
- if 'NASNet' or 'MobileNetV3' in app .__name__ :
113
+ if 'NASNet' in app .__name__ :
97
114
only_check_last_dim = True
98
115
else :
99
116
only_check_last_dim = False
@@ -119,10 +136,7 @@ def test_application_variable_input_channels(self, app, last_dim):
119
136
input_shape = (None , None , 1 )
120
137
output_shape = _get_output_shape (
121
138
lambda : app (weights = None , include_top = False , input_shape = input_shape ))
122
- if 'MobileNetV3' in app .__name__ :
123
- self .assertShapeEqual (output_shape , (None , 1 , 1 , last_dim ))
124
- else :
125
- self .assertShapeEqual (output_shape , (None , None , None , last_dim ))
139
+ self .assertShapeEqual (output_shape , (None , None , None , last_dim ))
126
140
backend .clear_session ()
127
141
128
142
if backend .image_data_format () == 'channels_first' :
@@ -131,12 +145,23 @@ def test_application_variable_input_channels(self, app, last_dim):
131
145
input_shape = (None , None , 4 )
132
146
output_shape = _get_output_shape (
133
147
lambda : app (weights = None , include_top = False , input_shape = input_shape ))
134
- if 'MobileNetV3' in app .__name__ :
135
- self .assertShapeEqual (output_shape , (None , 1 , 1 , last_dim ))
136
- else :
137
- self .assertShapeEqual (output_shape , (None , None , None , last_dim ))
148
+ self .assertShapeEqual (output_shape , (None , None , None , last_dim ))
138
149
backend .clear_session ()
139
150
151
+ @parameterized .parameters (* MOBILENET_V3_FOR_WEIGHTS )
152
+ def test_mobilenet_v3_load_weights (
153
+ self ,
154
+ mobilenet_class ,
155
+ alpha ,
156
+ minimalistic ,
157
+ include_top ):
158
+ mobilenet_class (
159
+ input_shape = (224 , 224 , 3 ),
160
+ weights = 'imagenet' ,
161
+ alpha = alpha ,
162
+ minimalistic = minimalistic ,
163
+ include_top = include_top )
164
+
140
165
141
166
def _get_output_shape (model_fn ):
142
167
model = model_fn ()
0 commit comments