@@ -46,23 +46,23 @@ def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False):
46
46
self .aux_logits = aux_logits
47
47
48
48
self .conv1 = BasicConv2d (3 , 64 , batch_norm , kernel_size = 7 , stride = 2 , padding = 3 )
49
- self .maxpool1 = nn .MaxPool2d (3 , stride = 2 , padding = 1 )
49
+ self .maxpool1 = nn .MaxPool2d (3 , stride = 2 , ceil_mode = True )
50
50
self .lrn1 = nn .LocalResponseNorm (5 , alpha = 0.0001 )
51
51
self .conv2 = BasicConv2d (64 , 64 , batch_norm , kernel_size = 1 )
52
52
self .conv3 = BasicConv2d (64 , 192 , batch_norm , kernel_size = 3 , stride = 1 , padding = 1 )
53
53
self .lrn2 = nn .LocalResponseNorm (5 , alpha = 0.0001 )
54
- self .maxpool2 = nn .MaxPool2d (3 , stride = 2 , padding = 1 )
54
+ self .maxpool2 = nn .MaxPool2d (3 , stride = 2 , ceil_mode = True )
55
55
56
56
self .inception3a = Inception (192 , 64 , 96 , 128 , 16 , 32 , 32 , batch_norm )
57
57
self .inception3b = Inception (256 , 128 , 128 , 192 , 32 , 96 , 64 , batch_norm )
58
- self .maxpool3 = nn .MaxPool2d (3 , stride = 2 , padding = 1 )
58
+ self .maxpool3 = nn .MaxPool2d (3 , stride = 2 , ceil_mode = True )
59
59
60
60
self .inception4a = Inception (480 , 192 , 96 , 208 , 16 , 48 , 64 , batch_norm )
61
61
self .inception4b = Inception (512 , 160 , 112 , 224 , 24 , 64 , 64 , batch_norm )
62
62
self .inception4c = Inception (512 , 128 , 128 , 256 , 24 , 64 , 64 , batch_norm )
63
63
self .inception4d = Inception (512 , 112 , 144 , 288 , 32 , 64 , 64 , batch_norm )
64
64
self .inception4e = Inception (528 , 256 , 160 , 320 , 32 , 128 , 128 , batch_norm )
65
- self .maxpool4 = nn .MaxPool2d (3 , stride = 2 , padding = 1 )
65
+ self .maxpool4 = nn .MaxPool2d (3 , stride = 2 , ceil_mode = True )
66
66
67
67
self .inception5a = Inception (832 , 256 , 160 , 320 , 32 , 128 , 128 , batch_norm )
68
68
self .inception5b = Inception (832 , 384 , 192 , 384 , 48 , 128 , 128 , batch_norm )
@@ -75,11 +75,9 @@ def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False):
75
75
76
76
for m in self .modules ():
77
77
if isinstance (m , nn .Conv2d ) or isinstance (m , nn .Linear ):
78
- import scipy .stats as stats
79
- X = stats .truncnorm (- 2 , 2 , scale = 0.01 )
80
- values = torch .Tensor (X .rvs (m .weight .numel ()))
81
- values = values .view (m .weight .size ())
82
- m .weight .data .copy_ (values )
78
+ nn .init .xavier_uniform_ (m .weight )
79
+ if m .bias is not None :
80
+ nn .init .constant_ (m .bias , 0.2 )
83
81
84
82
def forward (self , x ):
85
83
x = self .conv1 (x )
@@ -135,7 +133,7 @@ def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_pr
135
133
)
136
134
137
135
self .branch4 = nn .Sequential (
138
- nn .MaxPool2d (kernel_size = 3 , stride = 1 , padding = 1 ),
136
+ nn .MaxPool2d (kernel_size = 3 , stride = 1 , padding = 1 , ceil_mode = True ),
139
137
BasicConv2d (in_channels , pool_proj , batch_norm , kernel_size = 1 )
140
138
)
141
139
0 commit comments