2
2
import torch .nn as nn
3
3
import torch .nn .init as init
4
4
from .utils import load_state_dict_from_url
5
+ from typing import Any
5
6
6
7
__all__ = ['SqueezeNet' , 'squeezenet1_0' , 'squeezenet1_1' ]
7
8
13
14
14
15
class Fire (nn .Module ):
15
16
16
- def __init__ (self , inplanes , squeeze_planes ,
17
- expand1x1_planes , expand3x3_planes ):
17
+ def __init__ (
18
+ self ,
19
+ inplanes : int ,
20
+ squeeze_planes : int ,
21
+ expand1x1_planes : int ,
22
+ expand3x3_planes : int
23
+ ):
18
24
super (Fire , self ).__init__ ()
19
25
self .inplanes = inplanes
20
26
self .squeeze = nn .Conv2d (inplanes , squeeze_planes , kernel_size = 1 )
@@ -26,7 +32,7 @@ def __init__(self, inplanes, squeeze_planes,
26
32
kernel_size = 3 , padding = 1 )
27
33
self .expand3x3_activation = nn .ReLU (inplace = True )
28
34
29
- def forward (self , x ) :
35
+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
30
36
x = self .squeeze_activation (self .squeeze (x ))
31
37
return torch .cat ([
32
38
self .expand1x1_activation (self .expand1x1 (x )),
@@ -36,7 +42,11 @@ def forward(self, x):
36
42
37
43
class SqueezeNet (nn .Module ):
38
44
39
- def __init__ (self , version = '1_0' , num_classes = 1000 ):
45
+ def __init__ (
46
+ self ,
47
+ version : str = '1_0' ,
48
+ num_classes : int = 1000
49
+ ):
40
50
super (SqueezeNet , self ).__init__ ()
41
51
self .num_classes = num_classes
42
52
if version == '1_0' :
@@ -96,13 +106,13 @@ def __init__(self, version='1_0', num_classes=1000):
96
106
if m .bias is not None :
97
107
init .constant_ (m .bias , 0 )
98
108
99
- def forward (self , x ) :
109
+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
100
110
x = self .features (x )
101
111
x = self .classifier (x )
102
112
return torch .flatten (x , 1 )
103
113
104
114
105
- def _squeezenet (version , pretrained , progress , ** kwargs ) :
115
+ def _squeezenet (version : str , pretrained : bool , progress : bool , ** kwargs : Any ) -> SqueezeNet :
106
116
model = SqueezeNet (version , ** kwargs )
107
117
if pretrained :
108
118
arch = 'squeezenet' + version
@@ -112,7 +122,7 @@ def _squeezenet(version, pretrained, progress, **kwargs):
112
122
return model
113
123
114
124
115
- def squeezenet1_0 (pretrained = False , progress = True , ** kwargs ) :
125
+ def squeezenet1_0 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> SqueezeNet :
116
126
r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
117
127
accuracy with 50x fewer parameters and <0.5MB model size"
118
128
<https://arxiv.org/abs/1602.07360>`_ paper.
@@ -124,7 +134,7 @@ def squeezenet1_0(pretrained=False, progress=True, **kwargs):
124
134
return _squeezenet ('1_0' , pretrained , progress , ** kwargs )
125
135
126
136
127
- def squeezenet1_1 (pretrained = False , progress = True , ** kwargs ) :
137
+ def squeezenet1_1 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> SqueezeNet :
128
138
r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
129
139
<https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
130
140
SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
0 commit comments