Skip to content

Commit d44273b

Browse files
Maratyszczaapaszke
authored andcommitted
SqueezeNet 1.0 and 1.1 models (#49)
* Add SqueezeNet 1.0 and 1.1 models * Selectively avoid inplace in SqueezeNet * Use Glorot uniform initialization in SqueezeNet * Make all ReLU in SqueezeNet in-place * Add pretrained SqueezeNet 1.0 and 1.1 * Minor fixes in SqueezeNet models
1 parent 98f59eb commit d44273b

File tree

2 files changed

+134
-0
lines changed

2 files changed

+134
-0
lines changed

torchvision/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
- `AlexNet`_
55
- `VGG`_
66
- `ResNet`_
7+
- `SqueezeNet`_
78
89
You can construct a model with random weights by calling its constructor:
910
@@ -12,6 +13,7 @@
1213
import torchvision.models as models
1314
resnet18 = models.resnet18()
1415
alexnet = models.alexnet()
16+
squeezenet = models.squeezenet1_0()
1517
1618
We provide pre-trained models for the ResNet variants and AlexNet, using the
1719
PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing
@@ -26,8 +28,10 @@
2628
.. _AlexNet: https://arxiv.org/abs/1404.5997
2729
.. _VGG: https://arxiv.org/abs/1409.1556
2830
.. _ResNet: https://arxiv.org/abs/1512.03385
31+
.. _SqueezeNet: https://arxiv.org/abs/1602.07360
2932
"""
3033

3134
from .alexnet import *
3235
from .resnet import *
3336
from .vgg import *
37+
from .squeezenet import *

torchvision/models/squeezenet.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import math
2+
import torch
3+
import torch.nn as nn
4+
import torch.utils.model_zoo as model_zoo
5+
6+
7+
__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']
8+
9+
10+
model_urls = {
11+
'squeezenet1_0': 'https://s3.amazonaws.com/pytorch/models/squeezenet1_0-a815701f.pth',
12+
'squeezenet1_1': 'https://s3.amazonaws.com/pytorch/models/squeezenet1_1-f364aa15.pth',
13+
}
14+
15+
16+
class Fire(nn.Module):
17+
def __init__(self, inplanes, squeeze_planes,
18+
expand1x1_planes, expand3x3_planes):
19+
super(Fire, self).__init__()
20+
self.inplanes = inplanes
21+
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
22+
self.squeeze_activation = nn.ReLU(inplace=True)
23+
self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,
24+
kernel_size=1)
25+
self.expand1x1_activation = nn.ReLU(inplace=True)
26+
self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,
27+
kernel_size=3, padding=1)
28+
self.expand3x3_activation = nn.ReLU(inplace=True)
29+
30+
def forward(self, x):
31+
x = self.squeeze_activation(self.squeeze(x))
32+
return torch.cat([
33+
self.expand1x1_activation(self.expand1x1(x)),
34+
self.expand3x3_activation(self.expand3x3(x))
35+
], 1)
36+
37+
38+
class SqueezeNet(nn.Module):
39+
def __init__(self, version=1.0, num_classes=1000):
40+
super(SqueezeNet, self).__init__()
41+
if version not in [1.0, 1.1]:
42+
raise ValueError("Unsupported SqueezeNet version {version}:"
43+
"1.0 or 1.1 expected".format(version=version))
44+
self.num_classes = num_classes
45+
if version == 1.0:
46+
self.features = nn.Sequential(
47+
nn.Conv2d(3, 96, kernel_size=7, stride=2),
48+
nn.ReLU(inplace=True),
49+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
50+
Fire(96, 16, 64, 64),
51+
Fire(128, 16, 64, 64),
52+
Fire(128, 32, 128, 128),
53+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
54+
Fire(256, 32, 128, 128),
55+
Fire(256, 48, 192, 192),
56+
Fire(384, 48, 192, 192),
57+
Fire(384, 64, 256, 256),
58+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
59+
Fire(512, 64, 256, 256),
60+
)
61+
else:
62+
self.features = nn.Sequential(
63+
nn.Conv2d(3, 64, kernel_size=3, stride=2),
64+
nn.ReLU(inplace=True),
65+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
66+
Fire(64, 16, 64, 64),
67+
Fire(128, 16, 64, 64),
68+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
69+
Fire(128, 32, 128, 128),
70+
Fire(256, 32, 128, 128),
71+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
72+
Fire(256, 48, 192, 192),
73+
Fire(384, 48, 192, 192),
74+
Fire(384, 64, 256, 256),
75+
Fire(512, 64, 256, 256),
76+
)
77+
# Final convolution is initialized differently form the rest
78+
final_conv = nn.Conv2d(512, num_classes, kernel_size=1)
79+
self.classifier = nn.Sequential(
80+
nn.Dropout(p=0.5),
81+
final_conv,
82+
nn.ReLU(inplace=True),
83+
nn.AvgPool2d(13)
84+
)
85+
86+
for m in self.modules():
87+
if isinstance(m, nn.Conv2d):
88+
gain = 2.0
89+
if m is final_conv:
90+
m.weight.data.normal_(0, 0.01)
91+
else:
92+
fan_in = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
93+
u = math.sqrt(3.0 * gain / fan_in)
94+
m.weight.data.uniform_(-u, u)
95+
if m.bias is not None:
96+
m.bias.data.zero_()
97+
98+
def forward(self, x):
99+
x = self.features(x)
100+
x = self.classifier(x)
101+
return x.view(x.size(0), self.num_classes)
102+
103+
104+
def squeezenet1_0(pretrained=False):
105+
r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
106+
accuracy with 50x fewer parameters and <0.5MB model size"
107+
<https://arxiv.org/abs/1602.07360>`_ paper.
108+
109+
Args:
110+
pretrained (bool): If True, returns a model pre-trained on ImageNet
111+
"""
112+
model = SqueezeNet(version=1.0)
113+
if pretrained:
114+
model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_0']))
115+
return model
116+
117+
118+
def squeezenet1_1(pretrained=False):
119+
r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
120+
<https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
121+
SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
122+
than SqueezeNet 1.0, without sacrificing accuracy.
123+
124+
Args:
125+
pretrained (bool): If True, returns a model pre-trained on ImageNet
126+
"""
127+
model = SqueezeNet(version=1.1)
128+
if pretrained:
129+
model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1']))
130+
return model

0 commit comments

Comments
 (0)