Skip to content

Commit d302005

Browse files
committed
adjust network to match TensorFlow
1 parent 4ecf860 commit d302005

File tree

2 files changed

+41
-58
lines changed

2 files changed

+41
-58
lines changed

docs/source/models.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,5 +150,4 @@ GoogLeNet
150150
------------
151151

152152
.. autofunction:: googlenet
153-
.. autofunction:: googlenet_bn
154153

torchvision/models/googlenet.py

Lines changed: 41 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import torch.nn.functional as F
44
from torch.utils import model_zoo
55

6-
__all__ = ['GoogLeNet', 'googlenet', 'googlenet_bn']
6+
__all__ = ['GoogLeNet', 'googlenet']
77

88
model_urls = {
9-
'googlenet': '',
10-
'googlenet_bn': ''
9+
# GoogLeNet ported from TensorFlow
10+
'googlenet': 'https://github.com/TheCodez/vision/releases/download/1.0/googlenet-1378be20.pth',
1111
}
1212

1313

@@ -18,6 +18,8 @@ def googlenet(pretrained=False, **kwargs):
1818
pretrained (bool): If True, returns a model pre-trained on ImageNet
1919
"""
2020
if pretrained:
21+
if 'transform_input' not in kwargs:
22+
kwargs['transform_input'] = True
2123
kwargs['init_weights'] = False
2224
model = GoogLeNet(**kwargs)
2325
model.load_state_dict(model_zoo.load_url(model_urls['googlenet']))
@@ -26,51 +28,35 @@ def googlenet(pretrained=False, **kwargs):
2628
return GoogLeNet(**kwargs)
2729

2830

29-
def googlenet_bn(pretrained=False, **kwargs):
30-
r"""GoogLeNet (Inception v1) model architecture with batch normalization from
31-
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
32-
Args:
33-
pretrained (bool): If True, returns a model pre-trained on ImageNet
34-
"""
35-
if pretrained:
36-
kwargs['init_weights'] = False
37-
model = GoogLeNet(batch_norm=True, **kwargs)
38-
model.load_state_dict(model_zoo.load_url(model_urls['googlenet_bn']))
39-
return model
40-
41-
return GoogLeNet(batch_norm=True, **kwargs)
42-
43-
4431
class GoogLeNet(nn.Module):
4532

46-
def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False, init_weights=True):
33+
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True):
4734
super(GoogLeNet, self).__init__()
4835
self.aux_logits = aux_logits
36+
self.transform_input = transform_input
4937

50-
self.conv1 = BasicConv2d(3, 64, batch_norm, kernel_size=7, stride=2, padding=3)
38+
self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
5139
self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
52-
self.lrn1 = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75)
53-
self.conv2 = BasicConv2d(64, 64, batch_norm, kernel_size=1)
54-
self.conv3 = BasicConv2d(64, 192, batch_norm, kernel_size=3, padding=1)
55-
self.lrn2 = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75)
40+
self.conv2 = BasicConv2d(64, 64, kernel_size=1)
41+
self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
5642
self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
5743

58-
self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32, batch_norm)
59-
self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64, batch_norm)
44+
self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
45+
self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
6046
self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
6147

62-
self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64, batch_norm)
63-
self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64, batch_norm)
64-
self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64, batch_norm)
65-
self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64, batch_norm)
66-
self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128, batch_norm)
48+
self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
49+
self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
50+
self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
51+
self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
52+
self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
6753
self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
6854

69-
self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128, batch_norm)
70-
self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128, batch_norm)
55+
self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
56+
self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
7157
if aux_logits:
72-
self.aux1 = InceptionAux(512, num_classes, batch_norm)
73-
self.aux2 = InceptionAux(528, num_classes, batch_norm)
58+
self.aux1 = InceptionAux(512, num_classes)
59+
self.aux2 = InceptionAux(528, num_classes)
7460
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
7561
self.dropout = nn.Dropout(0.4)
7662
self.fc = nn.Linear(1024, num_classes)
@@ -92,12 +78,16 @@ def _initialize_weights(self):
9278
nn.init.constant_(m.bias, 0)
9379

9480
def forward(self, x):
81+
if self.transform_input:
82+
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
83+
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
84+
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
85+
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
86+
9587
x = self.conv1(x)
9688
x = self.maxpool1(x)
97-
x = self.lrn1(x)
9889
x = self.conv2(x)
9990
x = self.conv3(x)
100-
x = self.lrn2(x)
10191
x = self.maxpool2(x)
10292

10393
x = self.inception3a(x)
@@ -129,24 +119,24 @@ def forward(self, x):
129119

130120
class Inception(nn.Module):
131121

132-
def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj, batch_norm=False):
122+
def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
133123
super(Inception, self).__init__()
134124

135-
self.branch1 = BasicConv2d(in_channels, ch1x1, batch_norm, kernel_size=1)
125+
self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
136126

137127
self.branch2 = nn.Sequential(
138-
BasicConv2d(in_channels, ch3x3red, batch_norm, kernel_size=1),
139-
BasicConv2d(ch3x3red, ch3x3, batch_norm, kernel_size=3, padding=1)
128+
BasicConv2d(in_channels, ch3x3red, kernel_size=1),
129+
BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
140130
)
141131

142132
self.branch3 = nn.Sequential(
143-
BasicConv2d(in_channels, ch5x5red, batch_norm, kernel_size=1),
144-
BasicConv2d(ch5x5red, ch5x5, batch_norm, kernel_size=5, padding=2)
133+
BasicConv2d(in_channels, ch5x5red, kernel_size=1),
134+
BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1)
145135
)
146136

147137
self.branch4 = nn.Sequential(
148138
nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
149-
BasicConv2d(in_channels, pool_proj, batch_norm, kernel_size=1)
139+
BasicConv2d(in_channels, pool_proj, kernel_size=1)
150140
)
151141

152142
def forward(self, x):
@@ -161,11 +151,11 @@ def forward(self, x):
161151

162152
class InceptionAux(nn.Module):
163153

164-
def __init__(self, in_channels, num_classes, batch_norm=False):
154+
def __init__(self, in_channels, num_classes):
165155
super(InceptionAux, self).__init__()
166-
self.conv = BasicConv2d(in_channels, 128, batch_norm, kernel_size=1)
156+
self.conv = BasicConv2d(in_channels, 128, kernel_size=1)
167157

168-
self.fc1 = nn.Linear(128 * 4 * 4, 1024)
158+
self.fc1 = nn.Linear(2048, 1024)
169159
self.fc2 = nn.Linear(1024, num_classes)
170160

171161
def forward(self, x):
@@ -182,18 +172,12 @@ def forward(self, x):
182172

183173
class BasicConv2d(nn.Module):
184174

185-
def __init__(self, in_channels, out_channels, batch_norm=False, **kwargs):
175+
def __init__(self, in_channels, out_channels, **kwargs):
186176
super(BasicConv2d, self).__init__()
187-
self.batch_norm = batch_norm
188-
189-
if self.batch_norm:
190-
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
191-
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
192-
else:
193-
self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
177+
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
178+
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
194179

195180
def forward(self, x):
196181
x = self.conv(x)
197-
if self.batch_norm:
198-
x = self.bn(x)
182+
x = self.bn(x)
199183
return F.relu(x, inplace=True)

0 commit comments

Comments
 (0)