Skip to content

Commit 466fccb

Browse files
haifeng-jintensorflower-gardener
authored andcommitted
Update MobileNetV3 architecture, move some layers from body to head, which only appears when include_top=True.
Pretrained weights are updated and tested. PiperOrigin-RevId: 401845877
1 parent a4b2c80 commit 466fccb

File tree

2 files changed

+52
-27
lines changed

2 files changed

+52
-27
lines changed

keras/applications/applications_test.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@
4848
(inception_resnet_v2.InceptionResNetV2, 1536),
4949
(mobilenet.MobileNet, 1024),
5050
(mobilenet_v2.MobileNetV2, 1280),
51-
(mobilenet_v3.MobileNetV3Small, 1024),
52-
(mobilenet_v3.MobileNetV3Large, 1280),
51+
(mobilenet_v3.MobileNetV3Small, 576),
52+
(mobilenet_v3.MobileNetV3Large, 960),
5353
(densenet.DenseNet121, 1024),
5454
(densenet.DenseNet169, 1664),
5555
(densenet.DenseNet201, 1920),
@@ -70,6 +70,23 @@
7070

7171
MODEL_LIST = MODEL_LIST_NO_NASNET + NASNET_LIST
7272

73+
# Parameters for loading weights for MobileNetV3.
74+
# (class, alpha, minimalistic, include_top)
75+
MOBILENET_V3_FOR_WEIGHTS = [
76+
(mobilenet_v3.MobileNetV3Large, 0.75, False, False),
77+
(mobilenet_v3.MobileNetV3Large, 1.0, False, False),
78+
(mobilenet_v3.MobileNetV3Large, 1.0, True, False),
79+
(mobilenet_v3.MobileNetV3Large, 0.75, False, True),
80+
(mobilenet_v3.MobileNetV3Large, 1.0, False, True),
81+
(mobilenet_v3.MobileNetV3Large, 1.0, True, True),
82+
(mobilenet_v3.MobileNetV3Small, 0.75, False, False),
83+
(mobilenet_v3.MobileNetV3Small, 1.0, False, False),
84+
(mobilenet_v3.MobileNetV3Small, 1.0, True, False),
85+
(mobilenet_v3.MobileNetV3Small, 0.75, False, True),
86+
(mobilenet_v3.MobileNetV3Small, 1.0, False, True),
87+
(mobilenet_v3.MobileNetV3Small, 1.0, True, True),
88+
]
89+
7390

7491
class ApplicationsTest(tf.test.TestCase, parameterized.TestCase):
7592

@@ -93,7 +110,7 @@ def test_application_base(self, app, _):
93110

94111
@parameterized.parameters(*MODEL_LIST)
95112
def test_application_notop(self, app, last_dim):
96-
if 'NASNet' or 'MobileNetV3' in app.__name__:
113+
if 'NASNet' in app.__name__:
97114
only_check_last_dim = True
98115
else:
99116
only_check_last_dim = False
@@ -119,10 +136,7 @@ def test_application_variable_input_channels(self, app, last_dim):
119136
input_shape = (None, None, 1)
120137
output_shape = _get_output_shape(
121138
lambda: app(weights=None, include_top=False, input_shape=input_shape))
122-
if 'MobileNetV3' in app.__name__:
123-
self.assertShapeEqual(output_shape, (None, 1, 1, last_dim))
124-
else:
125-
self.assertShapeEqual(output_shape, (None, None, None, last_dim))
139+
self.assertShapeEqual(output_shape, (None, None, None, last_dim))
126140
backend.clear_session()
127141

128142
if backend.image_data_format() == 'channels_first':
@@ -131,12 +145,23 @@ def test_application_variable_input_channels(self, app, last_dim):
131145
input_shape = (None, None, 4)
132146
output_shape = _get_output_shape(
133147
lambda: app(weights=None, include_top=False, input_shape=input_shape))
134-
if 'MobileNetV3' in app.__name__:
135-
self.assertShapeEqual(output_shape, (None, 1, 1, last_dim))
136-
else:
137-
self.assertShapeEqual(output_shape, (None, None, None, last_dim))
148+
self.assertShapeEqual(output_shape, (None, None, None, last_dim))
138149
backend.clear_session()
139150

151+
@parameterized.parameters(*MOBILENET_V3_FOR_WEIGHTS)
152+
def test_mobilenet_v3_load_weights(
153+
self,
154+
mobilenet_class,
155+
alpha,
156+
minimalistic,
157+
include_top):
158+
mobilenet_class(
159+
input_shape=(224, 224, 3),
160+
weights='imagenet',
161+
alpha=alpha,
162+
minimalistic=minimalistic,
163+
include_top=include_top)
164+
140165

141166
def _get_output_shape(model_fn):
142167
model = model_fn()

keras/applications/mobilenet_v3.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,17 @@
3333
'keras-applications/mobilenet_v3/')
3434
WEIGHTS_HASHES = {
3535
'large_224_0.75_float': ('765b44a33ad4005b3ac83185abf1d0eb',
36-
'e7b4d1071996dd51a2c2ca2424570e20'),
36+
'40af19a13ebea4e2ee0c676887f69a2e'),
3737
'large_224_1.0_float': ('59e551e166be033d707958cf9e29a6a7',
38-
'037116398e07f018c0005ffcb0406831'),
38+
'07fb09a5933dd0c8eaafa16978110389'),
3939
'large_minimalistic_224_1.0_float': ('675e7b876c45c57e9e63e6d90a36599c',
40-
'a2c33aed672524d1d0b4431808177695'),
40+
'ec5221f64a2f6d1ef965a614bdae7973'),
4141
'small_224_0.75_float': ('cb65d4e5be93758266aa0a7f2c6708b7',
42-
'4d2fe46f1c1f38057392514b0df1d673'),
42+
'ebdb5cc8e0b497cd13a7c275d475c819'),
4343
'small_224_1.0_float': ('8768d4c2e7dee89b9d02b2d03d65d862',
44-
'be7100780f875c06bcab93d76641aa26'),
44+
'd3e8ec802a04aa4fc771ee12a9a9b836'),
4545
'small_minimalistic_224_1.0_float': ('99cd97fb2fcdad2bf028eb838de69e37',
46-
'20d4e357df3f7a6361f3a288857b1051'),
46+
'cde8136e733e811080d9fcd8a252f7e4'),
4747
}
4848

4949
layers = VersionAwareLayers()
@@ -310,16 +310,16 @@ def MobileNetV3(stack_fn,
310310
axis=channel_axis, epsilon=1e-3,
311311
momentum=0.999, name='Conv_1/BatchNorm')(x)
312312
x = activation(x)
313-
x = layers.GlobalAveragePooling2D(keepdims=True)(x)
314-
x = layers.Conv2D(
315-
last_point_ch,
316-
kernel_size=1,
317-
padding='same',
318-
use_bias=True,
319-
name='Conv_2')(x)
320-
x = activation(x)
321-
322313
if include_top:
314+
x = layers.GlobalAveragePooling2D(keepdims=True)(x)
315+
x = layers.Conv2D(
316+
last_point_ch,
317+
kernel_size=1,
318+
padding='same',
319+
use_bias=True,
320+
name='Conv_2')(x)
321+
x = activation(x)
322+
323323
if dropout_rate > 0:
324324
x = layers.Dropout(dropout_rate)(x)
325325
x = layers.Conv2D(classes, kernel_size=1, padding='same', name='Logits')(x)
@@ -350,7 +350,7 @@ def MobileNetV3(stack_fn,
350350
file_name = 'weights_mobilenet_v3_' + model_name + '.h5'
351351
file_hash = WEIGHTS_HASHES[model_name][0]
352352
else:
353-
file_name = 'weights_mobilenet_v3_' + model_name + '_no_top.h5'
353+
file_name = 'weights_mobilenet_v3_' + model_name + '_no_top_v2.h5'
354354
file_hash = WEIGHTS_HASHES[model_name][1]
355355
weights_path = data_utils.get_file(
356356
file_name,

0 commit comments

Comments
 (0)