From 50bfbe640ec1834e013295df2527c5c66d99b12f Mon Sep 17 00:00:00 2001 From: 1e100 <38598618+1e100@users.noreply.github.com> Date: Mon, 1 Apr 2019 23:12:25 -0700 Subject: [PATCH 01/18] Add initial mnasnet impl --- torchvision/models/mnasnet.py | 167 ++++++++++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 torchvision/models/mnasnet.py diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py new file mode 100644 index 00000000000..9d6b7f714a7 --- /dev/null +++ b/torchvision/models/mnasnet.py @@ -0,0 +1,167 @@ +import math + +import torch +import torch.nn as nn + +# Paper suggests 0.9997 momentum, for TensFlow. Equivalent PyTorch +# momentum is 1.0 - tensorflow. +_BN_MOMENTUM = 1 - 0.9997 + +class _InvertedResidual(nn.Module): + + def __init__(self, in_ch: int, out_ch: int, kernel_size: int, stride: int, + expansion_factor: int, bn_momentum: float = 0.1) -> None: + super().__init__() + assert stride in [1, 2] + assert kernel_size in [3, 5] + mid_ch = in_ch * expansion_factor + self.apply_residual = (in_ch == out_ch and stride == 1) + self.layers = nn.Sequential( + # Pointwise + nn.Conv2d(in_ch, mid_ch, 1, bias=False), + nn.BatchNorm2d(mid_ch, momentum=bn_momentum), + nn.ReLU(inplace=True), + # Depthwise + nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, + stride=stride, groups=mid_ch, bias=False), + nn.BatchNorm2d(mid_ch, momentum=bn_momentum), + nn.ReLU(inplace=True), + # Linear pointwise. Note that there's no activation. + nn.Conv2d(mid_ch, out_ch, 1, bias=False), + nn.BatchNorm2d(out_ch, momentum=bn_momentum)) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.apply_residual: + return self.layers.forward(input) + input + else: + return self.layers.forward(input) + + +def _stack(in_ch: int, out_ch: int, kernel_size: int, stride: int, + exp_factor: int, repeats: int, bn_momentum: float) -> nn.Sequential: + """ Creates a stack of inverted residuals as seen in e.g. MobileNetV2 or + MNasNet. """ + assert repeats >= 1 + # First one has no skip, because feature map size changes. + first = InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, + bn_momentum=bn_momentum) + remaining = [] + for _ in range(1, repeats): + remaining.append( + InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, + bn_momentum=bn_momentum)) + return nn.Sequential(first, *remaining) + + +def _round_to_multiple_of(val: float, divisor: int, + round_up_bias: float = 0.9) -> int: + """ Asymmetric rounding to make `val` divisible by `divisor`. With default + bias, will round up, unless the number is no more than 10% greater than the + smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """ + assert 0.0 < round_up_bias < 1.0 + new_val = max(divisor, int(val + divisor / 2) // divisor * divisor) + return new_val if new_val >= round_up_bias * val else new_val + divisor + + +def _scale_depths(depths: List[int], alpha: float) -> List[int]: + """ Scales tensor depths as in reference MobileNet code, prefers rouding up + rather than down. """ + return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] + + +class MNasNet(torch.nn.Module): + """ MNasNet, as described in https://arxiv.org/pdf/1807.11626.pdf. + >>> model = MNasNet(1000, 1.0) + >>> x = torch.rand(1, 3, 224, 224) + >>> y = model.forward(x) + >>> y.dim() + 1 + >>> y.nelement() + 1000 + """ + + def __init__(self, num_classes: int, alpha: float, dropout:float=0.2) -> None: + super().__init__() + self.alpha = alpha + self.num_classes = num_classes + depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha) + layers = [ + # First layer: regular conv. + nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False), + nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True), + # Depthwise separable, no skip. + nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False), + nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True), + nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False), + nn.BatchNorm2d(16, momentum=_BN_MOMENTUM), + # MNasNet blocks: stacks of inverted residuals. + _stack(16, depths[0], 3, 2, 3, 3, _BN_MOMENTUM), + _stack(depths[0], depths[1], 5, 2, 3, 3, _BN_MOMENTUM), + _stack(depths[1], depths[2], 5, 2, 6, 3, _BN_MOMENTUM), + _stack(depths[2], depths[3], 3, 1, 6, 2, _BN_MOMENTUM), + _stack(depths[3], depths[4], 5, 2, 6, 4, _BN_MOMENTUM), + _stack(depths[4], depths[5], 3, 1, 6, 1, _BN_MOMENTUM), + # Final mapping to classifier input. + nn.Conv2d(depths[5], 1280, 1, padding=0, stride=1, bias=False), + nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True), + nn.AdaptiveAvgPool2d(1) + ] + self.layers = nn.Sequential(*layers) + if dropout > 0.0: + self.classifier = nn.Sequential( + nn.Dropout(inplace=True, p=0.2), nn.Linear(1280, self.num_classes)) + else: + self.classifier = nn.Linear(1280, self.num_classes) + + self._initialize_weights() + + def features(self, x): + return self.layers.forward(x).squeeze() + + def forward(self, x): + return self.classifier(self.features(x)) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2.0 / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + n = m.weight.size(1) + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + +class MNasNet0_5(MNasNet): + """ MNasNet with depth multiplier of 0.5. """ + + def __init__(self, num_classes: int) -> None: + super().__init__(num_classes, 0.5) + +class MNasNet0_75(MNasNet): + """ MNasNet with depth multiplier of 0.75. """ + + def __init__(self, num_classes: int) -> None: + super().__init__(num_classes, 0.75) + +class MNasNet1_0(MNasNet): + """ MNasNet with depth multiplier of 1.0. """ + + def __init__(self, num_classes: int) -> None: + super().__init__(num_classes, 1.0) + + +class MNasNet1_3(MNasNet): + """ MNasNet with depth multiplier of 1.3. """ + + def __init__(self, num_classes: int) -> None: + super().__init__(num_classes, 1.3) + + From e1c55063be9ba6dff47f28588cde9739915d4613 Mon Sep 17 00:00:00 2001 From: 1e100 <38598618+1e100@users.noreply.github.com> Date: Mon, 1 Apr 2019 23:25:47 -0700 Subject: [PATCH 02/18] Remove all type hints, comply with PyTorch overall style --- torchvision/models/mnasnet.py | 77 +++++++++++++++++++---------------- 1 file changed, 42 insertions(+), 35 deletions(-) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 9d6b7f714a7..421016b82b8 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -3,14 +3,20 @@ import torch import torch.nn as nn -# Paper suggests 0.9997 momentum, for TensFlow. Equivalent PyTorch -# momentum is 1.0 - tensorflow. + +__all__ = ['MNASNet', 'MNASNet0_5', 'MNASNet0_75', 'MNASNet1_0', 'MNASNet1_3'] + +# Paper suggests 0.9997 momentum, for TensFlow. Equivalent PyTorch momentum is +# 1.0 - tensorflow. _BN_MOMENTUM = 1 - 0.9997 + class _InvertedResidual(nn.Module): + """ Inverted residual block from MobileNetV2 and MNASNet papers. This can + be used to implement MobileNet V2, if ReLU is replaced with ReLU6. """ - def __init__(self, in_ch: int, out_ch: int, kernel_size: int, stride: int, - expansion_factor: int, bn_momentum: float = 0.1) -> None: + def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, + bn_momentum=0.1): super().__init__() assert stride in [1, 2] assert kernel_size in [3, 5] @@ -30,31 +36,30 @@ def __init__(self, in_ch: int, out_ch: int, kernel_size: int, stride: int, nn.Conv2d(mid_ch, out_ch, 1, bias=False), nn.BatchNorm2d(out_ch, momentum=bn_momentum)) - def forward(self, input: torch.Tensor) -> torch.Tensor: + def forward(self, input): if self.apply_residual: return self.layers.forward(input) + input else: return self.layers.forward(input) -def _stack(in_ch: int, out_ch: int, kernel_size: int, stride: int, - exp_factor: int, repeats: int, bn_momentum: float) -> nn.Sequential: +def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats, + bn_momentum): """ Creates a stack of inverted residuals as seen in e.g. MobileNetV2 or - MNasNet. """ + MNASNet. """ assert repeats >= 1 # First one has no skip, because feature map size changes. - first = InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, - bn_momentum=bn_momentum) + first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, + bn_momentum=bn_momentum) remaining = [] for _ in range(1, repeats): remaining.append( - InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, - bn_momentum=bn_momentum)) + _InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, + bn_momentum=bn_momentum)) return nn.Sequential(first, *remaining) -def _round_to_multiple_of(val: float, divisor: int, - round_up_bias: float = 0.9) -> int: +def _round_to_multiple_of(val, divisor, round_up_bias=0.9): """ Asymmetric rounding to make `val` divisible by `divisor`. With default bias, will round up, unless the number is no more than 10% greater than the smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """ @@ -63,15 +68,15 @@ def _round_to_multiple_of(val: float, divisor: int, return new_val if new_val >= round_up_bias * val else new_val + divisor -def _scale_depths(depths: List[int], alpha: float) -> List[int]: +def _scale_depths(depths, alpha): """ Scales tensor depths as in reference MobileNet code, prefers rouding up rather than down. """ return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] -class MNasNet(torch.nn.Module): - """ MNasNet, as described in https://arxiv.org/pdf/1807.11626.pdf. - >>> model = MNasNet(1000, 1.0) +class MNASNet(torch.nn.Module): + """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. + >>> model = MNASNet(1000, 1.0) >>> x = torch.rand(1, 3, 224, 224) >>> y = model.forward(x) >>> y.dim() @@ -80,7 +85,7 @@ class MNasNet(torch.nn.Module): 1000 """ - def __init__(self, num_classes: int, alpha: float, dropout:float=0.2) -> None: + def __init__(self, num_classes, alpha, dropout=0.2): super().__init__() self.alpha = alpha self.num_classes = num_classes @@ -96,7 +101,7 @@ def __init__(self, num_classes: int, alpha: float, dropout:float=0.2) -> None: nn.ReLU(inplace=True), nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False), nn.BatchNorm2d(16, momentum=_BN_MOMENTUM), - # MNasNet blocks: stacks of inverted residuals. + # MNASNet blocks: stacks of inverted residuals. _stack(16, depths[0], 3, 2, 3, 3, _BN_MOMENTUM), _stack(depths[0], depths[1], 5, 2, 3, 3, _BN_MOMENTUM), _stack(depths[1], depths[2], 5, 2, 6, 3, _BN_MOMENTUM), @@ -112,7 +117,8 @@ def __init__(self, num_classes: int, alpha: float, dropout:float=0.2) -> None: self.layers = nn.Sequential(*layers) if dropout > 0.0: self.classifier = nn.Sequential( - nn.Dropout(inplace=True, p=0.2), nn.Linear(1280, self.num_classes)) + nn.Dropout(inplace=True, p=0.2), + nn.Linear(1280, self.num_classes)) else: self.classifier = nn.Linear(1280, self.num_classes) @@ -139,29 +145,30 @@ def _initialize_weights(self): m.weight.data.normal_(0, 0.01) m.bias.data.zero_() -class MNasNet0_5(MNasNet): - """ MNasNet with depth multiplier of 0.5. """ - def __init__(self, num_classes: int) -> None: +class MNASNet0_5(MNASNet): + """ MNASNet with depth multiplier of 0.5. """ + + def __init__(self, num_classes): super().__init__(num_classes, 0.5) -class MNasNet0_75(MNasNet): - """ MNasNet with depth multiplier of 0.75. """ - def __init__(self, num_classes: int) -> None: +class MNASNet0_75(MNASNet): + """ MNASNet with depth multiplier of 0.75. """ + + def __init__(self, num_classes): super().__init__(num_classes, 0.75) -class MNasNet1_0(MNasNet): - """ MNasNet with depth multiplier of 1.0. """ - def __init__(self, num_classes: int) -> None: +class MNASNet1_0(MNASNet): + """ MNASNet with depth multiplier of 1.0. """ + + def __init__(self, num_classes): super().__init__(num_classes, 1.0) -class MNasNet1_3(MNasNet): - """ MNasNet with depth multiplier of 1.3. """ +class MNASNet1_3(MNASNet): + """ MNASNet with depth multiplier of 1.3. """ - def __init__(self, num_classes: int) -> None: + def __init__(self, num_classes): super().__init__(num_classes, 1.3) - - From 0d77accb5fcd88be9823ad2ce37ffe02ca71191f Mon Sep 17 00:00:00 2001 From: 1e100 <38598618+1e100@users.noreply.github.com> Date: Tue, 2 Apr 2019 01:03:23 -0700 Subject: [PATCH 03/18] Expose models --- torchvision/models/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 7437c51597f..839f45301ee 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -5,3 +5,4 @@ from .inception import * from .densenet import * from .googlenet import * +from .mnasnet import * From c41aaab80fec8f2ddc67e3372db8e362e0aa270d Mon Sep 17 00:00:00 2001 From: 1e100 <38598618+1e100@users.noreply.github.com> Date: Tue, 2 Apr 2019 01:36:45 -0700 Subject: [PATCH 04/18] Remove avgpool from features() and add separately --- torchvision/models/mnasnet.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 421016b82b8..413b71c26c3 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -112,9 +112,9 @@ def __init__(self, num_classes, alpha, dropout=0.2): nn.Conv2d(depths[5], 1280, 1, padding=0, stride=1, bias=False), nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM), nn.ReLU(inplace=True), - nn.AdaptiveAvgPool2d(1) ] self.layers = nn.Sequential(*layers) + self.avgpool = nn.AdaptiveAvgPool2d(1) if dropout > 0.0: self.classifier = nn.Sequential( nn.Dropout(inplace=True, p=0.2), @@ -125,10 +125,12 @@ def __init__(self, num_classes, alpha, dropout=0.2): self._initialize_weights() def features(self, x): - return self.layers.forward(x).squeeze() + return self.layers.forward(x) def forward(self, x): - return self.classifier(self.features(x)) + x = self.features(x) + x = self.avgpool(x).squeeze() + return self.classifier(x) def _initialize_weights(self): for m in self.modules(): From 568bd5083c7552884ef608ed08ace16a4d21ef74 Mon Sep 17 00:00:00 2001 From: 1e100 <38598618+1e100@users.noreply.github.com> Date: Fri, 12 Apr 2019 18:31:33 -0700 Subject: [PATCH 05/18] Fix python3-only stuff, replace subclasses with functions --- torchvision/models/mnasnet.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 413b71c26c3..364c8d363af 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -17,7 +17,7 @@ class _InvertedResidual(nn.Module): def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, bn_momentum=0.1): - super().__init__() + super(_InvertedResidual, self).__init__() assert stride in [1, 2] assert kernel_size in [3, 5] mid_ch = in_ch * expansion_factor @@ -86,7 +86,7 @@ class MNASNet(torch.nn.Module): """ def __init__(self, num_classes, alpha, dropout=0.2): - super().__init__() + super(MNASNet, self).__init__() self.alpha = alpha self.num_classes = num_classes depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha) @@ -148,29 +148,21 @@ def _initialize_weights(self): m.bias.data.zero_() -class MNASNet0_5(MNASNet): +def mnasnet0_5(num_classes): """ MNASNet with depth multiplier of 0.5. """ + return MNASNet(num_classes, alpha=0.5) - def __init__(self, num_classes): - super().__init__(num_classes, 0.5) - -class MNASNet0_75(MNASNet): +def mnasnet0_75(num_classes): """ MNASNet with depth multiplier of 0.75. """ - - def __init__(self, num_classes): - super().__init__(num_classes, 0.75) + return MNASNet(num_classes, alpha=0.75) -class MNASNet1_0(MNASNet): +def mnasnet1_0(num_classes): """ MNASNet with depth multiplier of 1.0. """ + return MNASNet(num_classes, alpha=1.0) - def __init__(self, num_classes): - super().__init__(num_classes, 1.0) - -class MNASNet1_3(MNASNet): +def mnasnet1_3(num_classes): """ MNASNet with depth multiplier of 1.3. """ - - def __init__(self, num_classes): - super().__init__(num_classes, 1.3) + return MNASNet(num_classes, alpha=1.3) From 5617b8e38fae97166d001af4d19f4b20b5e954a3 Mon Sep 17 00:00:00 2001 From: 1e100 <38598618+1e100@users.noreply.github.com> Date: Fri, 12 Apr 2019 18:34:22 -0700 Subject: [PATCH 06/18] fix __all__ --- torchvision/models/mnasnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 364c8d363af..695d1830f76 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -4,7 +4,7 @@ import torch.nn as nn -__all__ = ['MNASNet', 'MNASNet0_5', 'MNASNet0_75', 'MNASNet1_0', 'MNASNet1_3'] +__all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3'] # Paper suggests 0.9997 momentum, for TensFlow. Equivalent PyTorch momentum is # 1.0 - tensorflow. From ba0ad4d80512bd357fccc4a7fe3aa8fd87ac4aa2 Mon Sep 17 00:00:00 2001 From: Dmitry Belenko <38598618+1e100@users.noreply.github.com> Date: Sat, 13 Apr 2019 01:47:06 -0700 Subject: [PATCH 07/18] Fix typo --- torchvision/models/mnasnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 695d1830f76..f235d12989b 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -6,7 +6,7 @@ __all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3'] -# Paper suggests 0.9997 momentum, for TensFlow. Equivalent PyTorch momentum is +# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is # 1.0 - tensorflow. _BN_MOMENTUM = 1 - 0.9997 From bd4836b6746331115388bcb07415e7d3d9dfc684 Mon Sep 17 00:00:00 2001 From: 1e100 <38598618+1e100@users.noreply.github.com> Date: Sat, 13 Apr 2019 19:03:10 -0700 Subject: [PATCH 08/18] Remove conditional dropout --- torchvision/models/mnasnet.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 695d1830f76..cd3538032dd 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -115,12 +115,9 @@ def __init__(self, num_classes, alpha, dropout=0.2): ] self.layers = nn.Sequential(*layers) self.avgpool = nn.AdaptiveAvgPool2d(1) - if dropout > 0.0: - self.classifier = nn.Sequential( - nn.Dropout(inplace=True, p=0.2), - nn.Linear(1280, self.num_classes)) - else: - self.classifier = nn.Linear(1280, self.num_classes) + self.classifier = nn.Sequential( + nn.Dropout(inplace=True, p=dropout), + nn.Linear(1280, self.num_classes)) self._initialize_weights() From 102ba553b7b075675bdbc983b0e5ba28e73186fe Mon Sep 17 00:00:00 2001 From: 1e100 <38598618+1e100@users.noreply.github.com> Date: Sun, 14 Apr 2019 20:48:10 -0700 Subject: [PATCH 09/18] Make dropout functional --- torchvision/models/mnasnet.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 85c2808d768..41855582311 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn - __all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3'] # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is @@ -89,6 +88,7 @@ def __init__(self, num_classes, alpha, dropout=0.2): super(MNASNet, self).__init__() self.alpha = alpha self.num_classes = num_classes + self.dropout = dropout depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha) layers = [ # First layer: regular conv. @@ -115,9 +115,7 @@ def __init__(self, num_classes, alpha, dropout=0.2): ] self.layers = nn.Sequential(*layers) self.avgpool = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Sequential( - nn.Dropout(inplace=True, p=dropout), - nn.Linear(1280, self.num_classes)) + self.classifier = nn.Linear(1280, self.num_classes) self._initialize_weights() @@ -127,6 +125,9 @@ def features(self, x): def forward(self, x): x = self.features(x) x = self.avgpool(x).squeeze() + if self.dropout > 0.0: + x = nn.functional.dropout(x, p=self.dropout, training=self.training, + inplace=True) return self.classifier(x) def _initialize_weights(self): From 9c8b827b37e7c18f4a44ba9740fd771bd6a6d122 Mon Sep 17 00:00:00 2001 From: 1e100 <38598618+1e100@users.noreply.github.com> Date: Mon, 15 Apr 2019 23:00:19 -0700 Subject: [PATCH 10/18] Addressing @fmassa's feedback, round 1 --- torchvision/models/mnasnet.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 41855582311..c37c73a67e6 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -11,8 +11,6 @@ class _InvertedResidual(nn.Module): - """ Inverted residual block from MobileNetV2 and MNASNet papers. This can - be used to implement MobileNet V2, if ReLU is replaced with ReLU6. """ def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, bn_momentum=0.1): @@ -37,15 +35,14 @@ def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, def forward(self, input): if self.apply_residual: - return self.layers.forward(input) + input + return self.layers(input) + input else: - return self.layers.forward(input) + return self.layers(input) def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats, bn_momentum): - """ Creates a stack of inverted residuals as seen in e.g. MobileNetV2 or - MNASNet. """ + """ Creates a stack of inverted residuals. """ assert repeats >= 1 # First one has no skip, because feature map size changes. first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, From 2872b1fa9ee6c846a21c4b6ec0c71649036f51e6 Mon Sep 17 00:00:00 2001 From: 1e100 <38598618+1e100@users.noreply.github.com> Date: Mon, 15 Apr 2019 23:07:49 -0700 Subject: [PATCH 11/18] Replaced adaptive avgpool with mean on H and W to prevent collapsing the batch dimension --- torchvision/models/mnasnet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index c37c73a67e6..42591e77868 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -111,17 +111,17 @@ def __init__(self, num_classes, alpha, dropout=0.2): nn.ReLU(inplace=True), ] self.layers = nn.Sequential(*layers) - self.avgpool = nn.AdaptiveAvgPool2d(1) self.classifier = nn.Linear(1280, self.num_classes) self._initialize_weights() def features(self, x): - return self.layers.forward(x) + return self.layers(x) def forward(self, x): x = self.features(x) - x = self.avgpool(x).squeeze() + # Equivalent to global avgpool and removing H and W dimensions. + x = x.mean([2, 3]) if self.dropout > 0.0: x = nn.functional.dropout(x, p=self.dropout, training=self.training, inplace=True) From 05b387b8bda3e686485890a8a55f8bf5e3484439 Mon Sep 17 00:00:00 2001 From: 1e100 <38598618+1e100@users.noreply.github.com> Date: Fri, 3 May 2019 03:40:52 -0700 Subject: [PATCH 12/18] Partially address feedback --- torchvision/models/mnasnet.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 42591e77868..7e41a505750 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -74,7 +74,7 @@ class MNASNet(torch.nn.Module): """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. >>> model = MNASNet(1000, 1.0) >>> x = torch.rand(1, 3, 224, 224) - >>> y = model.forward(x) + >>> y = model(x) >>> y.dim() 1 >>> y.nelement() @@ -111,20 +111,14 @@ def __init__(self, num_classes, alpha, dropout=0.2): nn.ReLU(inplace=True), ] self.layers = nn.Sequential(*layers) - self.classifier = nn.Linear(1280, self.num_classes) - + self.classifier = nn.Sequential(nn.Dropout(p=self.dropout, inplace=True), + nn.Linear(1280, self.num_classes)) self._initialize_weights() - def features(self, x): - return self.layers(x) - def forward(self, x): - x = self.features(x) + x = self.layers(x) # Equivalent to global avgpool and removing H and W dimensions. x = x.mean([2, 3]) - if self.dropout > 0.0: - x = nn.functional.dropout(x, p=self.dropout, training=self.training, - inplace=True) return self.classifier(x) def _initialize_weights(self): From 2d397976c000548dca28d8295d63b4ee12ded32c Mon Sep 17 00:00:00 2001 From: 1e100 <38598618+1e100@users.noreply.github.com> Date: Fri, 3 May 2019 03:41:54 -0700 Subject: [PATCH 13/18] YAPF --- torchvision/models/mnasnet.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 7e41a505750..9405dd844ca 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -111,8 +111,9 @@ def __init__(self, num_classes, alpha, dropout=0.2): nn.ReLU(inplace=True), ] self.layers = nn.Sequential(*layers) - self.classifier = nn.Sequential(nn.Dropout(p=self.dropout, inplace=True), - nn.Linear(1280, self.num_classes)) + self.classifier = nn.Sequential( + nn.Dropout(p=self.dropout, inplace=True), + nn.Linear(1280, self.num_classes)) self._initialize_weights() def forward(self, x): From 8b5f7b91d31ff9e5ab3200d5064387cb5a2094c4 Mon Sep 17 00:00:00 2001 From: 1e100 <38598618+1e100@users.noreply.github.com> Date: Fri, 3 May 2019 03:46:33 -0700 Subject: [PATCH 14/18] Removed redundant class vars --- torchvision/models/mnasnet.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 9405dd844ca..bdb7f000317 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -83,9 +83,6 @@ class MNASNet(torch.nn.Module): def __init__(self, num_classes, alpha, dropout=0.2): super(MNASNet, self).__init__() - self.alpha = alpha - self.num_classes = num_classes - self.dropout = dropout depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha) layers = [ # First layer: regular conv. @@ -112,8 +109,8 @@ def __init__(self, num_classes, alpha, dropout=0.2): ] self.layers = nn.Sequential(*layers) self.classifier = nn.Sequential( - nn.Dropout(p=self.dropout, inplace=True), - nn.Linear(1280, self.num_classes)) + nn.Dropout(p=dropout, inplace=True), + nn.Linear(1280, num_classes)) self._initialize_weights() def forward(self, x): From 40471ac6f090a929eb83136a95a62108bf40bad7 Mon Sep 17 00:00:00 2001 From: 1e100 <38598618+1e100@users.noreply.github.com> Date: Mon, 6 May 2019 02:40:09 -0700 Subject: [PATCH 15/18] Update urls to releases --- torchvision/models/mnasnet.py | 53 +++++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 12 deletions(-) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index bdb7f000317..ee107d2518e 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -2,9 +2,19 @@ import torch import torch.nn as nn +from .utils import load_state_dict_from_url __all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3'] +_MODEL_URLS = { + "mnasnet0_5": + "https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet0.5_top1_67.592-7c6cb539b9.pth", + "mnasnet0_75": None, + "mnasnet1_0": + "https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet1.0_top1_73.512-f206786ef8.pth", + "mnasnet1_3": None +} + # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is # 1.0 - tensorflow. _BN_MOMENTUM = 1 - 0.9997 @@ -81,7 +91,7 @@ class MNASNet(torch.nn.Module): 1000 """ - def __init__(self, num_classes, alpha, dropout=0.2): + def __init__(self, alpha, num_classes=1000, dropout=0.2): super(MNASNet, self).__init__() depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha) layers = [ @@ -108,9 +118,8 @@ def __init__(self, num_classes, alpha, dropout=0.2): nn.ReLU(inplace=True), ] self.layers = nn.Sequential(*layers) - self.classifier = nn.Sequential( - nn.Dropout(p=dropout, inplace=True), - nn.Linear(1280, num_classes)) + self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), + nn.Linear(1280, num_classes)) self._initialize_weights() def forward(self, x): @@ -135,21 +144,41 @@ def _initialize_weights(self): m.bias.data.zero_() -def mnasnet0_5(num_classes): +def _load_pretrained(model_name, model): + if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None: + raise ValueError( + "No checkpoint is available for model type {}".format(model_name)) + checkpoint_url = _MODEL_URLS[model_name] + model.load_state_dict(torch.utils.model_zoo.load_url(checkpoint_url)) + + +def mnasnet0_5(pretrained=False, **kwargs): """ MNASNet with depth multiplier of 0.5. """ - return MNASNet(num_classes, alpha=0.5) + model = MNASNet(0.5, **kwargs) + if pretrained: + _load_pretrained("mnasnet0_5", model) + return model -def mnasnet0_75(num_classes): +def mnasnet0_75(pretrained=False, **kwargs): """ MNASNet with depth multiplier of 0.75. """ - return MNASNet(num_classes, alpha=0.75) + model = MNASNet(0.75, **kwargs) + if pretrained: + _load_pretrained("mnasnet0_75", model) + return model -def mnasnet1_0(num_classes): +def mnasnet1_0(pretrained=False, **kwargs): """ MNASNet with depth multiplier of 1.0. """ - return MNASNet(num_classes, alpha=1.0) + model = MNASNet(1.0, **kwargs) + if pretrained: + _load_pretrained("mnasnet1_0", model) + return model -def mnasnet1_3(num_classes): +def mnasnet1_3(pretrained=False, **kwargs): """ MNASNet with depth multiplier of 1.3. """ - return MNASNet(num_classes, alpha=1.3) + model = MNASNet(1.3, **kwargs) + if pretrained: + _load_pretrained("mnasnet1_3", model) + return model From b1d54ec62b7f5c355f8316e0b2f5a0c63de5812d Mon Sep 17 00:00:00 2001 From: 1e100 <38598618+1e100@users.noreply.github.com> Date: Mon, 6 May 2019 02:51:54 -0700 Subject: [PATCH 16/18] Add information to models.rst --- docs/source/models.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/source/models.rst b/docs/source/models.rst index 66bb60e2004..216c0f9b79b 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -12,6 +12,7 @@ architectures: - `Inception`_ v3 - `GoogLeNet`_ - `ShuffleNet`_ v2 +- `MNASNet`_ You can construct a model with random weights by calling its constructor: @@ -26,6 +27,7 @@ You can construct a model with random weights by calling its constructor: inception = models.inception_v3() googlenet = models.googlenet() shufflenet = models.shufflenetv2() + mnasnet = models.mnasnet1_0() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. These can be constructed by passing ``pretrained=True``: @@ -41,6 +43,7 @@ These can be constructed by passing ``pretrained=True``: inception = models.inception_v3(pretrained=True) googlenet = models.googlenet(pretrained=True) shufflenet = models.shufflenetv2(pretrained=True) + mnasnet = models.mnasnet1_0(pretrained=True) Instancing a pre-trained model will download its weights to a cache directory. This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See @@ -92,6 +95,7 @@ Densenet-161 22.35 6.20 Inception v3 22.55 6.44 GoogleNet 30.22 10.47 ShuffleNet V2 30.64 11.68 +MNASNet 1.0 26.49 8.456 ================================ ============= ============= @@ -103,6 +107,7 @@ ShuffleNet V2 30.64 11.68 .. _Inception: https://arxiv.org/abs/1512.00567 .. _GoogLeNet: https://arxiv.org/abs/1409.4842 .. _ShuffleNet: https://arxiv.org/abs/1807.11164 +.. _MNASNet: https://arxiv.org/abs/1807.11626 .. currentmodule:: torchvision.models @@ -162,3 +167,10 @@ ShuffleNet v2 .. autofunction:: shufflenet +MNASNet +------------- + +.. autofunction:: mnasnet0_5 +.. autofunction:: mnasnet0_75 +.. autofunction:: mnasnet1_0 +.. autofunction:: mnasnet1_3 From ec717d03042fbf6e84273dfc70f1855efb19a0be Mon Sep 17 00:00:00 2001 From: 1e100 <38598618+1e100@users.noreply.github.com> Date: Sat, 11 May 2019 16:57:50 -0700 Subject: [PATCH 17/18] Replace init with kaiming_normal_ in fan-out mode --- torchvision/models/mnasnet.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index ee107d2518e..26fc68879c8 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -131,17 +131,16 @@ def forward(self, x): def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2.0 / n)) + nn.init.kaiming_normal_(m.weight, mode="fan_out", + nonlinearity="relu") if m.bias is not None: - m.bias.data.zero_() + nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1.0) - m.bias.data.zero_() + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): - n = m.weight.size(1) - m.weight.data.normal_(0, 0.01) - m.bias.data.zero_() + nn.init.normal_(m.weight, 0.01) + nn.init.zeros_(m.bias) def _load_pretrained(model_name, model): From 8b2dba9d62e25fe0d8c5db9558b348472dc0432c Mon Sep 17 00:00:00 2001 From: 1e100 <38598618+1e100@users.noreply.github.com> Date: Sat, 11 May 2019 17:13:11 -0700 Subject: [PATCH 18/18] Use load_state_dict_from_url --- torchvision/models/mnasnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 26fc68879c8..5deb87c2ad1 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -148,7 +148,7 @@ def _load_pretrained(model_name, model): raise ValueError( "No checkpoint is available for model type {}".format(model_name)) checkpoint_url = _MODEL_URLS[model_name] - model.load_state_dict(torch.utils.model_zoo.load_url(checkpoint_url)) + model.load_state_dict(load_state_dict_from_url(checkpoint_url)) def mnasnet0_5(pretrained=False, **kwargs):