3
3
import torch
4
4
import torch .nn as nn
5
5
import torch .nn .functional as F
6
- from torch .jit .annotations import Optional
7
6
from torch import Tensor
8
7
from .utils import load_state_dict_from_url
8
+ from typing import Callable , Any , Optional , Tuple , List
9
9
10
10
11
11
__all__ = ['Inception3' , 'inception_v3' , 'InceptionOutputs' , '_InceptionOutputs' ]
24
24
_InceptionOutputs = InceptionOutputs
25
25
26
26
27
- def inception_v3 (pretrained = False , progress = True , ** kwargs ) :
27
+ def inception_v3 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> "Inception3" :
28
28
r"""Inception v3 model architecture from
29
29
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
30
30
@@ -63,8 +63,14 @@ def inception_v3(pretrained=False, progress=True, **kwargs):
63
63
64
64
class Inception3 (nn .Module ):
65
65
66
- def __init__ (self , num_classes = 1000 , aux_logits = True , transform_input = False ,
67
- inception_blocks = None , init_weights = None ):
66
+ def __init__ (
67
+ self ,
68
+ num_classes : int = 1000 ,
69
+ aux_logits : bool = True ,
70
+ transform_input : bool = False ,
71
+ inception_blocks : Optional [List [Callable [..., nn .Module ]]] = None ,
72
+ init_weights : Optional [bool ] = None
73
+ ) -> None :
68
74
super (Inception3 , self ).__init__ ()
69
75
if inception_blocks is None :
70
76
inception_blocks = [
@@ -124,15 +130,15 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False,
124
130
nn .init .constant_ (m .weight , 1 )
125
131
nn .init .constant_ (m .bias , 0 )
126
132
127
- def _transform_input (self , x ) :
133
+ def _transform_input (self , x : Tensor ) -> Tensor :
128
134
if self .transform_input :
129
135
x_ch0 = torch .unsqueeze (x [:, 0 ], 1 ) * (0.229 / 0.5 ) + (0.485 - 0.5 ) / 0.5
130
136
x_ch1 = torch .unsqueeze (x [:, 1 ], 1 ) * (0.224 / 0.5 ) + (0.456 - 0.5 ) / 0.5
131
137
x_ch2 = torch .unsqueeze (x [:, 2 ], 1 ) * (0.225 / 0.5 ) + (0.406 - 0.5 ) / 0.5
132
138
x = torch .cat ((x_ch0 , x_ch1 , x_ch2 ), 1 )
133
139
return x
134
140
135
- def _forward (self , x ) :
141
+ def _forward (self , x : Tensor ) -> Tuple [ Tensor , Optional [ Tensor ]] :
136
142
# N x 3 x 299 x 299
137
143
x = self .Conv2d_1a_3x3 (x )
138
144
# N x 32 x 149 x 149
@@ -188,13 +194,13 @@ def _forward(self, x):
188
194
return x , aux
189
195
190
196
@torch .jit .unused
191
- def eager_outputs (self , x : torch . Tensor , aux : Optional [Tensor ]) -> InceptionOutputs :
197
+ def eager_outputs (self , x : Tensor , aux : Optional [Tensor ]) -> InceptionOutputs :
192
198
if self .training and self .aux_logits :
193
199
return InceptionOutputs (x , aux )
194
200
else :
195
201
return x # type: ignore[return-value]
196
202
197
- def forward (self , x ) :
203
+ def forward (self , x : Tensor ) -> InceptionOutputs :
198
204
x = self ._transform_input (x )
199
205
x , aux = self ._forward (x )
200
206
aux_defined = self .training and self .aux_logits
@@ -208,7 +214,12 @@ def forward(self, x):
208
214
209
215
class InceptionA (nn .Module ):
210
216
211
- def __init__ (self , in_channels , pool_features , conv_block = None ):
217
+ def __init__ (
218
+ self ,
219
+ in_channels : int ,
220
+ pool_features : int ,
221
+ conv_block : Optional [Callable [..., nn .Module ]] = None
222
+ ) -> None :
212
223
super (InceptionA , self ).__init__ ()
213
224
if conv_block is None :
214
225
conv_block = BasicConv2d
@@ -223,7 +234,7 @@ def __init__(self, in_channels, pool_features, conv_block=None):
223
234
224
235
self .branch_pool = conv_block (in_channels , pool_features , kernel_size = 1 )
225
236
226
- def _forward (self , x ) :
237
+ def _forward (self , x : Tensor ) -> List [ Tensor ] :
227
238
branch1x1 = self .branch1x1 (x )
228
239
229
240
branch5x5 = self .branch5x5_1 (x )
@@ -239,14 +250,18 @@ def _forward(self, x):
239
250
outputs = [branch1x1 , branch5x5 , branch3x3dbl , branch_pool ]
240
251
return outputs
241
252
242
- def forward (self , x ) :
253
+ def forward (self , x : Tensor ) -> Tensor :
243
254
outputs = self ._forward (x )
244
255
return torch .cat (outputs , 1 )
245
256
246
257
247
258
class InceptionB (nn .Module ):
248
259
249
- def __init__ (self , in_channels , conv_block = None ):
260
+ def __init__ (
261
+ self ,
262
+ in_channels : int ,
263
+ conv_block : Optional [Callable [..., nn .Module ]] = None
264
+ ) -> None :
250
265
super (InceptionB , self ).__init__ ()
251
266
if conv_block is None :
252
267
conv_block = BasicConv2d
@@ -256,7 +271,7 @@ def __init__(self, in_channels, conv_block=None):
256
271
self .branch3x3dbl_2 = conv_block (64 , 96 , kernel_size = 3 , padding = 1 )
257
272
self .branch3x3dbl_3 = conv_block (96 , 96 , kernel_size = 3 , stride = 2 )
258
273
259
- def _forward (self , x ) :
274
+ def _forward (self , x : Tensor ) -> List [ Tensor ] :
260
275
branch3x3 = self .branch3x3 (x )
261
276
262
277
branch3x3dbl = self .branch3x3dbl_1 (x )
@@ -268,14 +283,19 @@ def _forward(self, x):
268
283
outputs = [branch3x3 , branch3x3dbl , branch_pool ]
269
284
return outputs
270
285
271
- def forward (self , x ) :
286
+ def forward (self , x : Tensor ) -> Tensor :
272
287
outputs = self ._forward (x )
273
288
return torch .cat (outputs , 1 )
274
289
275
290
276
291
class InceptionC (nn .Module ):
277
292
278
- def __init__ (self , in_channels , channels_7x7 , conv_block = None ):
293
+ def __init__ (
294
+ self ,
295
+ in_channels : int ,
296
+ channels_7x7 : int ,
297
+ conv_block : Optional [Callable [..., nn .Module ]] = None
298
+ ) -> None :
279
299
super (InceptionC , self ).__init__ ()
280
300
if conv_block is None :
281
301
conv_block = BasicConv2d
@@ -294,7 +314,7 @@ def __init__(self, in_channels, channels_7x7, conv_block=None):
294
314
295
315
self .branch_pool = conv_block (in_channels , 192 , kernel_size = 1 )
296
316
297
- def _forward (self , x ) :
317
+ def _forward (self , x : Tensor ) -> List [ Tensor ] :
298
318
branch1x1 = self .branch1x1 (x )
299
319
300
320
branch7x7 = self .branch7x7_1 (x )
@@ -313,14 +333,18 @@ def _forward(self, x):
313
333
outputs = [branch1x1 , branch7x7 , branch7x7dbl , branch_pool ]
314
334
return outputs
315
335
316
- def forward (self , x ) :
336
+ def forward (self , x : Tensor ) -> Tensor :
317
337
outputs = self ._forward (x )
318
338
return torch .cat (outputs , 1 )
319
339
320
340
321
341
class InceptionD (nn .Module ):
322
342
323
- def __init__ (self , in_channels , conv_block = None ):
343
+ def __init__ (
344
+ self ,
345
+ in_channels : int ,
346
+ conv_block : Optional [Callable [..., nn .Module ]] = None
347
+ ) -> None :
324
348
super (InceptionD , self ).__init__ ()
325
349
if conv_block is None :
326
350
conv_block = BasicConv2d
@@ -332,7 +356,7 @@ def __init__(self, in_channels, conv_block=None):
332
356
self .branch7x7x3_3 = conv_block (192 , 192 , kernel_size = (7 , 1 ), padding = (3 , 0 ))
333
357
self .branch7x7x3_4 = conv_block (192 , 192 , kernel_size = 3 , stride = 2 )
334
358
335
- def _forward (self , x ) :
359
+ def _forward (self , x : Tensor ) -> List [ Tensor ] :
336
360
branch3x3 = self .branch3x3_1 (x )
337
361
branch3x3 = self .branch3x3_2 (branch3x3 )
338
362
@@ -345,14 +369,18 @@ def _forward(self, x):
345
369
outputs = [branch3x3 , branch7x7x3 , branch_pool ]
346
370
return outputs
347
371
348
- def forward (self , x ) :
372
+ def forward (self , x : Tensor ) -> Tensor :
349
373
outputs = self ._forward (x )
350
374
return torch .cat (outputs , 1 )
351
375
352
376
353
377
class InceptionE (nn .Module ):
354
378
355
- def __init__ (self , in_channels , conv_block = None ):
379
+ def __init__ (
380
+ self ,
381
+ in_channels : int ,
382
+ conv_block : Optional [Callable [..., nn .Module ]] = None
383
+ ) -> None :
356
384
super (InceptionE , self ).__init__ ()
357
385
if conv_block is None :
358
386
conv_block = BasicConv2d
@@ -369,7 +397,7 @@ def __init__(self, in_channels, conv_block=None):
369
397
370
398
self .branch_pool = conv_block (in_channels , 192 , kernel_size = 1 )
371
399
372
- def _forward (self , x ) :
400
+ def _forward (self , x : Tensor ) -> List [ Tensor ] :
373
401
branch1x1 = self .branch1x1 (x )
374
402
375
403
branch3x3 = self .branch3x3_1 (x )
@@ -393,24 +421,29 @@ def _forward(self, x):
393
421
outputs = [branch1x1 , branch3x3 , branch3x3dbl , branch_pool ]
394
422
return outputs
395
423
396
- def forward (self , x ) :
424
+ def forward (self , x : Tensor ) -> Tensor :
397
425
outputs = self ._forward (x )
398
426
return torch .cat (outputs , 1 )
399
427
400
428
401
429
class InceptionAux (nn .Module ):
402
430
403
- def __init__ (self , in_channels , num_classes , conv_block = None ):
431
+ def __init__ (
432
+ self ,
433
+ in_channels : int ,
434
+ num_classes : int ,
435
+ conv_block : Optional [Callable [..., nn .Module ]] = None
436
+ ) -> None :
404
437
super (InceptionAux , self ).__init__ ()
405
438
if conv_block is None :
406
439
conv_block = BasicConv2d
407
440
self .conv0 = conv_block (in_channels , 128 , kernel_size = 1 )
408
441
self .conv1 = conv_block (128 , 768 , kernel_size = 5 )
409
- self .conv1 .stddev = 0.01
442
+ self .conv1 .stddev = 0.01 # type: ignore[assignment]
410
443
self .fc = nn .Linear (768 , num_classes )
411
- self .fc .stddev = 0.001
444
+ self .fc .stddev = 0.001 # type: ignore[assignment]
412
445
413
- def forward (self , x ) :
446
+ def forward (self , x : Tensor ) -> Tensor :
414
447
# N x 768 x 17 x 17
415
448
x = F .avg_pool2d (x , kernel_size = 5 , stride = 3 )
416
449
# N x 768 x 5 x 5
@@ -430,12 +463,17 @@ def forward(self, x):
430
463
431
464
class BasicConv2d (nn .Module ):
432
465
433
- def __init__ (self , in_channels , out_channels , ** kwargs ):
466
+ def __init__ (
467
+ self ,
468
+ in_channels : int ,
469
+ out_channels : int ,
470
+ ** kwargs : Any
471
+ ) -> None :
434
472
super (BasicConv2d , self ).__init__ ()
435
473
self .conv = nn .Conv2d (in_channels , out_channels , bias = False , ** kwargs )
436
474
self .bn = nn .BatchNorm2d (out_channels , eps = 0.001 )
437
475
438
- def forward (self , x ) :
476
+ def forward (self , x : Tensor ) -> Tensor :
439
477
x = self .conv (x )
440
478
x = self .bn (x )
441
479
return F .relu (x , inplace = True )
0 commit comments