|
1 |
| -from torch import nn |
2 |
| -from torch import Tensor |
3 |
| -from .utils import load_state_dict_from_url |
4 |
| -from typing import Callable, Any, Optional, List |
5 |
| - |
6 |
| - |
7 |
| -__all__ = ['MobileNetV2', 'mobilenet_v2'] |
8 |
| - |
9 |
| - |
10 |
| -model_urls = { |
11 |
| - 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', |
12 |
| -} |
13 |
| - |
14 |
| - |
15 |
| -def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: |
16 |
| - """ |
17 |
| - This function is taken from the original tf repo. |
18 |
| - It ensures that all layers have a channel number that is divisible by 8 |
19 |
| - It can be seen here: |
20 |
| - https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py |
21 |
| - :param v: |
22 |
| - :param divisor: |
23 |
| - :param min_value: |
24 |
| - :return: |
25 |
| - """ |
26 |
| - if min_value is None: |
27 |
| - min_value = divisor |
28 |
| - new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) |
29 |
| - # Make sure that round down does not go down by more than 10%. |
30 |
| - if new_v < 0.9 * v: |
31 |
| - new_v += divisor |
32 |
| - return new_v |
33 |
| - |
34 |
| - |
35 |
| -class ConvBNReLU(nn.Sequential): |
36 |
| - def __init__( |
37 |
| - self, |
38 |
| - in_planes: int, |
39 |
| - out_planes: int, |
40 |
| - kernel_size: int = 3, |
41 |
| - stride: int = 1, |
42 |
| - groups: int = 1, |
43 |
| - norm_layer: Optional[Callable[..., nn.Module]] = None |
44 |
| - ) -> None: |
45 |
| - padding = (kernel_size - 1) // 2 |
46 |
| - if norm_layer is None: |
47 |
| - norm_layer = nn.BatchNorm2d |
48 |
| - super(ConvBNReLU, self).__init__( |
49 |
| - nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), |
50 |
| - norm_layer(out_planes), |
51 |
| - nn.ReLU6(inplace=True) |
52 |
| - ) |
53 |
| - |
54 |
| - |
55 |
| -class InvertedResidual(nn.Module): |
56 |
| - def __init__( |
57 |
| - self, |
58 |
| - inp: int, |
59 |
| - oup: int, |
60 |
| - stride: int, |
61 |
| - expand_ratio: int, |
62 |
| - norm_layer: Optional[Callable[..., nn.Module]] = None |
63 |
| - ) -> None: |
64 |
| - super(InvertedResidual, self).__init__() |
65 |
| - self.stride = stride |
66 |
| - assert stride in [1, 2] |
67 |
| - |
68 |
| - if norm_layer is None: |
69 |
| - norm_layer = nn.BatchNorm2d |
70 |
| - |
71 |
| - hidden_dim = int(round(inp * expand_ratio)) |
72 |
| - self.use_res_connect = self.stride == 1 and inp == oup |
73 |
| - |
74 |
| - layers: List[nn.Module] = [] |
75 |
| - if expand_ratio != 1: |
76 |
| - # pw |
77 |
| - layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) |
78 |
| - layers.extend([ |
79 |
| - # dw |
80 |
| - ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer), |
81 |
| - # pw-linear |
82 |
| - nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), |
83 |
| - norm_layer(oup), |
84 |
| - ]) |
85 |
| - self.conv = nn.Sequential(*layers) |
86 |
| - |
87 |
| - def forward(self, x: Tensor) -> Tensor: |
88 |
| - if self.use_res_connect: |
89 |
| - return x + self.conv(x) |
90 |
| - else: |
91 |
| - return self.conv(x) |
92 |
| - |
93 |
| - |
94 |
| -class MobileNetV2(nn.Module): |
95 |
| - def __init__( |
96 |
| - self, |
97 |
| - num_classes: int = 1000, |
98 |
| - width_mult: float = 1.0, |
99 |
| - inverted_residual_setting: Optional[List[List[int]]] = None, |
100 |
| - round_nearest: int = 8, |
101 |
| - block: Optional[Callable[..., nn.Module]] = None, |
102 |
| - norm_layer: Optional[Callable[..., nn.Module]] = None |
103 |
| - ) -> None: |
104 |
| - """ |
105 |
| - MobileNet V2 main class |
106 |
| -
|
107 |
| - Args: |
108 |
| - num_classes (int): Number of classes |
109 |
| - width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount |
110 |
| - inverted_residual_setting: Network structure |
111 |
| - round_nearest (int): Round the number of channels in each layer to be a multiple of this number |
112 |
| - Set to 1 to turn off rounding |
113 |
| - block: Module specifying inverted residual building block for mobilenet |
114 |
| - norm_layer: Module specifying the normalization layer to use |
115 |
| -
|
116 |
| - """ |
117 |
| - super(MobileNetV2, self).__init__() |
118 |
| - |
119 |
| - if block is None: |
120 |
| - block = InvertedResidual |
121 |
| - |
122 |
| - if norm_layer is None: |
123 |
| - norm_layer = nn.BatchNorm2d |
124 |
| - |
125 |
| - input_channel = 32 |
126 |
| - last_channel = 1280 |
127 |
| - |
128 |
| - if inverted_residual_setting is None: |
129 |
| - inverted_residual_setting = [ |
130 |
| - # t, c, n, s |
131 |
| - [1, 16, 1, 1], |
132 |
| - [6, 24, 2, 2], |
133 |
| - [6, 32, 3, 2], |
134 |
| - [6, 64, 4, 2], |
135 |
| - [6, 96, 3, 1], |
136 |
| - [6, 160, 3, 2], |
137 |
| - [6, 320, 1, 1], |
138 |
| - ] |
139 |
| - |
140 |
| - # only check the first element, assuming user knows t,c,n,s are required |
141 |
| - if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: |
142 |
| - raise ValueError("inverted_residual_setting should be non-empty " |
143 |
| - "or a 4-element list, got {}".format(inverted_residual_setting)) |
144 |
| - |
145 |
| - # building first layer |
146 |
| - input_channel = _make_divisible(input_channel * width_mult, round_nearest) |
147 |
| - self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) |
148 |
| - features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)] |
149 |
| - # building inverted residual blocks |
150 |
| - for t, c, n, s in inverted_residual_setting: |
151 |
| - output_channel = _make_divisible(c * width_mult, round_nearest) |
152 |
| - for i in range(n): |
153 |
| - stride = s if i == 0 else 1 |
154 |
| - features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) |
155 |
| - input_channel = output_channel |
156 |
| - # building last several layers |
157 |
| - features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer)) |
158 |
| - # make it nn.Sequential |
159 |
| - self.features = nn.Sequential(*features) |
160 |
| - |
161 |
| - # building classifier |
162 |
| - self.classifier = nn.Sequential( |
163 |
| - nn.Dropout(0.2), |
164 |
| - nn.Linear(self.last_channel, num_classes), |
165 |
| - ) |
166 |
| - |
167 |
| - # weight initialization |
168 |
| - for m in self.modules(): |
169 |
| - if isinstance(m, nn.Conv2d): |
170 |
| - nn.init.kaiming_normal_(m.weight, mode='fan_out') |
171 |
| - if m.bias is not None: |
172 |
| - nn.init.zeros_(m.bias) |
173 |
| - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): |
174 |
| - nn.init.ones_(m.weight) |
175 |
| - nn.init.zeros_(m.bias) |
176 |
| - elif isinstance(m, nn.Linear): |
177 |
| - nn.init.normal_(m.weight, 0, 0.01) |
178 |
| - nn.init.zeros_(m.bias) |
179 |
| - |
180 |
| - def _forward_impl(self, x: Tensor) -> Tensor: |
181 |
| - # This exists since TorchScript doesn't support inheritance, so the superclass method |
182 |
| - # (this one) needs to have a name other than `forward` that can be accessed in a subclass |
183 |
| - x = self.features(x) |
184 |
| - # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0] |
185 |
| - x = nn.functional.adaptive_avg_pool2d(x, (1, 1)).reshape(x.shape[0], -1) |
186 |
| - x = self.classifier(x) |
187 |
| - return x |
188 |
| - |
189 |
| - def forward(self, x: Tensor) -> Tensor: |
190 |
| - return self._forward_impl(x) |
191 |
| - |
192 |
| - |
193 |
| -def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV2: |
194 |
| - """ |
195 |
| - Constructs a MobileNetV2 architecture from |
196 |
| - `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_. |
197 |
| -
|
198 |
| - Args: |
199 |
| - pretrained (bool): If True, returns a model pre-trained on ImageNet |
200 |
| - progress (bool): If True, displays a progress bar of the download to stderr |
201 |
| - """ |
202 |
| - model = MobileNetV2(**kwargs) |
203 |
| - if pretrained: |
204 |
| - state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], |
205 |
| - progress=progress) |
206 |
| - model.load_state_dict(state_dict) |
207 |
| - return model |
| 1 | +from .mobilenetv2 import MobileNetV2, mobilenet_v2 |
0 commit comments