1
1
import warnings
2
2
3
3
import torch
4
+ from torch import Tensor
4
5
import torch .nn as nn
5
6
from .utils import load_state_dict_from_url
7
+ from typing import Any , Dict , List
6
8
7
9
__all__ = ['MNASNet' , 'mnasnet0_5' , 'mnasnet0_75' , 'mnasnet1_0' , 'mnasnet1_3' ]
8
10
22
24
23
25
class _InvertedResidual (nn .Module ):
24
26
25
- def __init__ (self , in_ch , out_ch , kernel_size , stride , expansion_factor ,
26
- bn_momentum = 0.1 ):
27
+ def __init__ (
28
+ self ,
29
+ in_ch : int ,
30
+ out_ch : int ,
31
+ kernel_size : int ,
32
+ stride : int ,
33
+ expansion_factor : int ,
34
+ bn_momentum : float = 0.1
35
+ ):
27
36
super (_InvertedResidual , self ).__init__ ()
28
37
assert stride in [1 , 2 ]
29
38
assert kernel_size in [3 , 5 ]
@@ -43,15 +52,15 @@ def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor,
43
52
nn .Conv2d (mid_ch , out_ch , 1 , bias = False ),
44
53
nn .BatchNorm2d (out_ch , momentum = bn_momentum ))
45
54
46
- def forward (self , input ) :
55
+ def forward (self , input : Tensor ) -> Tensor :
47
56
if self .apply_residual :
48
57
return self .layers (input ) + input
49
58
else :
50
59
return self .layers (input )
51
60
52
61
53
- def _stack (in_ch , out_ch , kernel_size , stride , exp_factor , repeats ,
54
- bn_momentum ) :
62
+ def _stack (in_ch : int , out_ch : int , kernel_size : int , stride : int , exp_factor : int , repeats : int ,
63
+ bn_momentum : float ) -> nn . Sequential :
55
64
""" Creates a stack of inverted residuals. """
56
65
assert repeats >= 1
57
66
# First one has no skip, because feature map size changes.
@@ -65,7 +74,7 @@ def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats,
65
74
return nn .Sequential (first , * remaining )
66
75
67
76
68
- def _round_to_multiple_of (val , divisor , round_up_bias = 0.9 ):
77
+ def _round_to_multiple_of (val : float , divisor : int , round_up_bias : float = 0.9 ) -> int :
69
78
""" Asymmetric rounding to make `val` divisible by `divisor`. With default
70
79
bias, will round up, unless the number is no more than 10% greater than the
71
80
smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """
@@ -74,7 +83,7 @@ def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
74
83
return new_val if new_val >= round_up_bias * val else new_val + divisor
75
84
76
85
77
- def _get_depths (alpha ) :
86
+ def _get_depths (alpha : float ) -> List [ int ] :
78
87
""" Scales tensor depths as in reference MobileNet code, prefers rouding up
79
88
rather than down. """
80
89
depths = [32 , 16 , 24 , 40 , 80 , 96 , 192 , 320 ]
@@ -95,7 +104,12 @@ class MNASNet(torch.nn.Module):
95
104
# Version 2 adds depth scaling in the initial stages of the network.
96
105
_version = 2
97
106
98
- def __init__ (self , alpha , num_classes = 1000 , dropout = 0.2 ):
107
+ def __init__ (
108
+ self ,
109
+ alpha : float ,
110
+ num_classes : int = 1000 ,
111
+ dropout : float = 0.2
112
+ ):
99
113
super (MNASNet , self ).__init__ ()
100
114
assert alpha > 0.0
101
115
self .alpha = alpha
@@ -130,13 +144,13 @@ def __init__(self, alpha, num_classes=1000, dropout=0.2):
130
144
nn .Linear (1280 , num_classes ))
131
145
self ._initialize_weights ()
132
146
133
- def forward (self , x ) :
147
+ def forward (self , x : Tensor ) -> Tensor :
134
148
x = self .layers (x )
135
149
# Equivalent to global avgpool and removing H and W dimensions.
136
150
x = x .mean ([2 , 3 ])
137
151
return self .classifier (x )
138
152
139
- def _initialize_weights (self ):
153
+ def _initialize_weights (self ) -> None :
140
154
for m in self .modules ():
141
155
if isinstance (m , nn .Conv2d ):
142
156
nn .init .kaiming_normal_ (m .weight , mode = "fan_out" ,
@@ -151,8 +165,8 @@ def _initialize_weights(self):
151
165
nonlinearity = "sigmoid" )
152
166
nn .init .zeros_ (m .bias )
153
167
154
- def _load_from_state_dict (self , state_dict , prefix , local_metadata , strict ,
155
- missing_keys , unexpected_keys , error_msgs ) :
168
+ def _load_from_state_dict (self , state_dict : Dict , prefix : str , local_metadata : Dict , strict : bool ,
169
+ missing_keys : List [ str ] , unexpected_keys : List [ str ] , error_msgs : List [ str ]) -> None :
156
170
version = local_metadata .get ("version" , None )
157
171
assert version in [1 , 2 ]
158
172
@@ -192,7 +206,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
192
206
unexpected_keys , error_msgs )
193
207
194
208
195
- def _load_pretrained (model_name , model , progress ) :
209
+ def _load_pretrained (model_name : str , model : nn . Module , progress : bool ) -> None :
196
210
if model_name not in _MODEL_URLS or _MODEL_URLS [model_name ] is None :
197
211
raise ValueError (
198
212
"No checkpoint is available for model type {}" .format (model_name ))
@@ -201,7 +215,7 @@ def _load_pretrained(model_name, model, progress):
201
215
load_state_dict_from_url (checkpoint_url , progress = progress ))
202
216
203
217
204
- def mnasnet0_5 (pretrained = False , progress = True , ** kwargs ) :
218
+ def mnasnet0_5 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> MNASNet :
205
219
"""MNASNet with depth multiplier of 0.5 from
206
220
`"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
207
221
<https://arxiv.org/pdf/1807.11626.pdf>`_.
@@ -215,7 +229,7 @@ def mnasnet0_5(pretrained=False, progress=True, **kwargs):
215
229
return model
216
230
217
231
218
- def mnasnet0_75 (pretrained = False , progress = True , ** kwargs ) :
232
+ def mnasnet0_75 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> MNASNet :
219
233
"""MNASNet with depth multiplier of 0.75 from
220
234
`"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
221
235
<https://arxiv.org/pdf/1807.11626.pdf>`_.
@@ -229,7 +243,7 @@ def mnasnet0_75(pretrained=False, progress=True, **kwargs):
229
243
return model
230
244
231
245
232
- def mnasnet1_0 (pretrained = False , progress = True , ** kwargs ) :
246
+ def mnasnet1_0 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> MNASNet :
233
247
"""MNASNet with depth multiplier of 1.0 from
234
248
`"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
235
249
<https://arxiv.org/pdf/1807.11626.pdf>`_.
@@ -243,7 +257,7 @@ def mnasnet1_0(pretrained=False, progress=True, **kwargs):
243
257
return model
244
258
245
259
246
- def mnasnet1_3 (pretrained = False , progress = True , ** kwargs ) :
260
+ def mnasnet1_3 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> MNASNet :
247
261
"""MNASNet with depth multiplier of 1.3 from
248
262
`"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
249
263
<https://arxiv.org/pdf/1807.11626.pdf>`_.
0 commit comments