Skip to content

Commit 65591f1

Browse files
authored
Added annotation typing to squeezenet (#2865)
* style: Added annotation typing for squeezenet * feat: Added typing for kwargs
1 parent 67e7879 commit 65591f1

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

torchvision/models/squeezenet.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
import torch.nn.init as init
44
from .utils import load_state_dict_from_url
5+
from typing import Any
56

67
__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']
78

@@ -13,8 +14,13 @@
1314

1415
class Fire(nn.Module):
1516

16-
def __init__(self, inplanes, squeeze_planes,
17-
expand1x1_planes, expand3x3_planes):
17+
def __init__(
18+
self,
19+
inplanes: int,
20+
squeeze_planes: int,
21+
expand1x1_planes: int,
22+
expand3x3_planes: int
23+
):
1824
super(Fire, self).__init__()
1925
self.inplanes = inplanes
2026
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
@@ -26,7 +32,7 @@ def __init__(self, inplanes, squeeze_planes,
2632
kernel_size=3, padding=1)
2733
self.expand3x3_activation = nn.ReLU(inplace=True)
2834

29-
def forward(self, x):
35+
def forward(self, x: torch.Tensor) -> torch.Tensor:
3036
x = self.squeeze_activation(self.squeeze(x))
3137
return torch.cat([
3238
self.expand1x1_activation(self.expand1x1(x)),
@@ -36,7 +42,11 @@ def forward(self, x):
3642

3743
class SqueezeNet(nn.Module):
3844

39-
def __init__(self, version='1_0', num_classes=1000):
45+
def __init__(
46+
self,
47+
version: str = '1_0',
48+
num_classes: int = 1000
49+
):
4050
super(SqueezeNet, self).__init__()
4151
self.num_classes = num_classes
4252
if version == '1_0':
@@ -96,13 +106,13 @@ def __init__(self, version='1_0', num_classes=1000):
96106
if m.bias is not None:
97107
init.constant_(m.bias, 0)
98108

99-
def forward(self, x):
109+
def forward(self, x: torch.Tensor) -> torch.Tensor:
100110
x = self.features(x)
101111
x = self.classifier(x)
102112
return torch.flatten(x, 1)
103113

104114

105-
def _squeezenet(version, pretrained, progress, **kwargs):
115+
def _squeezenet(version: str, pretrained: bool, progress: bool, **kwargs: Any) -> SqueezeNet:
106116
model = SqueezeNet(version, **kwargs)
107117
if pretrained:
108118
arch = 'squeezenet' + version
@@ -112,7 +122,7 @@ def _squeezenet(version, pretrained, progress, **kwargs):
112122
return model
113123

114124

115-
def squeezenet1_0(pretrained=False, progress=True, **kwargs):
125+
def squeezenet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet:
116126
r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
117127
accuracy with 50x fewer parameters and <0.5MB model size"
118128
<https://arxiv.org/abs/1602.07360>`_ paper.
@@ -124,7 +134,7 @@ def squeezenet1_0(pretrained=False, progress=True, **kwargs):
124134
return _squeezenet('1_0', pretrained, progress, **kwargs)
125135

126136

127-
def squeezenet1_1(pretrained=False, progress=True, **kwargs):
137+
def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet:
128138
r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
129139
<https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
130140
SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters

0 commit comments

Comments
 (0)