Skip to content

Commit 67e7879

Browse files
authored
Added annotation typing to mnasnet (#2856)
* style: Added annotation typing for mnasnet * refactor: Removed un-necessary import
1 parent aa4cf03 commit 67e7879

File tree

1 file changed

+31
-17
lines changed

1 file changed

+31
-17
lines changed

torchvision/models/mnasnet.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import warnings
22

33
import torch
4+
from torch import Tensor
45
import torch.nn as nn
56
from .utils import load_state_dict_from_url
7+
from typing import Any, Dict, List
68

79
__all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3']
810

@@ -22,8 +24,15 @@
2224

2325
class _InvertedResidual(nn.Module):
2426

25-
def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor,
26-
bn_momentum=0.1):
27+
def __init__(
28+
self,
29+
in_ch: int,
30+
out_ch: int,
31+
kernel_size: int,
32+
stride: int,
33+
expansion_factor: int,
34+
bn_momentum: float = 0.1
35+
):
2736
super(_InvertedResidual, self).__init__()
2837
assert stride in [1, 2]
2938
assert kernel_size in [3, 5]
@@ -43,15 +52,15 @@ def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor,
4352
nn.Conv2d(mid_ch, out_ch, 1, bias=False),
4453
nn.BatchNorm2d(out_ch, momentum=bn_momentum))
4554

46-
def forward(self, input):
55+
def forward(self, input: Tensor) -> Tensor:
4756
if self.apply_residual:
4857
return self.layers(input) + input
4958
else:
5059
return self.layers(input)
5160

5261

53-
def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats,
54-
bn_momentum):
62+
def _stack(in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int,
63+
bn_momentum: float) -> nn.Sequential:
5564
""" Creates a stack of inverted residuals. """
5665
assert repeats >= 1
5766
# First one has no skip, because feature map size changes.
@@ -65,7 +74,7 @@ def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats,
6574
return nn.Sequential(first, *remaining)
6675

6776

68-
def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
77+
def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int:
6978
""" Asymmetric rounding to make `val` divisible by `divisor`. With default
7079
bias, will round up, unless the number is no more than 10% greater than the
7180
smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """
@@ -74,7 +83,7 @@ def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
7483
return new_val if new_val >= round_up_bias * val else new_val + divisor
7584

7685

77-
def _get_depths(alpha):
86+
def _get_depths(alpha: float) -> List[int]:
7887
""" Scales tensor depths as in reference MobileNet code, prefers rouding up
7988
rather than down. """
8089
depths = [32, 16, 24, 40, 80, 96, 192, 320]
@@ -95,7 +104,12 @@ class MNASNet(torch.nn.Module):
95104
# Version 2 adds depth scaling in the initial stages of the network.
96105
_version = 2
97106

98-
def __init__(self, alpha, num_classes=1000, dropout=0.2):
107+
def __init__(
108+
self,
109+
alpha: float,
110+
num_classes: int = 1000,
111+
dropout: float = 0.2
112+
):
99113
super(MNASNet, self).__init__()
100114
assert alpha > 0.0
101115
self.alpha = alpha
@@ -130,13 +144,13 @@ def __init__(self, alpha, num_classes=1000, dropout=0.2):
130144
nn.Linear(1280, num_classes))
131145
self._initialize_weights()
132146

133-
def forward(self, x):
147+
def forward(self, x: Tensor) -> Tensor:
134148
x = self.layers(x)
135149
# Equivalent to global avgpool and removing H and W dimensions.
136150
x = x.mean([2, 3])
137151
return self.classifier(x)
138152

139-
def _initialize_weights(self):
153+
def _initialize_weights(self) -> None:
140154
for m in self.modules():
141155
if isinstance(m, nn.Conv2d):
142156
nn.init.kaiming_normal_(m.weight, mode="fan_out",
@@ -151,8 +165,8 @@ def _initialize_weights(self):
151165
nonlinearity="sigmoid")
152166
nn.init.zeros_(m.bias)
153167

154-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
155-
missing_keys, unexpected_keys, error_msgs):
168+
def _load_from_state_dict(self, state_dict: Dict, prefix: str, local_metadata: Dict, strict: bool,
169+
missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str]) -> None:
156170
version = local_metadata.get("version", None)
157171
assert version in [1, 2]
158172

@@ -192,7 +206,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
192206
unexpected_keys, error_msgs)
193207

194208

195-
def _load_pretrained(model_name, model, progress):
209+
def _load_pretrained(model_name: str, model: nn.Module, progress: bool) -> None:
196210
if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
197211
raise ValueError(
198212
"No checkpoint is available for model type {}".format(model_name))
@@ -201,7 +215,7 @@ def _load_pretrained(model_name, model, progress):
201215
load_state_dict_from_url(checkpoint_url, progress=progress))
202216

203217

204-
def mnasnet0_5(pretrained=False, progress=True, **kwargs):
218+
def mnasnet0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
205219
"""MNASNet with depth multiplier of 0.5 from
206220
`"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
207221
<https://arxiv.org/pdf/1807.11626.pdf>`_.
@@ -215,7 +229,7 @@ def mnasnet0_5(pretrained=False, progress=True, **kwargs):
215229
return model
216230

217231

218-
def mnasnet0_75(pretrained=False, progress=True, **kwargs):
232+
def mnasnet0_75(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
219233
"""MNASNet with depth multiplier of 0.75 from
220234
`"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
221235
<https://arxiv.org/pdf/1807.11626.pdf>`_.
@@ -229,7 +243,7 @@ def mnasnet0_75(pretrained=False, progress=True, **kwargs):
229243
return model
230244

231245

232-
def mnasnet1_0(pretrained=False, progress=True, **kwargs):
246+
def mnasnet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
233247
"""MNASNet with depth multiplier of 1.0 from
234248
`"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
235249
<https://arxiv.org/pdf/1807.11626.pdf>`_.
@@ -243,7 +257,7 @@ def mnasnet1_0(pretrained=False, progress=True, **kwargs):
243257
return model
244258

245259

246-
def mnasnet1_3(pretrained=False, progress=True, **kwargs):
260+
def mnasnet1_3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
247261
"""MNASNet with depth multiplier of 1.3 from
248262
`"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
249263
<https://arxiv.org/pdf/1807.11626.pdf>`_.

0 commit comments

Comments
 (0)