3
3
import torch .nn .functional as F
4
4
from torch .utils import model_zoo
5
5
6
- __all__ = ['GoogLeNet' , 'googlenet' , 'googlenet_bn' ]
6
+ __all__ = ['GoogLeNet' , 'googlenet' ]
7
7
8
8
model_urls = {
9
- 'googlenet' : '' ,
10
- 'googlenet_bn ' : ''
9
+ # GoogLeNet ported from TensorFlow
10
+ 'googlenet ' : 'https://github.com/TheCodez/vision/releases/download/1.0/googlenet-1378be20.pth' ,
11
11
}
12
12
13
13
@@ -18,6 +18,8 @@ def googlenet(pretrained=False, **kwargs):
18
18
pretrained (bool): If True, returns a model pre-trained on ImageNet
19
19
"""
20
20
if pretrained :
21
+ if 'transform_input' not in kwargs :
22
+ kwargs ['transform_input' ] = True
21
23
kwargs ['init_weights' ] = False
22
24
model = GoogLeNet (** kwargs )
23
25
model .load_state_dict (model_zoo .load_url (model_urls ['googlenet' ]))
@@ -26,51 +28,35 @@ def googlenet(pretrained=False, **kwargs):
26
28
return GoogLeNet (** kwargs )
27
29
28
30
29
- def googlenet_bn (pretrained = False , ** kwargs ):
30
- r"""GoogLeNet (Inception v1) model architecture with batch normalization from
31
- `"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
32
- Args:
33
- pretrained (bool): If True, returns a model pre-trained on ImageNet
34
- """
35
- if pretrained :
36
- kwargs ['init_weights' ] = False
37
- model = GoogLeNet (batch_norm = True , ** kwargs )
38
- model .load_state_dict (model_zoo .load_url (model_urls ['googlenet_bn' ]))
39
- return model
40
-
41
- return GoogLeNet (batch_norm = True , ** kwargs )
42
-
43
-
44
31
class GoogLeNet (nn .Module ):
45
32
46
- def __init__ (self , num_classes = 1000 , aux_logits = True , batch_norm = False , init_weights = True ):
33
+ def __init__ (self , num_classes = 1000 , aux_logits = True , transform_input = False , init_weights = True ):
47
34
super (GoogLeNet , self ).__init__ ()
48
35
self .aux_logits = aux_logits
36
+ self .transform_input = transform_input
49
37
50
- self .conv1 = BasicConv2d (3 , 64 , batch_norm , kernel_size = 7 , stride = 2 , padding = 3 )
38
+ self .conv1 = BasicConv2d (3 , 64 , kernel_size = 7 , stride = 2 , padding = 3 )
51
39
self .maxpool1 = nn .MaxPool2d (3 , stride = 2 , ceil_mode = True )
52
- self .lrn1 = nn .LocalResponseNorm (5 , alpha = 0.0001 , beta = 0.75 )
53
- self .conv2 = BasicConv2d (64 , 64 , batch_norm , kernel_size = 1 )
54
- self .conv3 = BasicConv2d (64 , 192 , batch_norm , kernel_size = 3 , padding = 1 )
55
- self .lrn2 = nn .LocalResponseNorm (5 , alpha = 0.0001 , beta = 0.75 )
40
+ self .conv2 = BasicConv2d (64 , 64 , kernel_size = 1 )
41
+ self .conv3 = BasicConv2d (64 , 192 , kernel_size = 3 , padding = 1 )
56
42
self .maxpool2 = nn .MaxPool2d (3 , stride = 2 , ceil_mode = True )
57
43
58
- self .inception3a = Inception (192 , 64 , 96 , 128 , 16 , 32 , 32 , batch_norm )
59
- self .inception3b = Inception (256 , 128 , 128 , 192 , 32 , 96 , 64 , batch_norm )
44
+ self .inception3a = Inception (192 , 64 , 96 , 128 , 16 , 32 , 32 )
45
+ self .inception3b = Inception (256 , 128 , 128 , 192 , 32 , 96 , 64 )
60
46
self .maxpool3 = nn .MaxPool2d (3 , stride = 2 , ceil_mode = True )
61
47
62
- self .inception4a = Inception (480 , 192 , 96 , 208 , 16 , 48 , 64 , batch_norm )
63
- self .inception4b = Inception (512 , 160 , 112 , 224 , 24 , 64 , 64 , batch_norm )
64
- self .inception4c = Inception (512 , 128 , 128 , 256 , 24 , 64 , 64 , batch_norm )
65
- self .inception4d = Inception (512 , 112 , 144 , 288 , 32 , 64 , 64 , batch_norm )
66
- self .inception4e = Inception (528 , 256 , 160 , 320 , 32 , 128 , 128 , batch_norm )
48
+ self .inception4a = Inception (480 , 192 , 96 , 208 , 16 , 48 , 64 )
49
+ self .inception4b = Inception (512 , 160 , 112 , 224 , 24 , 64 , 64 )
50
+ self .inception4c = Inception (512 , 128 , 128 , 256 , 24 , 64 , 64 )
51
+ self .inception4d = Inception (512 , 112 , 144 , 288 , 32 , 64 , 64 )
52
+ self .inception4e = Inception (528 , 256 , 160 , 320 , 32 , 128 , 128 )
67
53
self .maxpool4 = nn .MaxPool2d (3 , stride = 2 , ceil_mode = True )
68
54
69
- self .inception5a = Inception (832 , 256 , 160 , 320 , 32 , 128 , 128 , batch_norm )
70
- self .inception5b = Inception (832 , 384 , 192 , 384 , 48 , 128 , 128 , batch_norm )
55
+ self .inception5a = Inception (832 , 256 , 160 , 320 , 32 , 128 , 128 )
56
+ self .inception5b = Inception (832 , 384 , 192 , 384 , 48 , 128 , 128 )
71
57
if aux_logits :
72
- self .aux1 = InceptionAux (512 , num_classes , batch_norm )
73
- self .aux2 = InceptionAux (528 , num_classes , batch_norm )
58
+ self .aux1 = InceptionAux (512 , num_classes )
59
+ self .aux2 = InceptionAux (528 , num_classes )
74
60
self .avgpool = nn .AdaptiveAvgPool2d ((1 , 1 ))
75
61
self .dropout = nn .Dropout (0.4 )
76
62
self .fc = nn .Linear (1024 , num_classes )
@@ -92,12 +78,16 @@ def _initialize_weights(self):
92
78
nn .init .constant_ (m .bias , 0 )
93
79
94
80
def forward (self , x ):
81
+ if self .transform_input :
82
+ x_ch0 = torch .unsqueeze (x [:, 0 ], 1 ) * (0.229 / 0.5 ) + (0.485 - 0.5 ) / 0.5
83
+ x_ch1 = torch .unsqueeze (x [:, 1 ], 1 ) * (0.224 / 0.5 ) + (0.456 - 0.5 ) / 0.5
84
+ x_ch2 = torch .unsqueeze (x [:, 2 ], 1 ) * (0.225 / 0.5 ) + (0.406 - 0.5 ) / 0.5
85
+ x = torch .cat ((x_ch0 , x_ch1 , x_ch2 ), 1 )
86
+
95
87
x = self .conv1 (x )
96
88
x = self .maxpool1 (x )
97
- x = self .lrn1 (x )
98
89
x = self .conv2 (x )
99
90
x = self .conv3 (x )
100
- x = self .lrn2 (x )
101
91
x = self .maxpool2 (x )
102
92
103
93
x = self .inception3a (x )
@@ -129,24 +119,24 @@ def forward(self, x):
129
119
130
120
class Inception (nn .Module ):
131
121
132
- def __init__ (self , in_channels , ch1x1 , ch3x3red , ch3x3 , ch5x5red , ch5x5 , pool_proj , batch_norm = False ):
122
+ def __init__ (self , in_channels , ch1x1 , ch3x3red , ch3x3 , ch5x5red , ch5x5 , pool_proj ):
133
123
super (Inception , self ).__init__ ()
134
124
135
- self .branch1 = BasicConv2d (in_channels , ch1x1 , batch_norm , kernel_size = 1 )
125
+ self .branch1 = BasicConv2d (in_channels , ch1x1 , kernel_size = 1 )
136
126
137
127
self .branch2 = nn .Sequential (
138
- BasicConv2d (in_channels , ch3x3red , batch_norm , kernel_size = 1 ),
139
- BasicConv2d (ch3x3red , ch3x3 , batch_norm , kernel_size = 3 , padding = 1 )
128
+ BasicConv2d (in_channels , ch3x3red , kernel_size = 1 ),
129
+ BasicConv2d (ch3x3red , ch3x3 , kernel_size = 3 , padding = 1 )
140
130
)
141
131
142
132
self .branch3 = nn .Sequential (
143
- BasicConv2d (in_channels , ch5x5red , batch_norm , kernel_size = 1 ),
144
- BasicConv2d (ch5x5red , ch5x5 , batch_norm , kernel_size = 5 , padding = 2 )
133
+ BasicConv2d (in_channels , ch5x5red , kernel_size = 1 ),
134
+ BasicConv2d (ch5x5red , ch5x5 , kernel_size = 3 , padding = 1 )
145
135
)
146
136
147
137
self .branch4 = nn .Sequential (
148
138
nn .MaxPool2d (kernel_size = 3 , stride = 1 , padding = 1 , ceil_mode = True ),
149
- BasicConv2d (in_channels , pool_proj , batch_norm , kernel_size = 1 )
139
+ BasicConv2d (in_channels , pool_proj , kernel_size = 1 )
150
140
)
151
141
152
142
def forward (self , x ):
@@ -161,11 +151,11 @@ def forward(self, x):
161
151
162
152
class InceptionAux (nn .Module ):
163
153
164
- def __init__ (self , in_channels , num_classes , batch_norm = False ):
154
+ def __init__ (self , in_channels , num_classes ):
165
155
super (InceptionAux , self ).__init__ ()
166
- self .conv = BasicConv2d (in_channels , 128 , batch_norm , kernel_size = 1 )
156
+ self .conv = BasicConv2d (in_channels , 128 , kernel_size = 1 )
167
157
168
- self .fc1 = nn .Linear (128 * 4 * 4 , 1024 )
158
+ self .fc1 = nn .Linear (2048 , 1024 )
169
159
self .fc2 = nn .Linear (1024 , num_classes )
170
160
171
161
def forward (self , x ):
@@ -182,18 +172,12 @@ def forward(self, x):
182
172
183
173
class BasicConv2d (nn .Module ):
184
174
185
- def __init__ (self , in_channels , out_channels , batch_norm = False , ** kwargs ):
175
+ def __init__ (self , in_channels , out_channels , ** kwargs ):
186
176
super (BasicConv2d , self ).__init__ ()
187
- self .batch_norm = batch_norm
188
-
189
- if self .batch_norm :
190
- self .conv = nn .Conv2d (in_channels , out_channels , bias = False , ** kwargs )
191
- self .bn = nn .BatchNorm2d (out_channels , eps = 0.001 )
192
- else :
193
- self .conv = nn .Conv2d (in_channels , out_channels , ** kwargs )
177
+ self .conv = nn .Conv2d (in_channels , out_channels , bias = False , ** kwargs )
178
+ self .bn = nn .BatchNorm2d (out_channels , eps = 0.001 )
194
179
195
180
def forward (self , x ):
196
181
x = self .conv (x )
197
- if self .batch_norm :
198
- x = self .bn (x )
182
+ x = self .bn (x )
199
183
return F .relu (x , inplace = True )
0 commit comments