Skip to content

Commit 367e851

Browse files
1e100fmassa
authored andcommitted
Bugfix for MNASNet (#1224)
* Add initial mnasnet impl * Remove all type hints, comply with PyTorch overall style * Expose models * Remove avgpool from features() and add separately * Fix python3-only stuff, replace subclasses with functions * fix __all__ * Fix typo * Remove conditional dropout * Make dropout functional * Addressing @fmassa's feedback, round 1 * Replaced adaptive avgpool with mean on H and W to prevent collapsing the batch dimension * Partially address feedback * YAPF * Removed redundant class vars * Update urls to releases * Add information to models.rst * Replace init with kaiming_normal_ in fan-out mode * Use load_state_dict_from_url * Fix depth scaling on first 2 layers * Restore initialization * Match reference implementation initialization for dense layer * Meant to use Kaiming * Remove spurious relu * Point to the newest 0.5 checkpoint * Latest pretrained checkpoint * Restore 1.0 checkpoint * YAPF * Implement backwards compat as suggested by Soumith * Update checkpoint URL * Move warnings up * Record a couple more function parameters * Update comment * Set the correct version such that if the BC-patched model is saved, it could be reloaded with BC patching again * Set a member var, not class var * Update mnasnet.py Remove unused member var as per review. * Update the path to weights
1 parent 3394c0f commit 367e851

File tree

1 file changed

+70
-19
lines changed

1 file changed

+70
-19
lines changed

torchvision/models/mnasnet.py

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
import warnings
23

34
import torch
45
import torch.nn as nn
@@ -8,7 +9,7 @@
89

910
_MODEL_URLS = {
1011
"mnasnet0_5":
11-
"https://download.pytorch.org/models/mnasnet0.5_top1_67.592-7c6cb539b9.pth",
12+
"https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
1213
"mnasnet0_75": None,
1314
"mnasnet1_0":
1415
"https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
@@ -74,14 +75,16 @@ def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
7475
return new_val if new_val >= round_up_bias * val else new_val + divisor
7576

7677

77-
def _scale_depths(depths, alpha):
78+
def _get_depths(alpha):
7879
""" Scales tensor depths as in reference MobileNet code, prefers rouding up
7980
rather than down. """
81+
depths = [32, 16, 24, 40, 80, 96, 192, 320]
8082
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
8183

8284

8385
class MNASNet(torch.nn.Module):
84-
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf.
86+
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
87+
implements the B1 variant of the model.
8588
>>> model = MNASNet(1000, 1.0)
8689
>>> x = torch.rand(1, 3, 224, 224)
8790
>>> y = model(x)
@@ -90,30 +93,36 @@ class MNASNet(torch.nn.Module):
9093
>>> y.nelement()
9194
1000
9295
"""
96+
# Version 2 adds depth scaling in the initial stages of the network.
97+
_version = 2
9398

9499
def __init__(self, alpha, num_classes=1000, dropout=0.2):
95100
super(MNASNet, self).__init__()
96-
depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha)
101+
assert alpha > 0.0
102+
self.alpha = alpha
103+
self.num_classes = num_classes
104+
depths = _get_depths(alpha)
97105
layers = [
98106
# First layer: regular conv.
99-
nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
100-
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
107+
nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False),
108+
nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
101109
nn.ReLU(inplace=True),
102110
# Depthwise separable, no skip.
103-
nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False),
104-
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
111+
nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1,
112+
groups=depths[0], bias=False),
113+
nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
105114
nn.ReLU(inplace=True),
106-
nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
107-
nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
115+
nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False),
116+
nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM),
108117
# MNASNet blocks: stacks of inverted residuals.
109-
_stack(16, depths[0], 3, 2, 3, 3, _BN_MOMENTUM),
110-
_stack(depths[0], depths[1], 5, 2, 3, 3, _BN_MOMENTUM),
111-
_stack(depths[1], depths[2], 5, 2, 6, 3, _BN_MOMENTUM),
112-
_stack(depths[2], depths[3], 3, 1, 6, 2, _BN_MOMENTUM),
113-
_stack(depths[3], depths[4], 5, 2, 6, 4, _BN_MOMENTUM),
114-
_stack(depths[4], depths[5], 3, 1, 6, 1, _BN_MOMENTUM),
118+
_stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
119+
_stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM),
120+
_stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM),
121+
_stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM),
122+
_stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM),
123+
_stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM),
115124
# Final mapping to classifier input.
116-
nn.Conv2d(depths[5], 1280, 1, padding=0, stride=1, bias=False),
125+
nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False),
117126
nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
118127
nn.ReLU(inplace=True),
119128
]
@@ -139,16 +148,58 @@ def _initialize_weights(self):
139148
nn.init.ones_(m.weight)
140149
nn.init.zeros_(m.bias)
141150
elif isinstance(m, nn.Linear):
142-
nn.init.normal_(m.weight, 0.01)
151+
nn.init.kaiming_uniform_(m.weight, mode="fan_out",
152+
nonlinearity="sigmoid")
143153
nn.init.zeros_(m.bias)
144154

155+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
156+
missing_keys, unexpected_keys, error_msgs):
157+
version = local_metadata.get("version", None)
158+
assert version in [1, 2]
159+
160+
if version == 1 and not self.alpha == 1.0:
161+
# In the initial version of the model (v1), stem was fixed-size.
162+
# All other layer configurations were the same. This will patch
163+
# the model so that it's identical to v1. Model with alpha 1.0 is
164+
# unaffected.
165+
depths = _get_depths(self.alpha)
166+
v1_stem = [
167+
nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
168+
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
169+
nn.ReLU(inplace=True),
170+
nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32,
171+
bias=False),
172+
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
173+
nn.ReLU(inplace=True),
174+
nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
175+
nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
176+
_stack(16, depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
177+
]
178+
for idx, layer in enumerate(v1_stem):
179+
self.layers[idx] = layer
180+
181+
# The model is now identical to v1, and must be saved as such.
182+
self._version = 1
183+
warnings.warn(
184+
"A new version of MNASNet model has been implemented. "
185+
"Your checkpoint was saved using the previous version. "
186+
"This checkpoint will load and work as before, but "
187+
"you may want to upgrade by training a newer model or "
188+
"transfer learning from an updated ImageNet checkpoint.",
189+
UserWarning)
190+
191+
super(MNASNet, self)._load_from_state_dict(
192+
state_dict, prefix, local_metadata, strict, missing_keys,
193+
unexpected_keys, error_msgs)
194+
145195

146196
def _load_pretrained(model_name, model, progress):
147197
if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
148198
raise ValueError(
149199
"No checkpoint is available for model type {}".format(model_name))
150200
checkpoint_url = _MODEL_URLS[model_name]
151-
model.load_state_dict(load_state_dict_from_url(checkpoint_url, progress=progress))
201+
model.load_state_dict(
202+
load_state_dict_from_url(checkpoint_url, progress=progress))
152203

153204

154205
def mnasnet0_5(pretrained=False, progress=True, **kwargs):

0 commit comments

Comments
 (0)