diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 774dd8174b0..59677427f1e 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -1,4 +1,5 @@ import math +import warnings import torch import torch.nn as nn @@ -8,7 +9,7 @@ _MODEL_URLS = { "mnasnet0_5": - "https://download.pytorch.org/models/mnasnet0.5_top1_67.592-7c6cb539b9.pth", + "https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", "mnasnet0_75": None, "mnasnet1_0": "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", @@ -74,14 +75,16 @@ def _round_to_multiple_of(val, divisor, round_up_bias=0.9): return new_val if new_val >= round_up_bias * val else new_val + divisor -def _scale_depths(depths, alpha): +def _get_depths(alpha): """ Scales tensor depths as in reference MobileNet code, prefers rouding up rather than down. """ + depths = [32, 16, 24, 40, 80, 96, 192, 320] return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] class MNASNet(torch.nn.Module): - """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. + """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This + implements the B1 variant of the model. >>> model = MNASNet(1000, 1.0) >>> x = torch.rand(1, 3, 224, 224) >>> y = model(x) @@ -90,30 +93,36 @@ class MNASNet(torch.nn.Module): >>> y.nelement() 1000 """ + # Version 2 adds depth scaling in the initial stages of the network. + _version = 2 def __init__(self, alpha, num_classes=1000, dropout=0.2): super(MNASNet, self).__init__() - depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha) + assert alpha > 0.0 + self.alpha = alpha + self.num_classes = num_classes + depths = _get_depths(alpha) layers = [ # First layer: regular conv. - nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False), - nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), + nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False), + nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM), nn.ReLU(inplace=True), # Depthwise separable, no skip. - nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False), - nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), + nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, + groups=depths[0], bias=False), + nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM), nn.ReLU(inplace=True), - nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False), - nn.BatchNorm2d(16, momentum=_BN_MOMENTUM), + nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False), + nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM), # MNASNet blocks: stacks of inverted residuals. - _stack(16, depths[0], 3, 2, 3, 3, _BN_MOMENTUM), - _stack(depths[0], depths[1], 5, 2, 3, 3, _BN_MOMENTUM), - _stack(depths[1], depths[2], 5, 2, 6, 3, _BN_MOMENTUM), - _stack(depths[2], depths[3], 3, 1, 6, 2, _BN_MOMENTUM), - _stack(depths[3], depths[4], 5, 2, 6, 4, _BN_MOMENTUM), - _stack(depths[4], depths[5], 3, 1, 6, 1, _BN_MOMENTUM), + _stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM), + _stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM), + _stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM), + _stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM), + _stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM), + _stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM), # Final mapping to classifier input. - nn.Conv2d(depths[5], 1280, 1, padding=0, stride=1, bias=False), + nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False), nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM), nn.ReLU(inplace=True), ] @@ -139,16 +148,58 @@ def _initialize_weights(self): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): - nn.init.normal_(m.weight, 0.01) + nn.init.kaiming_uniform_(m.weight, mode="fan_out", + nonlinearity="sigmoid") nn.init.zeros_(m.bias) + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + version = local_metadata.get("version", None) + assert version in [1, 2] + + if version == 1 and not self.alpha == 1.0: + # In the initial version of the model (v1), stem was fixed-size. + # All other layer configurations were the same. This will patch + # the model so that it's identical to v1. Model with alpha 1.0 is + # unaffected. + depths = _get_depths(self.alpha) + v1_stem = [ + nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False), + nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True), + nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, + bias=False), + nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True), + nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False), + nn.BatchNorm2d(16, momentum=_BN_MOMENTUM), + _stack(16, depths[2], 3, 2, 3, 3, _BN_MOMENTUM), + ] + for idx, layer in enumerate(v1_stem): + self.layers[idx] = layer + + # The model is now identical to v1, and must be saved as such. + self._version = 1 + warnings.warn( + "A new version of MNASNet model has been implemented. " + "Your checkpoint was saved using the previous version. " + "This checkpoint will load and work as before, but " + "you may want to upgrade by training a newer model or " + "transfer learning from an updated ImageNet checkpoint.", + UserWarning) + + super(MNASNet, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, + unexpected_keys, error_msgs) + def _load_pretrained(model_name, model, progress): if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None: raise ValueError( "No checkpoint is available for model type {}".format(model_name)) checkpoint_url = _MODEL_URLS[model_name] - model.load_state_dict(load_state_dict_from_url(checkpoint_url, progress=progress)) + model.load_state_dict( + load_state_dict_from_url(checkpoint_url, progress=progress)) def mnasnet0_5(pretrained=False, progress=True, **kwargs):