1
1
import torch
2
+ from torch import Tensor
2
3
import torch .nn as nn
3
4
from .utils import load_state_dict_from_url
5
+ from typing import Callable , Any , List
4
6
5
7
6
8
__all__ = [
16
18
}
17
19
18
20
19
- def channel_shuffle (x , groups ):
20
- # type: (torch.Tensor, int) -> torch.Tensor
21
+ def channel_shuffle (x : Tensor , groups : int ) -> Tensor :
21
22
batchsize , num_channels , height , width = x .data .size ()
22
23
channels_per_group = num_channels // groups
23
24
@@ -34,7 +35,12 @@ def channel_shuffle(x, groups):
34
35
35
36
36
37
class InvertedResidual (nn .Module ):
37
- def __init__ (self , inp , oup , stride ):
38
+ def __init__ (
39
+ self ,
40
+ inp : int ,
41
+ oup : int ,
42
+ stride : int
43
+ ) -> None :
38
44
super (InvertedResidual , self ).__init__ ()
39
45
40
46
if not (1 <= stride <= 3 ):
@@ -68,10 +74,17 @@ def __init__(self, inp, oup, stride):
68
74
)
69
75
70
76
@staticmethod
71
- def depthwise_conv (i , o , kernel_size , stride = 1 , padding = 0 , bias = False ):
77
+ def depthwise_conv (
78
+ i : int ,
79
+ o : int ,
80
+ kernel_size : int ,
81
+ stride : int = 1 ,
82
+ padding : int = 0 ,
83
+ bias : bool = False
84
+ ) -> nn .Conv2d :
72
85
return nn .Conv2d (i , o , kernel_size , stride , padding , bias = bias , groups = i )
73
86
74
- def forward (self , x ) :
87
+ def forward (self , x : Tensor ) -> Tensor :
75
88
if self .stride == 1 :
76
89
x1 , x2 = x .chunk (2 , dim = 1 )
77
90
out = torch .cat ((x1 , self .branch2 (x2 )), dim = 1 )
@@ -84,7 +97,13 @@ def forward(self, x):
84
97
85
98
86
99
class ShuffleNetV2 (nn .Module ):
87
- def __init__ (self , stages_repeats , stages_out_channels , num_classes = 1000 , inverted_residual = InvertedResidual ):
100
+ def __init__ (
101
+ self ,
102
+ stages_repeats : List [int ],
103
+ stages_out_channels : List [int ],
104
+ num_classes : int = 1000 ,
105
+ inverted_residual : Callable [..., nn .Module ] = InvertedResidual
106
+ ) -> None :
88
107
super (ShuffleNetV2 , self ).__init__ ()
89
108
90
109
if len (stages_repeats ) != 3 :
@@ -104,6 +123,10 @@ def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, invert
104
123
105
124
self .maxpool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
106
125
126
+ # Static annotations for mypy
127
+ self .stage2 : nn .Sequential
128
+ self .stage3 : nn .Sequential
129
+ self .stage4 : nn .Sequential
107
130
stage_names = ['stage{}' .format (i ) for i in [2 , 3 , 4 ]]
108
131
for name , repeats , output_channels in zip (
109
132
stage_names , stages_repeats , self ._stage_out_channels [1 :]):
@@ -122,7 +145,7 @@ def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, invert
122
145
123
146
self .fc = nn .Linear (output_channels , num_classes )
124
147
125
- def _forward_impl (self , x ) :
148
+ def _forward_impl (self , x : Tensor ) -> Tensor :
126
149
# See note [TorchScript super()]
127
150
x = self .conv1 (x )
128
151
x = self .maxpool (x )
@@ -134,11 +157,11 @@ def _forward_impl(self, x):
134
157
x = self .fc (x )
135
158
return x
136
159
137
- def forward (self , x ) :
160
+ def forward (self , x : Tensor ) -> Tensor :
138
161
return self ._forward_impl (x )
139
162
140
163
141
- def _shufflenetv2 (arch , pretrained , progress , * args , ** kwargs ) :
164
+ def _shufflenetv2 (arch : str , pretrained : bool , progress : bool , * args : Any , ** kwargs : Any ) -> ShuffleNetV2 :
142
165
model = ShuffleNetV2 (* args , ** kwargs )
143
166
144
167
if pretrained :
@@ -152,7 +175,7 @@ def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
152
175
return model
153
176
154
177
155
- def shufflenet_v2_x0_5 (pretrained = False , progress = True , ** kwargs ) :
178
+ def shufflenet_v2_x0_5 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> ShuffleNetV2 :
156
179
"""
157
180
Constructs a ShuffleNetV2 with 0.5x output channels, as described in
158
181
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
@@ -166,7 +189,7 @@ def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs):
166
189
[4 , 8 , 4 ], [24 , 48 , 96 , 192 , 1024 ], ** kwargs )
167
190
168
191
169
- def shufflenet_v2_x1_0 (pretrained = False , progress = True , ** kwargs ) :
192
+ def shufflenet_v2_x1_0 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> ShuffleNetV2 :
170
193
"""
171
194
Constructs a ShuffleNetV2 with 1.0x output channels, as described in
172
195
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
@@ -180,7 +203,7 @@ def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
180
203
[4 , 8 , 4 ], [24 , 116 , 232 , 464 , 1024 ], ** kwargs )
181
204
182
205
183
- def shufflenet_v2_x1_5 (pretrained = False , progress = True , ** kwargs ) :
206
+ def shufflenet_v2_x1_5 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> ShuffleNetV2 :
184
207
"""
185
208
Constructs a ShuffleNetV2 with 1.5x output channels, as described in
186
209
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
@@ -194,7 +217,7 @@ def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs):
194
217
[4 , 8 , 4 ], [24 , 176 , 352 , 704 , 1024 ], ** kwargs )
195
218
196
219
197
- def shufflenet_v2_x2_0 (pretrained = False , progress = True , ** kwargs ) :
220
+ def shufflenet_v2_x2_0 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> ShuffleNetV2 :
198
221
"""
199
222
Constructs a ShuffleNetV2 with 2.0x output channels, as described in
200
223
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
0 commit comments