Skip to content

Commit 4502083

Browse files
committed
Implement MBConv.
1 parent e173b8f commit 4502083

File tree

1 file changed

+54
-10
lines changed

1 file changed

+54
-10
lines changed

torchvision/models/efficientnet.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from torch import nn, Tensor
44
from torch.nn import functional as F
5-
from typing import Any, Optional
5+
from typing import Any, Callable, List, Optional
66

77
from .._internally_replaced_utils import load_state_dict_from_url
88

@@ -11,33 +11,77 @@
1111
from torchvision.models.mobilenetv3 import SqueezeExcitation
1212

1313

14-
__all__ = []
14+
__all__ = ["EfficientNet"]
1515

1616

17-
model_urls = {}
17+
model_urls = {
18+
"efficientnet_b0": "", # TOD: Add weights
19+
}
20+
21+
22+
def drop_connect(x: Tensor, rate: float):
23+
keep = torch.rand(size=(x.size(0), ), dtype=x.dtype, device=x.device) > rate
24+
keep = keep[(None, ) * (x.ndim - 1)].T
25+
return (x / (1.0 - rate)) * keep
1826

1927

2028
class MBConvConfig:
21-
# TODO: Add dilation for supporting detection and segmentation pipelines
2229
def __init__(self,
23-
kernel: int, stride: int,
24-
input_channels: int, out_channels: int, expand_ratio: float, se_ratio: float,
25-
skip: bool, width_mult: float):
30+
kernel: int, stride: int, dilation: int,
31+
input_channels: int, out_channels: int, expand_ratio: float,
32+
width_mult: float):
2633
self.kernel = kernel
2734
self.stride = stride
35+
self.dilation = dilation
2836
self.input_channels = self.adjust_channels(input_channels, width_mult)
2937
self.out_channels = self.adjust_channels(out_channels, width_mult)
3038
self.expanded_channels = self.adjust_channels(input_channels, expand_ratio * width_mult)
31-
self.se_channels = self.adjust_channels(input_channels, se_ratio * width_mult, 1)
32-
self.skip = skip
3339

3440
@staticmethod
3541
def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None):
3642
return _make_divisible(channels * width_mult, 8, min_value)
3743

3844

3945
class MBConv(nn.Module):
40-
pass
46+
def __init__(self, cnf: MBConvConfig, norm_layer: Callable[..., nn.Module],
47+
se_layer: Callable[..., nn.Module] = SqueezeExcitation):
48+
super().__init__()
49+
if not (1 <= cnf.stride <= 2):
50+
raise ValueError('illegal stride value')
51+
52+
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
53+
54+
layers: List[nn.Module] = []
55+
activation_layer = nn.SiLU
56+
57+
# expand
58+
if cnf.expanded_channels != cnf.input_channels:
59+
layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1,
60+
norm_layer=norm_layer, activation_layer=activation_layer))
61+
62+
# depthwise
63+
stride = 1 if cnf.dilation > 1 else cnf.stride
64+
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
65+
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
66+
norm_layer=norm_layer, activation_layer=activation_layer))
67+
68+
# squeeze and excitation
69+
layers.append(se_layer(cnf.expanded_channels, min_value=1, activation_fn=F.sigmoid))
70+
71+
# project
72+
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
73+
activation_layer=nn.Identity))
74+
75+
self.block = nn.Sequential(*layers)
76+
self.out_channels = cnf.out_channels
77+
78+
def forward(self, input: Tensor, drop_connect_rate: Optional[float] = None) -> Tensor:
79+
result = self.block(input)
80+
if self.use_res_connect:
81+
if self.training and drop_connect_rate:
82+
result = drop_connect(result, drop_connect_rate)
83+
result += input
84+
return result
4185

4286

4387
class EfficientNet(nn.Module):

0 commit comments

Comments
 (0)