Skip to content

Commit b94a401

Browse files
authored
Making protected params of MobileNetV3 public (#3828)
* Converting private parameters to public. * Add kwargs to handle extra params. * Add another kwargs. * Add arguments in _mobilenet_extractor.
1 parent f5aa5f5 commit b94a401

File tree

4 files changed

+14
-16
lines changed

4 files changed

+14
-16
lines changed

torchvision/models/detection/ssdlite.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,9 @@ def __init__(self, in_channels: List[int], num_anchors: List[int], norm_layer: C
9595

9696

9797
class SSDLiteFeatureExtractorMobileNet(nn.Module):
98-
def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], **kwargs: Any):
98+
def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], width_mult: float = 1.0,
99+
min_depth: int = 16, **kwargs: Any):
99100
super().__init__()
100-
# non-public config parameters
101-
min_depth = kwargs.pop('_min_depth', 16)
102-
width_mult = kwargs.pop('_width_mult', 1.0)
103101

104102
assert not backbone[c4_pos].use_res_connect
105103
self.features = nn.Sequential(
@@ -197,7 +195,7 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru
197195
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)
198196

199197
backbone = _mobilenet_extractor("mobilenet_v3_large", progress, pretrained_backbone, trainable_backbone_layers,
200-
norm_layer, _reduced_tail=reduce_tail, _width_mult=1.0)
198+
norm_layer, reduced_tail=reduce_tail, **kwargs)
201199

202200
size = (320, 320)
203201
anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95)

torchvision/models/mobilenetv3.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def __init__(
106106
last_channel: int,
107107
num_classes: int = 1000,
108108
block: Optional[Callable[..., nn.Module]] = None,
109-
norm_layer: Optional[Callable[..., nn.Module]] = None
109+
norm_layer: Optional[Callable[..., nn.Module]] = None,
110+
**kwargs: Any
110111
) -> None:
111112
"""
112113
MobileNet V3 main class
@@ -184,11 +185,10 @@ def forward(self, x: Tensor) -> Tensor:
184185
return self._forward_impl(x)
185186

186187

187-
def _mobilenet_v3_conf(arch: str, params: Dict[str, Any]):
188-
# non-public config parameters
189-
reduce_divider = 2 if params.pop('_reduced_tail', False) else 1
190-
dilation = 2 if params.pop('_dilated', False) else 1
191-
width_mult = params.pop('_width_mult', 1.0)
188+
def _mobilenet_v3_conf(arch: str, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False,
189+
**kwargs: Any):
190+
reduce_divider = 2 if reduced_tail else 1
191+
dilation = 2 if dilated else 1
192192

193193
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
194194
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
@@ -260,7 +260,7 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs
260260
progress (bool): If True, displays a progress bar of the download to stderr
261261
"""
262262
arch = "mobilenet_v3_large"
263-
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs)
263+
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs)
264264
return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
265265

266266

@@ -274,5 +274,5 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs
274274
progress (bool): If True, displays a progress bar of the download to stderr
275275
"""
276276
arch = "mobilenet_v3_small"
277-
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs)
277+
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs)
278278
return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)

torchvision/models/quantization/mobilenetv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,5 +127,5 @@ def mobilenet_v3_large(pretrained=False, progress=True, quantize=False, **kwargs
127127
quantize (bool): If True, returns a quantized model, else returns a float model
128128
"""
129129
arch = "mobilenet_v3_large"
130-
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs)
130+
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs)
131131
return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, quantize, **kwargs)

torchvision/models/segmentation/segmentation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True)
3232
aux_layer = 'layer3'
3333
aux_inplanes = 1024
3434
elif 'mobilenet_v3' in backbone_name:
35-
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, _dilated=True).features
35+
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features
3636

3737
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
3838
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
@@ -87,7 +87,7 @@ def _load_weights(model, arch_type, backbone, progress):
8787

8888

8989
def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=True):
90-
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, _dilated=True).features
90+
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features
9191

9292
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
9393
# The first and last blocks are always included because they are the C0 (conv1) and Cn.

0 commit comments

Comments
 (0)