3
3
import torch
4
4
import torch .nn as nn
5
5
import torch .nn .functional as F
6
- from torch .jit .annotations import Optional , Tuple
7
6
from torch import Tensor
8
7
from .utils import load_state_dict_from_url
8
+ from typing import Optional , Tuple , List , Callable , Any
9
9
10
10
__all__ = ['GoogLeNet' , 'googlenet' , "GoogLeNetOutputs" , "_GoogLeNetOutputs" ]
11
11
23
23
_GoogLeNetOutputs = GoogLeNetOutputs
24
24
25
25
26
- def googlenet (pretrained = False , progress = True , ** kwargs ) :
26
+ def googlenet (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> "GoogLeNet" :
27
27
r"""GoogLeNet (Inception v1) model architecture from
28
28
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
29
29
@@ -52,8 +52,8 @@ def googlenet(pretrained=False, progress=True, **kwargs):
52
52
model .load_state_dict (state_dict )
53
53
if not original_aux_logits :
54
54
model .aux_logits = False
55
- model .aux1 = None
56
- model .aux2 = None
55
+ model .aux1 = None # type: ignore[assignment]
56
+ model .aux2 = None # type: ignore[assignment]
57
57
return model
58
58
59
59
return GoogLeNet (** kwargs )
@@ -62,8 +62,14 @@ def googlenet(pretrained=False, progress=True, **kwargs):
62
62
class GoogLeNet (nn .Module ):
63
63
__constants__ = ['aux_logits' , 'transform_input' ]
64
64
65
- def __init__ (self , num_classes = 1000 , aux_logits = True , transform_input = False , init_weights = None ,
66
- blocks = None ):
65
+ def __init__ (
66
+ self ,
67
+ num_classes : int = 1000 ,
68
+ aux_logits : bool = True ,
69
+ transform_input : bool = False ,
70
+ init_weights : Optional [bool ] = None ,
71
+ blocks : Optional [List [Callable [..., nn .Module ]]] = None
72
+ ) -> None :
67
73
super (GoogLeNet , self ).__init__ ()
68
74
if blocks is None :
69
75
blocks = [BasicConv2d , Inception , InceptionAux ]
@@ -104,8 +110,8 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, ini
104
110
self .aux1 = inception_aux_block (512 , num_classes )
105
111
self .aux2 = inception_aux_block (528 , num_classes )
106
112
else :
107
- self .aux1 = None
108
- self .aux2 = None
113
+ self .aux1 = None # type: ignore[assignment]
114
+ self .aux2 = None # type: ignore[assignment]
109
115
110
116
self .avgpool = nn .AdaptiveAvgPool2d ((1 , 1 ))
111
117
self .dropout = nn .Dropout (0.2 )
@@ -114,7 +120,7 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, ini
114
120
if init_weights :
115
121
self ._initialize_weights ()
116
122
117
- def _initialize_weights (self ):
123
+ def _initialize_weights (self ) -> None :
118
124
for m in self .modules ():
119
125
if isinstance (m , nn .Conv2d ) or isinstance (m , nn .Linear ):
120
126
import scipy .stats as stats
@@ -127,17 +133,15 @@ def _initialize_weights(self):
127
133
nn .init .constant_ (m .weight , 1 )
128
134
nn .init .constant_ (m .bias , 0 )
129
135
130
- def _transform_input (self , x ):
131
- # type: (Tensor) -> Tensor
136
+ def _transform_input (self , x : Tensor ) -> Tensor :
132
137
if self .transform_input :
133
138
x_ch0 = torch .unsqueeze (x [:, 0 ], 1 ) * (0.229 / 0.5 ) + (0.485 - 0.5 ) / 0.5
134
139
x_ch1 = torch .unsqueeze (x [:, 1 ], 1 ) * (0.224 / 0.5 ) + (0.456 - 0.5 ) / 0.5
135
140
x_ch2 = torch .unsqueeze (x [:, 2 ], 1 ) * (0.225 / 0.5 ) + (0.406 - 0.5 ) / 0.5
136
141
x = torch .cat ((x_ch0 , x_ch1 , x_ch2 ), 1 )
137
142
return x
138
143
139
- def _forward (self , x ):
140
- # type: (Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]
144
+ def _forward (self , x : Tensor ) -> Tuple [Tensor , Optional [Tensor ], Optional [Tensor ]]:
141
145
# N x 3 x 224 x 224
142
146
x = self .conv1 (x )
143
147
# N x 64 x 112 x 112
@@ -199,8 +203,7 @@ def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> Goog
199
203
else :
200
204
return x # type: ignore[return-value]
201
205
202
- def forward (self , x ):
203
- # type: (Tensor) -> GoogLeNetOutputs
206
+ def forward (self , x : Tensor ) -> GoogLeNetOutputs :
204
207
x = self ._transform_input (x )
205
208
x , aux1 , aux2 = self ._forward (x )
206
209
aux_defined = self .training and self .aux_logits
@@ -214,8 +217,17 @@ def forward(self, x):
214
217
215
218
class Inception (nn .Module ):
216
219
217
- def __init__ (self , in_channels , ch1x1 , ch3x3red , ch3x3 , ch5x5red , ch5x5 , pool_proj ,
218
- conv_block = None ):
220
+ def __init__ (
221
+ self ,
222
+ in_channels : int ,
223
+ ch1x1 : int ,
224
+ ch3x3red : int ,
225
+ ch3x3 : int ,
226
+ ch5x5red : int ,
227
+ ch5x5 : int ,
228
+ pool_proj : int ,
229
+ conv_block : Optional [Callable [..., nn .Module ]] = None
230
+ ) -> None :
219
231
super (Inception , self ).__init__ ()
220
232
if conv_block is None :
221
233
conv_block = BasicConv2d
@@ -238,7 +250,7 @@ def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_pr
238
250
conv_block (in_channels , pool_proj , kernel_size = 1 )
239
251
)
240
252
241
- def _forward (self , x ) :
253
+ def _forward (self , x : Tensor ) -> List [ Tensor ] :
242
254
branch1 = self .branch1 (x )
243
255
branch2 = self .branch2 (x )
244
256
branch3 = self .branch3 (x )
@@ -247,14 +259,19 @@ def _forward(self, x):
247
259
outputs = [branch1 , branch2 , branch3 , branch4 ]
248
260
return outputs
249
261
250
- def forward (self , x ) :
262
+ def forward (self , x : Tensor ) -> Tensor :
251
263
outputs = self ._forward (x )
252
264
return torch .cat (outputs , 1 )
253
265
254
266
255
267
class InceptionAux (nn .Module ):
256
268
257
- def __init__ (self , in_channels , num_classes , conv_block = None ):
269
+ def __init__ (
270
+ self ,
271
+ in_channels : int ,
272
+ num_classes : int ,
273
+ conv_block : Optional [Callable [..., nn .Module ]] = None
274
+ ) -> None :
258
275
super (InceptionAux , self ).__init__ ()
259
276
if conv_block is None :
260
277
conv_block = BasicConv2d
@@ -263,7 +280,7 @@ def __init__(self, in_channels, num_classes, conv_block=None):
263
280
self .fc1 = nn .Linear (2048 , 1024 )
264
281
self .fc2 = nn .Linear (1024 , num_classes )
265
282
266
- def forward (self , x ) :
283
+ def forward (self , x : Tensor ) -> Tensor :
267
284
# aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
268
285
x = F .adaptive_avg_pool2d (x , (4 , 4 ))
269
286
# aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
@@ -283,12 +300,17 @@ def forward(self, x):
283
300
284
301
class BasicConv2d (nn .Module ):
285
302
286
- def __init__ (self , in_channels , out_channels , ** kwargs ):
303
+ def __init__ (
304
+ self ,
305
+ in_channels : int ,
306
+ out_channels : int ,
307
+ ** kwargs : Any
308
+ ) -> None :
287
309
super (BasicConv2d , self ).__init__ ()
288
310
self .conv = nn .Conv2d (in_channels , out_channels , bias = False , ** kwargs )
289
311
self .bn = nn .BatchNorm2d (out_channels , eps = 0.001 )
290
312
291
- def forward (self , x ) :
313
+ def forward (self , x : Tensor ) -> Tensor :
292
314
x = self .conv (x )
293
315
x = self .bn (x )
294
316
return F .relu (x , inplace = True )
0 commit comments