diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 319f5cf298e..cb0d28ab1a5 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -4,6 +4,7 @@ - `AlexNet`_ - `VGG`_ - `ResNet`_ +- `SqueezeNet`_ You can construct a model with random weights by calling its constructor: @@ -12,6 +13,7 @@ import torchvision.models as models resnet18 = models.resnet18() alexnet = models.alexnet() + squeezenet = models.squeezenet1_0() We provide pre-trained models for the ResNet variants and AlexNet, using the PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing @@ -26,8 +28,10 @@ .. _AlexNet: https://arxiv.org/abs/1404.5997 .. _VGG: https://arxiv.org/abs/1409.1556 .. _ResNet: https://arxiv.org/abs/1512.03385 +.. _SqueezeNet: https://arxiv.org/abs/1602.07360 """ from .alexnet import * from .resnet import * from .vgg import * +from .squeezenet import * diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py new file mode 100644 index 00000000000..13c7f8ff628 --- /dev/null +++ b/torchvision/models/squeezenet.py @@ -0,0 +1,130 @@ +import math +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + + +__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] + + +model_urls = { + 'squeezenet1_0': 'https://s3.amazonaws.com/pytorch/models/squeezenet1_0-a815701f.pth', + 'squeezenet1_1': 'https://s3.amazonaws.com/pytorch/models/squeezenet1_1-f364aa15.pth', +} + + +class Fire(nn.Module): + def __init__(self, inplanes, squeeze_planes, + expand1x1_planes, expand3x3_planes): + super(Fire, self).__init__() + self.inplanes = inplanes + self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) + self.squeeze_activation = nn.ReLU(inplace=True) + self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, + kernel_size=1) + self.expand1x1_activation = nn.ReLU(inplace=True) + self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, + kernel_size=3, padding=1) + self.expand3x3_activation = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.squeeze_activation(self.squeeze(x)) + return torch.cat([ + self.expand1x1_activation(self.expand1x1(x)), + self.expand3x3_activation(self.expand3x3(x)) + ], 1) + + +class SqueezeNet(nn.Module): + def __init__(self, version=1.0, num_classes=1000): + super(SqueezeNet, self).__init__() + if version not in [1.0, 1.1]: + raise ValueError("Unsupported SqueezeNet version {version}:" + "1.0 or 1.1 expected".format(version=version)) + self.num_classes = num_classes + if version == 1.0: + self.features = nn.Sequential( + nn.Conv2d(3, 96, kernel_size=7, stride=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(96, 16, 64, 64), + Fire(128, 16, 64, 64), + Fire(128, 32, 128, 128), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(256, 32, 128, 128), + Fire(256, 48, 192, 192), + Fire(384, 48, 192, 192), + Fire(384, 64, 256, 256), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(512, 64, 256, 256), + ) + else: + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=3, stride=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(64, 16, 64, 64), + Fire(128, 16, 64, 64), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(128, 32, 128, 128), + Fire(256, 32, 128, 128), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(256, 48, 192, 192), + Fire(384, 48, 192, 192), + Fire(384, 64, 256, 256), + Fire(512, 64, 256, 256), + ) + # Final convolution is initialized differently form the rest + final_conv = nn.Conv2d(512, num_classes, kernel_size=1) + self.classifier = nn.Sequential( + nn.Dropout(p=0.5), + final_conv, + nn.ReLU(inplace=True), + nn.AvgPool2d(13) + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + gain = 2.0 + if m is final_conv: + m.weight.data.normal_(0, 0.01) + else: + fan_in = m.kernel_size[0] * m.kernel_size[1] * m.in_channels + u = math.sqrt(3.0 * gain / fan_in) + m.weight.data.uniform_(-u, u) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.features(x) + x = self.classifier(x) + return x.view(x.size(0), self.num_classes) + + +def squeezenet1_0(pretrained=False): + r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level + accuracy with 50x fewer parameters and <0.5MB model size" + `_ paper. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = SqueezeNet(version=1.0) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_0'])) + return model + + +def squeezenet1_1(pretrained=False): + r"""SqueezeNet 1.1 model from the `official SqueezeNet repo + `_. + SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters + than SqueezeNet 1.0, without sacrificing accuracy. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = SqueezeNet(version=1.1) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1'])) + return model