Skip to content

Commit b9b69ba

Browse files
committed
Add GoogLeNet (Inception v1)
1 parent e489abc commit b9b69ba

File tree

3 files changed

+176
-0
lines changed

3 files changed

+176
-0
lines changed

docs/source/models.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ architectures:
1010
- `SqueezeNet`_
1111
- `DenseNet`_
1212
- `Inception`_ v3
13+
- `GoogLeNet`_
1314

1415
You can construct a model with random weights by calling its constructor:
1516

@@ -22,6 +23,7 @@ You can construct a model with random weights by calling its constructor:
2223
squeezenet = models.squeezenet1_0()
2324
densenet = models.densenet161()
2425
inception = models.inception_v3()
26+
googlenet = models.googlenet()
2527
2628
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
2729
These can be constructed by passing ``pretrained=True``:
@@ -35,6 +37,7 @@ These can be constructed by passing ``pretrained=True``:
3537
vgg16 = models.vgg16(pretrained=True)
3638
densenet = models.densenet161(pretrained=True)
3739
inception = models.inception_v3(pretrained=True)
40+
googlenet = models.googlenet(pretrained=True)
3841
3942
Instancing a pre-trained model will download its weights to a cache directory.
4043
This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See
@@ -93,6 +96,7 @@ Inception v3 22.55 6.44
9396
.. _SqueezeNet: https://arxiv.org/abs/1602.07360
9497
.. _DenseNet: https://arxiv.org/abs/1608.06993
9598
.. _Inception: https://arxiv.org/abs/1512.00567
99+
.. _GoogLeNet: https://arxiv.org/abs/1409.4842
96100

97101
.. currentmodule:: torchvision.models
98102

@@ -142,3 +146,8 @@ Inception v3
142146

143147
.. autofunction:: inception_v3
144148

149+
GoogLeNet
150+
------------
151+
152+
.. autofunction:: googlenet
153+

torchvision/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from .squeezenet import *
55
from .inception import *
66
from .densenet import *
7+
from .googlenet import *

torchvision/models/googlenet.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch.utils import model_zoo
5+
6+
__all__ = ['GoogLeNet', 'googlenet']
7+
8+
model_urls = {
9+
'googlenet': ''
10+
}
11+
12+
13+
def googlenet(pretrained=False, **kwargs):
14+
r"""GoogLeNet (Inception v1) model architecture from
15+
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
16+
Args:
17+
pretrained (bool): If True, returns a model pre-trained on ImageNet
18+
"""
19+
if pretrained:
20+
model = GoogLeNet(**kwargs)
21+
model.load_state_dict(model_zoo.load_url(model_urls['googlenet']))
22+
return model
23+
24+
return GoogLeNet(**kwargs)
25+
26+
27+
class GoogLeNet(nn.Module):
28+
29+
def __init__(self, num_classes=1000, aux_logits=True):
30+
super(GoogLeNet, self).__init__()
31+
self.aux_logits = aux_logits
32+
33+
self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
34+
self.maxpool1 = nn.MaxPool2d(3, stride=2, padding=1)
35+
self.lrn1 = nn.LocalResponseNorm(5, alpha=0.0001)
36+
self.conv2 = BasicConv2d(64, 64, kernel_size=1)
37+
self.conv3 = BasicConv2d(64, 192, kernel_size=3, stride=1, padding=1)
38+
self.lrn2 = nn.LocalResponseNorm(5, alpha=0.0001)
39+
self.maxpool2 = nn.MaxPool2d(3, stride=2, padding=1)
40+
41+
self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
42+
self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
43+
self.maxpool3 = nn.MaxPool2d(3, stride=2)
44+
45+
self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
46+
self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
47+
self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
48+
self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
49+
self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
50+
self.maxpool4 = nn.MaxPool2d(3, stride=2, padding=1)
51+
52+
self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
53+
self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
54+
if aux_logits:
55+
self.aux1 = InceptionAux(512, num_classes)
56+
self.aux2 = InceptionAux(528, num_classes)
57+
self.avgpool = nn.AvgPool2d(7, stride=1)
58+
self.dropout = nn.Dropout(0.4)
59+
self.fc = nn.Linear(1024, num_classes)
60+
61+
for m in self.modules():
62+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
63+
import scipy.stats as stats
64+
X = stats.truncnorm(-2, 2, scale=0.01)
65+
values = torch.Tensor(X.rvs(m.weight.numel()))
66+
values = values.view(m.weight.size())
67+
m.weight.data.copy_(values)
68+
69+
def forward(self, x):
70+
x = self.conv1(x)
71+
x = self.maxpool1(x)
72+
x = self.lrn1(x)
73+
x = self.conv2(x)
74+
x = self.conv3(x)
75+
x = self.lrn2(x)
76+
x = self.maxpool2(x)
77+
78+
x = self.inception3a(x)
79+
x = self.inception3b(x)
80+
x = self.maxpool3(x)
81+
x = self.inception4a(x)
82+
if self.training and self.aux_logits:
83+
aux1 = self.aux1(x)
84+
85+
x = self.inception4b(x)
86+
x = self.inception4c(x)
87+
x = self.inception4d(x)
88+
if self.training and self.aux_logits:
89+
aux2 = self.aux2(x)
90+
91+
x = self.inception4e(x)
92+
x = self.maxpool4(x)
93+
x = self.inception5a(x)
94+
x = self.inception5b(x)
95+
96+
x = self.avgpool(x)
97+
x = x.view(x.size(0), -1)
98+
x = self.dropout(x)
99+
x = self.fc(x)
100+
if self.training and self.aux_logits:
101+
return aux1, aux2, x
102+
return x
103+
104+
105+
class Inception(nn.Module):
106+
107+
def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
108+
super(Inception, self).__init__()
109+
110+
self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
111+
112+
self.branch2 = nn.Sequential(
113+
BasicConv2d(in_channels, ch3x3red, kernel_size=1, stride=1),
114+
BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
115+
)
116+
117+
self.branch3 = nn.Sequential(
118+
BasicConv2d(in_channels, ch5x5red, kernel_size=1),
119+
BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)
120+
)
121+
122+
self.branch4 = nn.Sequential(
123+
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
124+
BasicConv2d(in_channels, pool_proj, kernel_size=1)
125+
)
126+
127+
def forward(self, x):
128+
branch1 = self.branch1(x)
129+
branch2 = self.branch2(x)
130+
branch3 = self.branch3(x)
131+
branch4 = self.branch4(x)
132+
133+
outputs = [branch1, branch2, branch3, branch4]
134+
return torch.cat(outputs, 1)
135+
136+
137+
class InceptionAux(nn.Module):
138+
139+
def __init__(self, in_channels, num_classes):
140+
super(InceptionAux, self).__init__()
141+
self.conv = BasicConv2d(in_channels, 128, kernel_size=1)
142+
143+
self.fc1 = nn.Linear(128 * 3 * 3, 1024)
144+
self.fc2 = nn.Linear(1024, num_classes)
145+
146+
def forward(self, x):
147+
x = F.avg_pool2d(x, kernel_size=5, stride=3)
148+
149+
x = self.conv(x)
150+
x = x.view(x.size(0), -1)
151+
x = self.fc1(x)
152+
x = F.dropout(x, 0.7, training=self.training)
153+
x = self.fc2(x)
154+
155+
return x
156+
157+
158+
class BasicConv2d(nn.Module):
159+
160+
def __init__(self, in_channels, out_channels, **kwargs):
161+
super(BasicConv2d, self).__init__()
162+
self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
163+
164+
def forward(self, x):
165+
x = self.conv(x)
166+
return F.relu(x, inplace=True)

0 commit comments

Comments
 (0)