4
4
'''
5
5
import torch
6
6
import torch .nn as nn
7
- import torch .nn .init as init
8
7
import torch .nn .functional as F
9
8
10
9
from torch .autograd import Variable
@@ -14,20 +13,21 @@ class Block(nn.Module):
14
13
'''Grouped convolution block.'''
15
14
expansion = 2
16
15
17
- def __init__ (self , in_planes , planes , stride = 1 , cardinality = 32 ):
16
+ def __init__ (self , in_planes , cardinality = 32 , bottleneck_width = 4 , stride = 1 ):
18
17
super (Block , self ).__init__ ()
19
- self .conv1 = nn .Conv2d (in_planes , planes , kernel_size = 1 , bias = False )
20
- self .bn1 = nn .BatchNorm2d (planes )
21
- self .conv2 = nn .Conv2d (planes , planes , kernel_size = 3 , stride = stride , padding = 1 , groups = cardinality , bias = False )
22
- self .bn2 = nn .BatchNorm2d (planes )
23
- self .conv3 = nn .Conv2d (planes , self .expansion * planes , kernel_size = 1 , bias = False )
24
- self .bn3 = nn .BatchNorm2d (self .expansion * planes )
18
+ group_width = cardinality * bottleneck_width
19
+ self .conv1 = nn .Conv2d (in_planes , group_width , kernel_size = 1 , bias = False )
20
+ self .bn1 = nn .BatchNorm2d (group_width )
21
+ self .conv2 = nn .Conv2d (group_width , group_width , kernel_size = 3 , stride = stride , padding = 1 , groups = cardinality , bias = False )
22
+ self .bn2 = nn .BatchNorm2d (group_width )
23
+ self .conv3 = nn .Conv2d (group_width , self .expansion * group_width , kernel_size = 1 , bias = False )
24
+ self .bn3 = nn .BatchNorm2d (self .expansion * group_width )
25
25
26
26
self .shortcut = nn .Sequential ()
27
- if stride != 1 or in_planes != self .expansion * planes :
27
+ if stride != 1 or in_planes != self .expansion * group_width :
28
28
self .shortcut = nn .Sequential (
29
- nn .Conv2d (in_planes , self .expansion * planes , kernel_size = 1 , stride = stride , bias = False ),
30
- nn .BatchNorm2d (self .expansion * planes )
29
+ nn .Conv2d (in_planes , self .expansion * group_width , kernel_size = 1 , stride = stride , bias = False ),
30
+ nn .BatchNorm2d (self .expansion * group_width )
31
31
)
32
32
33
33
def forward (self , x ):
@@ -40,26 +40,28 @@ def forward(self, x):
40
40
41
41
42
42
class ResNeXt (nn .Module ):
43
- def __init__ (self , block , num_blocks , cardinality = 32 , num_classes = 10 ):
43
+ def __init__ (self , num_blocks , cardinality , bottleneck_width , num_classes = 10 ):
44
44
super (ResNeXt , self ).__init__ ()
45
+ self .cardinality = cardinality
46
+ self .bottleneck_width = bottleneck_width
45
47
self .in_planes = 64
46
48
47
49
self .conv1 = nn .Conv2d (3 , 64 , kernel_size = 1 , bias = False )
48
50
self .bn1 = nn .BatchNorm2d (64 )
49
- self .layer1 = self ._make_layer (block , 64 , num_blocks [0 ], 1 , cardinality )
50
- self .layer2 = self ._make_layer (block , 128 , num_blocks [1 ], 2 , cardinality )
51
- self .layer3 = self ._make_layer (block , 256 , num_blocks [2 ], 2 , cardinality )
52
- # self.layer4 = self._make_layer(block, 512, num_blocks[3], 2, cardinality )
53
- self .linear = nn .Linear (512 , num_classes )
51
+ self .layer1 = self ._make_layer (num_blocks [0 ], 1 )
52
+ self .layer2 = self ._make_layer (num_blocks [1 ], 2 )
53
+ self .layer3 = self ._make_layer (num_blocks [2 ], 2 )
54
+ # self.layer4 = self._make_layer(num_blocks[3], 2)
55
+ self .linear = nn .Linear (cardinality * bottleneck_width * 8 , num_classes )
54
56
55
- self .init_params ()
56
-
57
- def _make_layer (self , block , planes , num_blocks , stride , cardinality ):
57
+ def _make_layer (self , num_blocks , stride ):
58
58
strides = [stride ] + [1 ]* (num_blocks - 1 )
59
59
layers = []
60
60
for stride in strides :
61
- layers .append (block (self .in_planes , planes , stride , cardinality ))
62
- self .in_planes = planes * block .expansion
61
+ layers .append (Block (self .in_planes , self .cardinality , self .bottleneck_width , stride ))
62
+ self .in_planes = Block .expansion * self .cardinality * self .bottleneck_width
63
+ # Increase bottleneck_width by 2 after each stage.
64
+ self .bottleneck_width *= 2
63
65
return nn .Sequential (* layers )
64
66
65
67
def forward (self , x ):
@@ -73,27 +75,23 @@ def forward(self, x):
73
75
out = self .linear (out )
74
76
return out
75
77
76
- def init_params (self ):
77
- '''Init layer parameters.'''
78
- for m in self .modules ():
79
- if isinstance (m , nn .Conv2d ):
80
- init .kaiming_normal (m .weight , mode = 'fan_out' )
81
- if m .bias :
82
- init .constant (m .bias , 0 )
83
- elif isinstance (m , nn .BatchNorm2d ):
84
- init .constant (m .weight , 1 )
85
- init .constant (m .bias , 0 )
86
- elif isinstance (m , nn .Linear ):
87
- init .normal (m .weight , std = 1e-3 )
88
- if m .bias :
89
- init .constant (m .bias , 0 )
90
-
91
-
92
- def ResNeXt29 ():
93
- return ResNeXt (Block , [3 ,3 ,3 ])
94
-
95
-
96
- # net = resnext_cifar()
97
- # x = torch.randn(1,3,32,32)
98
- # y = net(Variable(x))
99
- # print(y.size())
78
+
79
+ def ResNeXt29_2x64d ():
80
+ return ResNeXt (num_blocks = [3 ,3 ,3 ], cardinality = 2 , bottleneck_width = 64 )
81
+
82
+ def ResNeXt29_4x64d ():
83
+ return ResNeXt (num_blocks = [3 ,3 ,3 ], cardinality = 4 , bottleneck_width = 64 )
84
+
85
+ def ResNeXt29_8x64d ():
86
+ return ResNeXt (num_blocks = [3 ,3 ,3 ], cardinality = 8 , bottleneck_width = 64 )
87
+
88
+ def ResNeXt29_32x4d ():
89
+ return ResNeXt (num_blocks = [3 ,3 ,3 ], cardinality = 32 , bottleneck_width = 4 )
90
+
91
+ def test_resnext ():
92
+ net = ResNeXt29_2x64d ()
93
+ x = torch .randn (1 ,3 ,32 ,32 )
94
+ y = net (Variable (x ))
95
+ print (y .size ())
96
+
97
+ # test_resnext()
0 commit comments