Skip to content

Commit d4cd0be

Browse files
authored
Added annotation typing to mobilenet (#2862)
* style: Added annotation typing for mmobilenet * fix: Fixed type hinting of adaptive pooling * refactor: Removed un-necessary import * fix: Fixed constructor typing * fix: Fixed list typing
1 parent 59c9742 commit d4cd0be

File tree

1 file changed

+36
-17
lines changed

1 file changed

+36
-17
lines changed

torchvision/models/mobilenet.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from torch import nn
2+
from torch import Tensor
23
from .utils import load_state_dict_from_url
4+
from typing import Callable, Any, Optional, List
35

46

57
__all__ = ['MobileNetV2', 'mobilenet_v2']
@@ -10,7 +12,7 @@
1012
}
1113

1214

13-
def _make_divisible(v, divisor, min_value=None):
15+
def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
1416
"""
1517
This function is taken from the original tf repo.
1618
It ensures that all layers have a channel number that is divisible by 8
@@ -31,7 +33,15 @@ def _make_divisible(v, divisor, min_value=None):
3133

3234

3335
class ConvBNReLU(nn.Sequential):
34-
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None):
36+
def __init__(
37+
self,
38+
in_planes: int,
39+
out_planes: int,
40+
kernel_size: int = 3,
41+
stride: int = 1,
42+
groups: int = 1,
43+
norm_layer: Optional[Callable[..., nn.Module]] = None
44+
) -> None:
3545
padding = (kernel_size - 1) // 2
3646
if norm_layer is None:
3747
norm_layer = nn.BatchNorm2d
@@ -43,7 +53,14 @@ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, nor
4353

4454

4555
class InvertedResidual(nn.Module):
46-
def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
56+
def __init__(
57+
self,
58+
inp: int,
59+
oup: int,
60+
stride: int,
61+
expand_ratio: int,
62+
norm_layer: Optional[Callable[..., nn.Module]] = None
63+
) -> None:
4764
super(InvertedResidual, self).__init__()
4865
self.stride = stride
4966
assert stride in [1, 2]
@@ -54,7 +71,7 @@ def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
5471
hidden_dim = int(round(inp * expand_ratio))
5572
self.use_res_connect = self.stride == 1 and inp == oup
5673

57-
layers = []
74+
layers: List[nn.Module] = []
5875
if expand_ratio != 1:
5976
# pw
6077
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
@@ -67,21 +84,23 @@ def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
6784
])
6885
self.conv = nn.Sequential(*layers)
6986

70-
def forward(self, x):
87+
def forward(self, x: Tensor) -> Tensor:
7188
if self.use_res_connect:
7289
return x + self.conv(x)
7390
else:
7491
return self.conv(x)
7592

7693

7794
class MobileNetV2(nn.Module):
78-
def __init__(self,
79-
num_classes=1000,
80-
width_mult=1.0,
81-
inverted_residual_setting=None,
82-
round_nearest=8,
83-
block=None,
84-
norm_layer=None):
95+
def __init__(
96+
self,
97+
num_classes: int = 1000,
98+
width_mult: float = 1.0,
99+
inverted_residual_setting: Optional[List[List[int]]] = None,
100+
round_nearest: int = 8,
101+
block: Optional[Callable[..., nn.Module]] = None,
102+
norm_layer: Optional[Callable[..., nn.Module]] = None
103+
) -> None:
85104
"""
86105
MobileNet V2 main class
87106
@@ -126,7 +145,7 @@ def __init__(self,
126145
# building first layer
127146
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
128147
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
129-
features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
148+
features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
130149
# building inverted residual blocks
131150
for t, c, n, s in inverted_residual_setting:
132151
output_channel = _make_divisible(c * width_mult, round_nearest)
@@ -158,20 +177,20 @@ def __init__(self,
158177
nn.init.normal_(m.weight, 0, 0.01)
159178
nn.init.zeros_(m.bias)
160179

161-
def _forward_impl(self, x):
180+
def _forward_impl(self, x: Tensor) -> Tensor:
162181
# This exists since TorchScript doesn't support inheritance, so the superclass method
163182
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
164183
x = self.features(x)
165184
# Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
166-
x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
185+
x = nn.functional.adaptive_avg_pool2d(x, (1, 1)).reshape(x.shape[0], -1)
167186
x = self.classifier(x)
168187
return x
169188

170-
def forward(self, x):
189+
def forward(self, x: Tensor) -> Tensor:
171190
return self._forward_impl(x)
172191

173192

174-
def mobilenet_v2(pretrained=False, progress=True, **kwargs):
193+
def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV2:
175194
"""
176195
Constructs a MobileNetV2 architecture from
177196
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.

0 commit comments

Comments
 (0)