4848 (inception_resnet_v2 .InceptionResNetV2 , 1536 ),
4949 (mobilenet .MobileNet , 1024 ),
5050 (mobilenet_v2 .MobileNetV2 , 1280 ),
51- (mobilenet_v3 .MobileNetV3Small , 1024 ),
52- (mobilenet_v3 .MobileNetV3Large , 1280 ),
51+ (mobilenet_v3 .MobileNetV3Small , 576 ),
52+ (mobilenet_v3 .MobileNetV3Large , 960 ),
5353 (densenet .DenseNet121 , 1024 ),
5454 (densenet .DenseNet169 , 1664 ),
5555 (densenet .DenseNet201 , 1920 ),
7070
7171MODEL_LIST = MODEL_LIST_NO_NASNET + NASNET_LIST
7272
73+ # Parameters for loading weights for MobileNetV3.
74+ # (class, alpha, minimalistic, include_top)
75+ MOBILENET_V3_FOR_WEIGHTS = [
76+ (mobilenet_v3 .MobileNetV3Large , 0.75 , False , False ),
77+ (mobilenet_v3 .MobileNetV3Large , 1.0 , False , False ),
78+ (mobilenet_v3 .MobileNetV3Large , 1.0 , True , False ),
79+ (mobilenet_v3 .MobileNetV3Large , 0.75 , False , True ),
80+ (mobilenet_v3 .MobileNetV3Large , 1.0 , False , True ),
81+ (mobilenet_v3 .MobileNetV3Large , 1.0 , True , True ),
82+ (mobilenet_v3 .MobileNetV3Small , 0.75 , False , False ),
83+ (mobilenet_v3 .MobileNetV3Small , 1.0 , False , False ),
84+ (mobilenet_v3 .MobileNetV3Small , 1.0 , True , False ),
85+ (mobilenet_v3 .MobileNetV3Small , 0.75 , False , True ),
86+ (mobilenet_v3 .MobileNetV3Small , 1.0 , False , True ),
87+ (mobilenet_v3 .MobileNetV3Small , 1.0 , True , True ),
88+ ]
89+
7390
7491class ApplicationsTest (tf .test .TestCase , parameterized .TestCase ):
7592
@@ -93,7 +110,7 @@ def test_application_base(self, app, _):
93110
94111 @parameterized .parameters (* MODEL_LIST )
95112 def test_application_notop (self , app , last_dim ):
96- if 'NASNet' or 'MobileNetV3' in app .__name__ :
113+ if 'NASNet' in app .__name__ :
97114 only_check_last_dim = True
98115 else :
99116 only_check_last_dim = False
@@ -119,10 +136,7 @@ def test_application_variable_input_channels(self, app, last_dim):
119136 input_shape = (None , None , 1 )
120137 output_shape = _get_output_shape (
121138 lambda : app (weights = None , include_top = False , input_shape = input_shape ))
122- if 'MobileNetV3' in app .__name__ :
123- self .assertShapeEqual (output_shape , (None , 1 , 1 , last_dim ))
124- else :
125- self .assertShapeEqual (output_shape , (None , None , None , last_dim ))
139+ self .assertShapeEqual (output_shape , (None , None , None , last_dim ))
126140 backend .clear_session ()
127141
128142 if backend .image_data_format () == 'channels_first' :
@@ -131,12 +145,23 @@ def test_application_variable_input_channels(self, app, last_dim):
131145 input_shape = (None , None , 4 )
132146 output_shape = _get_output_shape (
133147 lambda : app (weights = None , include_top = False , input_shape = input_shape ))
134- if 'MobileNetV3' in app .__name__ :
135- self .assertShapeEqual (output_shape , (None , 1 , 1 , last_dim ))
136- else :
137- self .assertShapeEqual (output_shape , (None , None , None , last_dim ))
148+ self .assertShapeEqual (output_shape , (None , None , None , last_dim ))
138149 backend .clear_session ()
139150
151+ @parameterized .parameters (* MOBILENET_V3_FOR_WEIGHTS )
152+ def test_mobilenet_v3_load_weights (
153+ self ,
154+ mobilenet_class ,
155+ alpha ,
156+ minimalistic ,
157+ include_top ):
158+ mobilenet_class (
159+ input_shape = (224 , 224 , 3 ),
160+ weights = 'imagenet' ,
161+ alpha = alpha ,
162+ minimalistic = minimalistic ,
163+ include_top = include_top )
164+
140165
141166def _get_output_shape (model_fn ):
142167 model = model_fn ()
0 commit comments