diff --git a/docs/source/models.rst b/docs/source/models.rst index f27a555befe..167203b031e 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -24,6 +24,7 @@ architectures for image classification: - `ShuffleNet`_ v2 - `MobileNet`_ v2 - `ResNeXt`_ +- `MNASNet`_ You can construct a model with random weights by calling its constructor: @@ -40,6 +41,7 @@ You can construct a model with random weights by calling its constructor: shufflenet = models.shufflenet_v2_x1_0() mobilenet = models.mobilenet_v2() resnext50_32x4d = models.resnext50_32x4d() + mnasnet = models.mnasnet1_0() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. These can be constructed by passing ``pretrained=True``: @@ -57,6 +59,7 @@ These can be constructed by passing ``pretrained=True``: shufflenet = models.shufflenet_v2_x1_0(pretrained=True) mobilenet = models.mobilenet_v2(pretrained=True) resnext50_32x4d = models.resnext50_32x4d(pretrained=True) + mnasnet = models.mnasnet1_0(pretrained=True) Instancing a pre-trained model will download its weights to a cache directory. This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See @@ -111,6 +114,7 @@ ShuffleNet V2 30.64 11.68 MobileNet V2 28.12 9.71 ResNeXt-50-32x4d 22.38 6.30 ResNeXt-101-32x8d 20.69 5.47 +MNASNet 1.0 26.49 8.456 ================================ ============= ============= @@ -124,6 +128,7 @@ ResNeXt-101-32x8d 20.69 5.47 .. _ShuffleNet: https://arxiv.org/abs/1807.11164 .. _MobileNet: https://arxiv.org/abs/1801.04381 .. _ResNeXt: https://arxiv.org/abs/1611.05431 +.. _MNASNet: https://arxiv.org/abs/1807.11626 .. currentmodule:: torchvision.models @@ -197,6 +202,14 @@ ResNext .. autofunction:: resnext50_32x4d .. autofunction:: resnext101_32x8d +MNASNet +-------- + +.. autofunction:: mnasnet0_5 +.. autofunction:: mnasnet0_75 +.. autofunction:: mnasnet1_0 +.. autofunction:: mnasnet1_3 + Semantic Segmentation ===================== diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 7f460999296..f4b76156012 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -6,6 +6,7 @@ from .densenet import * from .googlenet import * from .mobilenet import * +from .mnasnet import * from .shufflenetv2 import * from . import segmentation from . import detection diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py new file mode 100644 index 00000000000..5deb87c2ad1 --- /dev/null +++ b/torchvision/models/mnasnet.py @@ -0,0 +1,183 @@ +import math + +import torch +import torch.nn as nn +from .utils import load_state_dict_from_url + +__all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3'] + +_MODEL_URLS = { + "mnasnet0_5": + "https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet0.5_top1_67.592-7c6cb539b9.pth", + "mnasnet0_75": None, + "mnasnet1_0": + "https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet1.0_top1_73.512-f206786ef8.pth", + "mnasnet1_3": None +} + +# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is +# 1.0 - tensorflow. +_BN_MOMENTUM = 1 - 0.9997 + + +class _InvertedResidual(nn.Module): + + def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, + bn_momentum=0.1): + super(_InvertedResidual, self).__init__() + assert stride in [1, 2] + assert kernel_size in [3, 5] + mid_ch = in_ch * expansion_factor + self.apply_residual = (in_ch == out_ch and stride == 1) + self.layers = nn.Sequential( + # Pointwise + nn.Conv2d(in_ch, mid_ch, 1, bias=False), + nn.BatchNorm2d(mid_ch, momentum=bn_momentum), + nn.ReLU(inplace=True), + # Depthwise + nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, + stride=stride, groups=mid_ch, bias=False), + nn.BatchNorm2d(mid_ch, momentum=bn_momentum), + nn.ReLU(inplace=True), + # Linear pointwise. Note that there's no activation. + nn.Conv2d(mid_ch, out_ch, 1, bias=False), + nn.BatchNorm2d(out_ch, momentum=bn_momentum)) + + def forward(self, input): + if self.apply_residual: + return self.layers(input) + input + else: + return self.layers(input) + + +def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats, + bn_momentum): + """ Creates a stack of inverted residuals. """ + assert repeats >= 1 + # First one has no skip, because feature map size changes. + first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, + bn_momentum=bn_momentum) + remaining = [] + for _ in range(1, repeats): + remaining.append( + _InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, + bn_momentum=bn_momentum)) + return nn.Sequential(first, *remaining) + + +def _round_to_multiple_of(val, divisor, round_up_bias=0.9): + """ Asymmetric rounding to make `val` divisible by `divisor`. With default + bias, will round up, unless the number is no more than 10% greater than the + smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """ + assert 0.0 < round_up_bias < 1.0 + new_val = max(divisor, int(val + divisor / 2) // divisor * divisor) + return new_val if new_val >= round_up_bias * val else new_val + divisor + + +def _scale_depths(depths, alpha): + """ Scales tensor depths as in reference MobileNet code, prefers rouding up + rather than down. """ + return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] + + +class MNASNet(torch.nn.Module): + """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. + >>> model = MNASNet(1000, 1.0) + >>> x = torch.rand(1, 3, 224, 224) + >>> y = model(x) + >>> y.dim() + 1 + >>> y.nelement() + 1000 + """ + + def __init__(self, alpha, num_classes=1000, dropout=0.2): + super(MNASNet, self).__init__() + depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha) + layers = [ + # First layer: regular conv. + nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False), + nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True), + # Depthwise separable, no skip. + nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False), + nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True), + nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False), + nn.BatchNorm2d(16, momentum=_BN_MOMENTUM), + # MNASNet blocks: stacks of inverted residuals. + _stack(16, depths[0], 3, 2, 3, 3, _BN_MOMENTUM), + _stack(depths[0], depths[1], 5, 2, 3, 3, _BN_MOMENTUM), + _stack(depths[1], depths[2], 5, 2, 6, 3, _BN_MOMENTUM), + _stack(depths[2], depths[3], 3, 1, 6, 2, _BN_MOMENTUM), + _stack(depths[3], depths[4], 5, 2, 6, 4, _BN_MOMENTUM), + _stack(depths[4], depths[5], 3, 1, 6, 1, _BN_MOMENTUM), + # Final mapping to classifier input. + nn.Conv2d(depths[5], 1280, 1, padding=0, stride=1, bias=False), + nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True), + ] + self.layers = nn.Sequential(*layers) + self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), + nn.Linear(1280, num_classes)) + self._initialize_weights() + + def forward(self, x): + x = self.layers(x) + # Equivalent to global avgpool and removing H and W dimensions. + x = x.mean([2, 3]) + return self.classifier(x) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", + nonlinearity="relu") + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0.01) + nn.init.zeros_(m.bias) + + +def _load_pretrained(model_name, model): + if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None: + raise ValueError( + "No checkpoint is available for model type {}".format(model_name)) + checkpoint_url = _MODEL_URLS[model_name] + model.load_state_dict(load_state_dict_from_url(checkpoint_url)) + + +def mnasnet0_5(pretrained=False, **kwargs): + """ MNASNet with depth multiplier of 0.5. """ + model = MNASNet(0.5, **kwargs) + if pretrained: + _load_pretrained("mnasnet0_5", model) + return model + + +def mnasnet0_75(pretrained=False, **kwargs): + """ MNASNet with depth multiplier of 0.75. """ + model = MNASNet(0.75, **kwargs) + if pretrained: + _load_pretrained("mnasnet0_75", model) + return model + + +def mnasnet1_0(pretrained=False, **kwargs): + """ MNASNet with depth multiplier of 1.0. """ + model = MNASNet(1.0, **kwargs) + if pretrained: + _load_pretrained("mnasnet1_0", model) + return model + + +def mnasnet1_3(pretrained=False, **kwargs): + """ MNASNet with depth multiplier of 1.3. """ + model = MNASNet(1.3, **kwargs) + if pretrained: + _load_pretrained("mnasnet1_3", model) + return model