Skip to content

Commit 2c8caab

Browse files
committed
Use ceil_mode instead of padding and initialize weights using "xavier"
1 parent 6c0d34d commit 2c8caab

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

torchvision/models/googlenet.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,23 @@ def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False):
4646
self.aux_logits = aux_logits
4747

4848
self.conv1 = BasicConv2d(3, 64, batch_norm, kernel_size=7, stride=2, padding=3)
49-
self.maxpool1 = nn.MaxPool2d(3, stride=2, padding=1)
49+
self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
5050
self.lrn1 = nn.LocalResponseNorm(5, alpha=0.0001)
5151
self.conv2 = BasicConv2d(64, 64, batch_norm, kernel_size=1)
5252
self.conv3 = BasicConv2d(64, 192, batch_norm, kernel_size=3, stride=1, padding=1)
5353
self.lrn2 = nn.LocalResponseNorm(5, alpha=0.0001)
54-
self.maxpool2 = nn.MaxPool2d(3, stride=2, padding=1)
54+
self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
5555

5656
self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32, batch_norm)
5757
self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64, batch_norm)
58-
self.maxpool3 = nn.MaxPool2d(3, stride=2, padding=1)
58+
self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
5959

6060
self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64, batch_norm)
6161
self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64, batch_norm)
6262
self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64, batch_norm)
6363
self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64, batch_norm)
6464
self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128, batch_norm)
65-
self.maxpool4 = nn.MaxPool2d(3, stride=2, padding=1)
65+
self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
6666

6767
self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128, batch_norm)
6868
self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128, batch_norm)
@@ -75,11 +75,9 @@ def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False):
7575

7676
for m in self.modules():
7777
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
78-
import scipy.stats as stats
79-
X = stats.truncnorm(-2, 2, scale=0.01)
80-
values = torch.Tensor(X.rvs(m.weight.numel()))
81-
values = values.view(m.weight.size())
82-
m.weight.data.copy_(values)
78+
nn.init.xavier_uniform_(m.weight)
79+
if m.bias is not None:
80+
nn.init.constant_(m.bias, 0.2)
8381

8482
def forward(self, x):
8583
x = self.conv1(x)
@@ -135,7 +133,7 @@ def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_pr
135133
)
136134

137135
self.branch4 = nn.Sequential(
138-
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
136+
nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
139137
BasicConv2d(in_channels, pool_proj, batch_norm, kernel_size=1)
140138
)
141139

0 commit comments

Comments
 (0)