1
1
from torch import nn
2
+ from torch import Tensor
2
3
from .utils import load_state_dict_from_url
4
+ from typing import Callable , Any , Optional , List
3
5
4
6
5
7
__all__ = ['MobileNetV2' , 'mobilenet_v2' ]
10
12
}
11
13
12
14
13
- def _make_divisible (v , divisor , min_value = None ):
15
+ def _make_divisible (v : float , divisor : int , min_value : Optional [ int ] = None ) -> int :
14
16
"""
15
17
This function is taken from the original tf repo.
16
18
It ensures that all layers have a channel number that is divisible by 8
@@ -31,7 +33,15 @@ def _make_divisible(v, divisor, min_value=None):
31
33
32
34
33
35
class ConvBNReLU (nn .Sequential ):
34
- def __init__ (self , in_planes , out_planes , kernel_size = 3 , stride = 1 , groups = 1 , norm_layer = None ):
36
+ def __init__ (
37
+ self ,
38
+ in_planes : int ,
39
+ out_planes : int ,
40
+ kernel_size : int = 3 ,
41
+ stride : int = 1 ,
42
+ groups : int = 1 ,
43
+ norm_layer : Optional [Callable [..., nn .Module ]] = None
44
+ ) -> None :
35
45
padding = (kernel_size - 1 ) // 2
36
46
if norm_layer is None :
37
47
norm_layer = nn .BatchNorm2d
@@ -43,7 +53,14 @@ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, nor
43
53
44
54
45
55
class InvertedResidual (nn .Module ):
46
- def __init__ (self , inp , oup , stride , expand_ratio , norm_layer = None ):
56
+ def __init__ (
57
+ self ,
58
+ inp : int ,
59
+ oup : int ,
60
+ stride : int ,
61
+ expand_ratio : int ,
62
+ norm_layer : Optional [Callable [..., nn .Module ]] = None
63
+ ) -> None :
47
64
super (InvertedResidual , self ).__init__ ()
48
65
self .stride = stride
49
66
assert stride in [1 , 2 ]
@@ -54,7 +71,7 @@ def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
54
71
hidden_dim = int (round (inp * expand_ratio ))
55
72
self .use_res_connect = self .stride == 1 and inp == oup
56
73
57
- layers = []
74
+ layers : List [ nn . Module ] = []
58
75
if expand_ratio != 1 :
59
76
# pw
60
77
layers .append (ConvBNReLU (inp , hidden_dim , kernel_size = 1 , norm_layer = norm_layer ))
@@ -67,21 +84,23 @@ def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
67
84
])
68
85
self .conv = nn .Sequential (* layers )
69
86
70
- def forward (self , x ) :
87
+ def forward (self , x : Tensor ) -> Tensor :
71
88
if self .use_res_connect :
72
89
return x + self .conv (x )
73
90
else :
74
91
return self .conv (x )
75
92
76
93
77
94
class MobileNetV2 (nn .Module ):
78
- def __init__ (self ,
79
- num_classes = 1000 ,
80
- width_mult = 1.0 ,
81
- inverted_residual_setting = None ,
82
- round_nearest = 8 ,
83
- block = None ,
84
- norm_layer = None ):
95
+ def __init__ (
96
+ self ,
97
+ num_classes : int = 1000 ,
98
+ width_mult : float = 1.0 ,
99
+ inverted_residual_setting : Optional [List [List [int ]]] = None ,
100
+ round_nearest : int = 8 ,
101
+ block : Optional [Callable [..., nn .Module ]] = None ,
102
+ norm_layer : Optional [Callable [..., nn .Module ]] = None
103
+ ) -> None :
85
104
"""
86
105
MobileNet V2 main class
87
106
@@ -126,7 +145,7 @@ def __init__(self,
126
145
# building first layer
127
146
input_channel = _make_divisible (input_channel * width_mult , round_nearest )
128
147
self .last_channel = _make_divisible (last_channel * max (1.0 , width_mult ), round_nearest )
129
- features = [ConvBNReLU (3 , input_channel , stride = 2 , norm_layer = norm_layer )]
148
+ features : List [ nn . Module ] = [ConvBNReLU (3 , input_channel , stride = 2 , norm_layer = norm_layer )]
130
149
# building inverted residual blocks
131
150
for t , c , n , s in inverted_residual_setting :
132
151
output_channel = _make_divisible (c * width_mult , round_nearest )
@@ -158,20 +177,20 @@ def __init__(self,
158
177
nn .init .normal_ (m .weight , 0 , 0.01 )
159
178
nn .init .zeros_ (m .bias )
160
179
161
- def _forward_impl (self , x ) :
180
+ def _forward_impl (self , x : Tensor ) -> Tensor :
162
181
# This exists since TorchScript doesn't support inheritance, so the superclass method
163
182
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
164
183
x = self .features (x )
165
184
# Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
166
- x = nn .functional .adaptive_avg_pool2d (x , 1 ).reshape (x .shape [0 ], - 1 )
185
+ x = nn .functional .adaptive_avg_pool2d (x , ( 1 , 1 ) ).reshape (x .shape [0 ], - 1 )
167
186
x = self .classifier (x )
168
187
return x
169
188
170
- def forward (self , x ) :
189
+ def forward (self , x : Tensor ) -> Tensor :
171
190
return self ._forward_impl (x )
172
191
173
192
174
- def mobilenet_v2 (pretrained = False , progress = True , ** kwargs ) :
193
+ def mobilenet_v2 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> MobileNetV2 :
175
194
"""
176
195
Constructs a MobileNetV2 architecture from
177
196
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
0 commit comments