Skip to content

Commit f640996

Browse files
committed
use adaptive avg pool
1 parent 664758e commit f640996

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchvision/models/googlenet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False, init_wei
7171
if aux_logits:
7272
self.aux1 = InceptionAux(512, num_classes, batch_norm)
7373
self.aux2 = InceptionAux(528, num_classes, batch_norm)
74-
self.avgpool = nn.AvgPool2d(7, stride=1)
74+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
7575
self.dropout = nn.Dropout(0.4)
7676
self.fc = nn.Linear(1024, num_classes)
7777

@@ -169,7 +169,7 @@ def __init__(self, in_channels, num_classes, batch_norm=False):
169169
self.fc2 = nn.Linear(1024, num_classes)
170170

171171
def forward(self, x):
172-
x = F.avg_pool2d(x, kernel_size=5, stride=3)
172+
x = F.adaptive_avg_pool2d(x, (4, 4))
173173

174174
x = self.conv(x)
175175
x = x.view(x.size(0), -1)

0 commit comments

Comments
 (0)