1
+ import warnings
1
2
import torch
2
3
3
4
from functools import partial
4
5
from torch import nn , Tensor
5
- from torch .nn import functional as F
6
- from typing import Any , Callable , Dict , List , Optional , Sequence
6
+ from typing import Any , Callable , List , Optional , Sequence
7
7
8
8
from .._internally_replaced_utils import load_state_dict_from_url
9
- from torchvision .models .mobilenetv2 import _make_divisible , ConvBNActivation
9
+ from .efficientnet import SqueezeExcitation as SElayer
10
+ from .mobilenetv2 import _make_divisible , ConvBNActivation
10
11
11
12
12
13
__all__ = ["MobileNetV3" , "mobilenet_v3_large" , "mobilenet_v3_small" ]
18
19
}
19
20
20
21
21
- class SqueezeExcitation (nn .Module ):
22
- # Implemented as described at Figure 4 of the MobileNetV3 paper
22
+ class SqueezeExcitation (SElayer ):
23
+ """DEPRECATED
24
+ """
23
25
def __init__ (self , input_channels : int , squeeze_factor : int = 4 ):
24
- super ().__init__ ()
25
26
squeeze_channels = _make_divisible (input_channels // squeeze_factor , 8 )
26
- self .fc1 = nn .Conv2d (input_channels , squeeze_channels , 1 )
27
- self .relu = nn .ReLU (inplace = True )
28
- self .fc2 = nn .Conv2d (squeeze_channels , input_channels , 1 )
29
-
30
- def _scale (self , input : Tensor , inplace : bool ) -> Tensor :
31
- scale = F .adaptive_avg_pool2d (input , 1 )
32
- scale = self .fc1 (scale )
33
- scale = self .relu (scale )
34
- scale = self .fc2 (scale )
35
- return F .hardsigmoid (scale , inplace = inplace )
36
-
37
- def forward (self , input : Tensor ) -> Tensor :
38
- scale = self ._scale (input , True )
39
- return scale * input
27
+ super ().__init__ (input_channels , squeeze_channels , scale_activation = nn .Hardsigmoid )
28
+ self .relu = self .activation
29
+ delattr (self , 'activation' )
30
+ warnings .warn (
31
+ "This SqueezeExcitation class is deprecated and will be removed in future versions." , FutureWarning )
40
32
41
33
42
34
class InvertedResidualConfig :
@@ -60,7 +52,7 @@ def adjust_channels(channels: int, width_mult: float):
60
52
class InvertedResidual (nn .Module ):
61
53
# Implemented as described at section 5 of MobileNetV3 paper
62
54
def __init__ (self , cnf : InvertedResidualConfig , norm_layer : Callable [..., nn .Module ],
63
- se_layer : Callable [..., nn .Module ] = SqueezeExcitation ):
55
+ se_layer : Callable [..., nn .Module ] = partial ( SElayer , scale_activation = nn . Hardsigmoid ) ):
64
56
super ().__init__ ()
65
57
if not (1 <= cnf .stride <= 2 ):
66
58
raise ValueError ('illegal stride value' )
@@ -81,7 +73,8 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod
81
73
stride = stride , dilation = cnf .dilation , groups = cnf .expanded_channels ,
82
74
norm_layer = norm_layer , activation_layer = activation_layer ))
83
75
if cnf .use_se :
84
- layers .append (se_layer (cnf .expanded_channels ))
76
+ squeeze_channels = _make_divisible (cnf .expanded_channels // 4 , 8 )
77
+ layers .append (se_layer (cnf .expanded_channels , squeeze_channels ))
85
78
86
79
# project
87
80
layers .append (ConvBNActivation (cnf .expanded_channels , cnf .out_channels , kernel_size = 1 , norm_layer = norm_layer ,
0 commit comments