Skip to content

Implementation of the MNASNet family of models #829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jun 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ architectures for image classification:
- `ShuffleNet`_ v2
- `MobileNet`_ v2
- `ResNeXt`_
- `MNASNet`_

You can construct a model with random weights by calling its constructor:

Expand All @@ -40,6 +41,7 @@ You can construct a model with random weights by calling its constructor:
shufflenet = models.shufflenet_v2_x1_0()
mobilenet = models.mobilenet_v2()
resnext50_32x4d = models.resnext50_32x4d()
mnasnet = models.mnasnet1_0()

We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
These can be constructed by passing ``pretrained=True``:
Expand All @@ -57,6 +59,7 @@ These can be constructed by passing ``pretrained=True``:
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet = models.mobilenet_v2(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True)

Instancing a pre-trained model will download its weights to a cache directory.
This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See
Expand Down Expand Up @@ -111,6 +114,7 @@ ShuffleNet V2 30.64 11.68
MobileNet V2 28.12 9.71
ResNeXt-50-32x4d 22.38 6.30
ResNeXt-101-32x8d 20.69 5.47
MNASNet 1.0 26.49 8.456
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question:

I've tried testing the models with the references/train.py file, and got

 * Acc@1 73.456 Acc@5 91.510

which corresponds to 26.54 and 8.49.

I wonder if there is a difference in our data, or if the model that I downloaded is not the same?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll verify and report back (and train a better model if needed). Could be something as dumb as the wrong version of Pillow. To aid with investigation, please answer the following questions:

  • Are you using Pillow-SIMD?
  • Are you compiling it from source?
  • Are you using libjpeg-turbo8?

Copy link

@depthwise depthwise Jun 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And also, did you change the number of epochs when training? I think that particular model actually used more epochs. I'm 1e100 just under another account

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, I trained this 'b1' variant of mnasnet to 74.66 and the 'a1' SE variant to 75.45. They both took over 400 epochs, both using RMSprop with Google like h-params. Using EMA of the model weights was necessary to match/surpass the papers...

Models only: https://github.com/rwightman/gen-efficientnet-pytorch#pretrained

Training code w/ models: https://github.com/rwightman/pytorch-image-models

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@depthwise I don't think that Pillow / the different image processing is the cause.
I've done extensive experiments in the past with multiple models, and they all were fairly insensitive to using PIL / OpenCV. If that is indeed the case, then I'd be surprised, and that could indicate that the model is very fragile to small perturbations.

For completeness, I used Pillow 5.4.1, from pip, using libjpeg-turbo8

@rwightman interesting, and what happens if you do not apply EMA, do you recall what accuracies you get?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So to investigate this I wrote a simple eval script, which I pushed to https://github.com/1e100/mnasnet_trainer/blob/master/eval.py.

The results with Pillow-SIMD/libjpeg-turbo-8 are as follows:

Dev/mnasnet_trainer % ./eval.py
Evaluating pretrained mnasnet1_0
1.0769143749256522 [('prec1', 73.490265), ('prec5', 91.53294)]
Evaluating pretrained mnasnet0_5
1.3720299355229553 [('prec1', 67.59815), ('prec5', 87.51842)]

Neither of which matches the published numbers exactly. MNASNet 1.0 is slightly worse than in the doc says. MNASNet 0.5 is slightly better than the checkpoint name would imply (67.598% top1 vs 67.592).

The results with "plain" pillow 6.0.0 from PyPl are as follows:

% ./eval.py
Evaluating pretrained mnasnet1_0
1.0772113243536072 [('prec1', 73.46037), ('prec5', 91.52099)]
Evaluating pretrained mnasnet0_5
1.372453243756781 [('prec1', 67.606026), ('prec5', 87.50845)]

So the top1 for 1.0 gets a bit worse, and for 0.5 it gets a bit better. I've observed such sensitivity with other "efficient" models in the past. In particular, the resize algorithm (which is different in Pillow SIMD) seems to make a noticeable difference. The effect of this on smaller models is easily measurable. Something as mundane as a different JPEG decoder affects them, and so do the software versions: CUDA/cuDNN, PyTorch, etc - a number of these are different between then and now.

Just to be on the safe side, though, I kicked off another run for 1.0. I should have the results sometime over the weekend, one way or the other.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@1e100 if you are running the training again, would you mind using the code from references/classification, and maybe only changing the lr step?

For MobileNet V2, I used the following commands, on 8 GPUs

--model mobilenet_v2 --epochs 300 --lr 0.045 --wd 0.00004 --lr-step-size 1 --lr-gamma 0.98

The most important thing we are trying to do here is to have a simple path for reproducible research, so having a few % worse accuracy (~0.3-0.5 e.g.,) but with a reproducible script available in torchvision would be preferable I'd say.

This way, we can indeed compare apples with apples

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa what version of Pillow should I use for this? I use custom-compiled PillowSIMD for my own training runs, but if we want things to be more repeatable, I could use the slower, default install of PillowSIMD.

Here's how I compile it on my machines:

CC="cc -march=native -mtune=native -O3" pip3 install \
    --force-reinstall --user pillow-simd

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@1e100 sorry, I missed your message.

You can use current version that you have available, don't worry about potentially small differences.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still training. First training run wasn't too good (I tweaked the training regime, and this model is sensitive to that kind of thing), so I'm using that as initialization on the second run. I'll update (and send a PR) when I get a good result.

================================ ============= =============


Expand All @@ -124,6 +128,7 @@ ResNeXt-101-32x8d 20.69 5.47
.. _ShuffleNet: https://arxiv.org/abs/1807.11164
.. _MobileNet: https://arxiv.org/abs/1801.04381
.. _ResNeXt: https://arxiv.org/abs/1611.05431
.. _MNASNet: https://arxiv.org/abs/1807.11626

.. currentmodule:: torchvision.models

Expand Down Expand Up @@ -197,6 +202,14 @@ ResNext
.. autofunction:: resnext50_32x4d
.. autofunction:: resnext101_32x8d

MNASNet
--------

.. autofunction:: mnasnet0_5
.. autofunction:: mnasnet0_75
.. autofunction:: mnasnet1_0
.. autofunction:: mnasnet1_3


Semantic Segmentation
=====================
Expand Down
1 change: 1 addition & 0 deletions torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .densenet import *
from .googlenet import *
from .mobilenet import *
from .mnasnet import *
from .shufflenetv2 import *
from . import segmentation
from . import detection
183 changes: 183 additions & 0 deletions torchvision/models/mnasnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import math

import torch
import torch.nn as nn
from .utils import load_state_dict_from_url

__all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3']

_MODEL_URLS = {
"mnasnet0_5":
"https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet0.5_top1_67.592-7c6cb539b9.pth",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a follow-up commit that updates this

"mnasnet0_75": None,
"mnasnet1_0":
"https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet1.0_top1_73.512-f206786ef8.pth",
"mnasnet1_3": None
}

# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
# 1.0 - tensorflow.
_BN_MOMENTUM = 1 - 0.9997


class _InvertedResidual(nn.Module):

def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor,
bn_momentum=0.1):
super(_InvertedResidual, self).__init__()
assert stride in [1, 2]
assert kernel_size in [3, 5]
mid_ch = in_ch * expansion_factor
self.apply_residual = (in_ch == out_ch and stride == 1)
self.layers = nn.Sequential(
# Pointwise
nn.Conv2d(in_ch, mid_ch, 1, bias=False),
nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
nn.ReLU(inplace=True),
# Depthwise
nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2,
stride=stride, groups=mid_ch, bias=False),
nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
nn.ReLU(inplace=True),
# Linear pointwise. Note that there's no activation.
nn.Conv2d(mid_ch, out_ch, 1, bias=False),
nn.BatchNorm2d(out_ch, momentum=bn_momentum))

def forward(self, input):
if self.apply_residual:
return self.layers(input) + input
else:
return self.layers(input)


def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats,
bn_momentum):
""" Creates a stack of inverted residuals. """
assert repeats >= 1
# First one has no skip, because feature map size changes.
first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor,
bn_momentum=bn_momentum)
remaining = []
for _ in range(1, repeats):
remaining.append(
_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor,
bn_momentum=bn_momentum))
return nn.Sequential(first, *remaining)


def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
""" Asymmetric rounding to make `val` divisible by `divisor`. With default
bias, will round up, unless the number is no more than 10% greater than the
smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """
assert 0.0 < round_up_bias < 1.0
new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
return new_val if new_val >= round_up_bias * val else new_val + divisor


def _scale_depths(depths, alpha):
""" Scales tensor depths as in reference MobileNet code, prefers rouding up
rather than down. """
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]


class MNASNet(torch.nn.Module):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf.
>>> model = MNASNet(1000, 1.0)
>>> x = torch.rand(1, 3, 224, 224)
>>> y = model(x)
>>> y.dim()
1
>>> y.nelement()
1000
"""

def __init__(self, alpha, num_classes=1000, dropout=0.2):
super(MNASNet, self).__init__()
depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha)
layers = [
# First layer: regular conv.
nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
# Depthwise separable, no skip.
nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
# MNASNet blocks: stacks of inverted residuals.
_stack(16, depths[0], 3, 2, 3, 3, _BN_MOMENTUM),
_stack(depths[0], depths[1], 5, 2, 3, 3, _BN_MOMENTUM),
_stack(depths[1], depths[2], 5, 2, 6, 3, _BN_MOMENTUM),
_stack(depths[2], depths[3], 3, 1, 6, 2, _BN_MOMENTUM),
_stack(depths[3], depths[4], 5, 2, 6, 4, _BN_MOMENTUM),
_stack(depths[4], depths[5], 3, 1, 6, 1, _BN_MOMENTUM),
# Final mapping to classifier input.
nn.Conv2d(depths[5], 1280, 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
]
self.layers = nn.Sequential(*layers)
self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True),
nn.Linear(1280, num_classes))
self._initialize_weights()

def forward(self, x):
x = self.layers(x)
# Equivalent to global avgpool and removing H and W dimensions.
x = x.mean([2, 3])
return self.classifier(x)

def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out",
nonlinearity="relu")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0.01)
nn.init.zeros_(m.bias)


def _load_pretrained(model_name, model):
if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
raise ValueError(
"No checkpoint is available for model type {}".format(model_name))
checkpoint_url = _MODEL_URLS[model_name]
model.load_state_dict(load_state_dict_from_url(checkpoint_url))


def mnasnet0_5(pretrained=False, **kwargs):
""" MNASNet with depth multiplier of 0.5. """
model = MNASNet(0.5, **kwargs)
if pretrained:
_load_pretrained("mnasnet0_5", model)
return model


def mnasnet0_75(pretrained=False, **kwargs):
""" MNASNet with depth multiplier of 0.75. """
model = MNASNet(0.75, **kwargs)
if pretrained:
_load_pretrained("mnasnet0_75", model)
return model


def mnasnet1_0(pretrained=False, **kwargs):
""" MNASNet with depth multiplier of 1.0. """
model = MNASNet(1.0, **kwargs)
if pretrained:
_load_pretrained("mnasnet1_0", model)
return model


def mnasnet1_3(pretrained=False, **kwargs):
""" MNASNet with depth multiplier of 1.3. """
model = MNASNet(1.3, **kwargs)
if pretrained:
_load_pretrained("mnasnet1_3", model)
return model