Skip to content

Commit ff126ae

Browse files
authored
Replace MobileNetV3's SqueezeExcitation with EfficientNet's one (#4487)
* Reuse EfficientNet SE layer. * Deprecating the mobilenetv3.SqueezeExcitation layer. * Passing the right activation on quantization. * Making strict named param. * Set default params if missing. * Fixing typos.
1 parent 13bd09d commit ff126ae

File tree

2 files changed

+59
-28
lines changed

2 files changed

+59
-28
lines changed

torchvision/models/mobilenetv3.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1+
import warnings
12
import torch
23

34
from functools import partial
45
from torch import nn, Tensor
5-
from torch.nn import functional as F
6-
from typing import Any, Callable, Dict, List, Optional, Sequence
6+
from typing import Any, Callable, List, Optional, Sequence
77

88
from .._internally_replaced_utils import load_state_dict_from_url
9-
from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation
9+
from .efficientnet import SqueezeExcitation as SElayer
10+
from .mobilenetv2 import _make_divisible, ConvBNActivation
1011

1112

1213
__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"]
@@ -18,25 +19,16 @@
1819
}
1920

2021

21-
class SqueezeExcitation(nn.Module):
22-
# Implemented as described at Figure 4 of the MobileNetV3 paper
22+
class SqueezeExcitation(SElayer):
23+
"""DEPRECATED
24+
"""
2325
def __init__(self, input_channels: int, squeeze_factor: int = 4):
24-
super().__init__()
2526
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
26-
self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
27-
self.relu = nn.ReLU(inplace=True)
28-
self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)
29-
30-
def _scale(self, input: Tensor, inplace: bool) -> Tensor:
31-
scale = F.adaptive_avg_pool2d(input, 1)
32-
scale = self.fc1(scale)
33-
scale = self.relu(scale)
34-
scale = self.fc2(scale)
35-
return F.hardsigmoid(scale, inplace=inplace)
36-
37-
def forward(self, input: Tensor) -> Tensor:
38-
scale = self._scale(input, True)
39-
return scale * input
27+
super().__init__(input_channels, squeeze_channels, scale_activation=nn.Hardsigmoid)
28+
self.relu = self.activation
29+
delattr(self, 'activation')
30+
warnings.warn(
31+
"This SqueezeExcitation class is deprecated and will be removed in future versions.", FutureWarning)
4032

4133

4234
class InvertedResidualConfig:
@@ -60,7 +52,7 @@ def adjust_channels(channels: int, width_mult: float):
6052
class InvertedResidual(nn.Module):
6153
# Implemented as described at section 5 of MobileNetV3 paper
6254
def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module],
63-
se_layer: Callable[..., nn.Module] = SqueezeExcitation):
55+
se_layer: Callable[..., nn.Module] = partial(SElayer, scale_activation=nn.Hardsigmoid)):
6456
super().__init__()
6557
if not (1 <= cnf.stride <= 2):
6658
raise ValueError('illegal stride value')
@@ -81,7 +73,8 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod
8173
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
8274
norm_layer=norm_layer, activation_layer=activation_layer))
8375
if cnf.use_se:
84-
layers.append(se_layer(cnf.expanded_channels))
76+
squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8)
77+
layers.append(se_layer(cnf.expanded_channels, squeeze_channels))
8578

8679
# project
8780
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,

torchvision/models/quantization/mobilenetv3.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import torch
22
from torch import nn, Tensor
33
from ..._internally_replaced_utils import load_state_dict_from_url
4-
from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\
5-
SqueezeExcitation, model_urls, _mobilenet_v3_conf
4+
from ..efficientnet import SqueezeExcitation as SElayer
5+
from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\
6+
model_urls, _mobilenet_v3_conf
67
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
78
from typing import Any, List, Optional
89
from .utils import _replace_relu
@@ -16,16 +17,53 @@
1617
}
1718

1819

19-
class QuantizableSqueezeExcitation(SqueezeExcitation):
20+
class QuantizableSqueezeExcitation(SElayer):
21+
_version = 2
22+
2023
def __init__(self, *args: Any, **kwargs: Any) -> None:
24+
kwargs["scale_activation"] = nn.Hardsigmoid
2125
super().__init__(*args, **kwargs)
2226
self.skip_mul = nn.quantized.FloatFunctional()
2327

2428
def forward(self, input: Tensor) -> Tensor:
25-
return self.skip_mul.mul(self._scale(input, False), input)
29+
return self.skip_mul.mul(self._scale(input), input)
2630

2731
def fuse_model(self) -> None:
28-
fuse_modules(self, ['fc1', 'relu'], inplace=True)
32+
fuse_modules(self, ['fc1', 'activation'], inplace=True)
33+
34+
def _load_from_state_dict(
35+
self,
36+
state_dict,
37+
prefix,
38+
local_metadata,
39+
strict,
40+
missing_keys,
41+
unexpected_keys,
42+
error_msgs,
43+
):
44+
version = local_metadata.get("version", None)
45+
46+
if version is None or version < 2:
47+
default_state_dict = {
48+
"scale_activation.activation_post_process.scale": torch.tensor([1.]),
49+
"scale_activation.activation_post_process.zero_point": torch.tensor([0], dtype=torch.int32),
50+
"scale_activation.activation_post_process.fake_quant_enabled": torch.tensor([1]),
51+
"scale_activation.activation_post_process.observer_enabled": torch.tensor([1]),
52+
}
53+
for k, v in default_state_dict.items():
54+
full_key = prefix + k
55+
if full_key not in state_dict:
56+
state_dict[full_key] = v
57+
58+
super()._load_from_state_dict(
59+
state_dict,
60+
prefix,
61+
local_metadata,
62+
strict,
63+
missing_keys,
64+
unexpected_keys,
65+
error_msgs,
66+
)
2967

3068

3169
class QuantizableInvertedResidual(InvertedResidual):
@@ -78,7 +116,7 @@ def _load_weights(
78116
arch: str,
79117
model: QuantizableMobileNetV3,
80118
model_url: Optional[str],
81-
progress: bool,
119+
progress: bool
82120
) -> None:
83121
if model_url is None:
84122
raise ValueError("No checkpoint is available for {}".format(arch))

0 commit comments

Comments
 (0)