Skip to content

Commit 8d580a1

Browse files
ShahriarSSfmassa
authored andcommitted
Updated some stuff in models (#1115)
1 parent d84fee6 commit 8d580a1

File tree

9 files changed

+108
-27
lines changed

9 files changed

+108
-27
lines changed

test/test_cpp_models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ def test_resnext50_32x4d(self):
8787
def test_resnext101_32x8d(self):
8888
process_model(models.resnext101_32x8d(), self.image, _C_tests.forward_resnext101_32x8d, 'ResNext101_32x8d')
8989

90+
def test_wide_resnet50_2(self):
91+
process_model(models.wide_resnet50_2(), self.image, _C_tests.forward_wide_resnet50_2, 'WideResNet50_2')
92+
93+
def test_wide_resnet101_2(self):
94+
process_model(models.wide_resnet101_2(), self.image, _C_tests.forward_wide_resnet101_2, 'WideResNet101_2')
95+
9096
def test_squeezenet1_0(self):
9197
process_model(models.squeezenet1_0(self.pretrained), self.image,
9298
_C_tests.forward_squeezenet1_0, 'Squeezenet1.0')

test/test_models.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,16 @@ torch::Tensor forward_resnext101_32x8d(
7373
torch::Tensor x) {
7474
return forward_model<ResNext101_32x8d>(input_path, x);
7575
}
76+
torch::Tensor forward_wide_resnet50_2(
77+
const std::string& input_path,
78+
torch::Tensor x) {
79+
return forward_model<WideResNet50_2>(input_path, x);
80+
}
81+
torch::Tensor forward_wide_resnet101_2(
82+
const std::string& input_path,
83+
torch::Tensor x) {
84+
return forward_model<WideResNet101_2>(input_path, x);
85+
}
7686

7787
torch::Tensor forward_squeezenet1_0(
7888
const std::string& input_path,
@@ -168,6 +178,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
168178
"forward_resnext101_32x8d",
169179
&forward_resnext101_32x8d,
170180
"forward_resnext101_32x8d");
181+
m.def(
182+
"forward_wide_resnet50_2",
183+
&forward_wide_resnet50_2,
184+
"forward_wide_resnet50_2");
185+
m.def(
186+
"forward_wide_resnet101_2",
187+
&forward_wide_resnet101_2,
188+
"forward_wide_resnet101_2");
171189

172190
m.def(
173191
"forward_squeezenet1_0", &forward_squeezenet1_0, "forward_squeezenet1_0");

torchvision/csrc/convert_models/convert_models.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ int main(int argc, const char* argv[]) {
4141
"resnext50_32x4d_python.pt", "resnext50_32x4d_cpp.pt");
4242
convert_and_save_model<ResNext101_32x8d>(
4343
"resnext101_32x8d_python.pt", "resnext101_32x8d_cpp.pt");
44+
convert_and_save_model<WideResNet50_2>(
45+
"wide_resnet50_2_python.pt", "wide_resnet50_2_cpp.pt");
46+
convert_and_save_model<WideResNet101_2>(
47+
"wide_resnet101_2_python.pt", "wide_resnet101_2_cpp.pt");
4448

4549
convert_and_save_model<SqueezeNet1_0>(
4650
"squeezenet1_0_python.pt", "squeezenet1_0_cpp.pt");

torchvision/csrc/models/mobilenet.cpp

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,19 @@ namespace vision {
66
namespace models {
77
using Options = torch::nn::Conv2dOptions;
88

9+
int64_t make_divisible(
10+
double value,
11+
int64_t divisor,
12+
c10::optional<int64_t> min_value = {}) {
13+
if (!min_value.has_value())
14+
min_value = divisor;
15+
auto new_value = std::max(
16+
min_value.value(), (int64_t(value + divisor / 2) / divisor) * divisor);
17+
if (new_value < .9 * value)
18+
new_value += divisor;
19+
return new_value;
20+
}
21+
922
struct ConvBNReLUImpl : torch::nn::SequentialImpl {
1023
ConvBNReLUImpl(
1124
int64_t in_planes,
@@ -69,28 +82,40 @@ struct MobileNetInvertedResidualImpl : torch::nn::Module {
6982

7083
TORCH_MODULE(MobileNetInvertedResidual);
7184

72-
MobileNetV2Impl::MobileNetV2Impl(int64_t num_classes, double width_mult) {
85+
MobileNetV2Impl::MobileNetV2Impl(
86+
int64_t num_classes,
87+
double width_mult,
88+
std::vector<std::vector<int64_t>> inverted_residual_settings,
89+
int64_t round_nearest) {
7390
using Block = MobileNetInvertedResidual;
7491
int64_t input_channel = 32;
7592
int64_t last_channel = 1280;
7693

77-
std::vector<std::vector<int64_t>> inverted_residual_settings = {
78-
// t, c, n, s
79-
{1, 16, 1, 1},
80-
{6, 24, 2, 2},
81-
{6, 32, 3, 2},
82-
{6, 64, 4, 2},
83-
{6, 96, 3, 1},
84-
{6, 160, 3, 2},
85-
{6, 320, 1, 1},
86-
};
87-
88-
input_channel = int64_t(input_channel * width_mult);
89-
this->last_channel = int64_t(last_channel * std::max(1.0, width_mult));
94+
if (inverted_residual_settings.empty())
95+
inverted_residual_settings = {
96+
// t, c, n, s
97+
{1, 16, 1, 1},
98+
{6, 24, 2, 2},
99+
{6, 32, 3, 2},
100+
{6, 64, 4, 2},
101+
{6, 96, 3, 1},
102+
{6, 160, 3, 2},
103+
{6, 320, 1, 1},
104+
};
105+
106+
if (inverted_residual_settings[0].size() != 4) {
107+
std::cerr << "inverted_residual_settings should contain 4-element vectors";
108+
assert(false);
109+
}
110+
111+
input_channel = make_divisible(input_channel * width_mult, round_nearest);
112+
this->last_channel =
113+
make_divisible(last_channel * std::max(1.0, width_mult), round_nearest);
90114
features->push_back(ConvBNReLU(3, input_channel, 3, 2));
91115

92116
for (auto setting : inverted_residual_settings) {
93-
auto output_channel = int64_t(setting[1] * width_mult);
117+
auto output_channel =
118+
make_divisible(setting[1] * width_mult, round_nearest);
94119

95120
for (int64_t i = 0; i < setting[2]; ++i) {
96121
auto stride = i == 0 ? setting[3] : 1;

torchvision/csrc/models/mobilenet.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@ struct VISION_API MobileNetV2Impl : torch::nn::Module {
1010
int64_t last_channel;
1111
torch::nn::Sequential features, classifier;
1212

13-
MobileNetV2Impl(int64_t num_classes = 1000, double width_mult = 1.0);
13+
MobileNetV2Impl(
14+
int64_t num_classes = 1000,
15+
double width_mult = 1.0,
16+
std::vector<std::vector<int64_t>> inverted_residual_settings = {},
17+
int64_t round_nearest = 8);
1418

1519
torch::Tensor forward(torch::Tensor x);
1620
};

torchvision/csrc/models/resnet.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,5 +145,15 @@ ResNext101_32x8dImpl::ResNext101_32x8dImpl(
145145
bool zero_init_residual)
146146
: ResNetImpl({3, 4, 23, 3}, num_classes, zero_init_residual, 32, 8) {}
147147

148+
WideResNet50_2Impl::WideResNet50_2Impl(
149+
int64_t num_classes,
150+
bool zero_init_residual)
151+
: ResNetImpl({3, 4, 6, 3}, num_classes, zero_init_residual, 1, 64 * 2) {}
152+
153+
WideResNet101_2Impl::WideResNet101_2Impl(
154+
int64_t num_classes,
155+
bool zero_init_residual)
156+
: ResNetImpl({3, 4, 23, 3}, num_classes, zero_init_residual, 1, 64 * 2) {}
157+
148158
} // namespace models
149159
} // namespace vision

torchvision/csrc/models/resnet.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,18 @@ struct VISION_API ResNext101_32x8dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
217217
bool zero_init_residual = false);
218218
};
219219

220+
struct VISION_API WideResNet50_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
221+
WideResNet50_2Impl(
222+
int64_t num_classes = 1000,
223+
bool zero_init_residual = false);
224+
};
225+
226+
struct VISION_API WideResNet101_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
227+
WideResNet101_2Impl(
228+
int64_t num_classes = 1000,
229+
bool zero_init_residual = false);
230+
};
231+
220232
template <typename Block>
221233
struct VISION_API ResNet : torch::nn::ModuleHolder<ResNetImpl<Block>> {
222234
using torch::nn::ModuleHolder<ResNetImpl<Block>>::ModuleHolder;
@@ -229,6 +241,8 @@ TORCH_MODULE(ResNet101);
229241
TORCH_MODULE(ResNet152);
230242
TORCH_MODULE(ResNext50_32x4d);
231243
TORCH_MODULE(ResNext101_32x8d);
244+
TORCH_MODULE(WideResNet50_2);
245+
TORCH_MODULE(WideResNet101_2);
232246

233247
} // namespace models
234248
} // namespace vision

torchvision/csrc/models/squeezenet.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ SqueezeNetImpl::SqueezeNetImpl(double version, int64_t num_classes)
6565
Fire(384, 64, 256, 256),
6666
Fire(512, 64, 256, 256));
6767
} else {
68-
std::cerr << "Wrong version number is passed th SqueeseNet constructor!"
69-
<< std::endl;
68+
std::cerr << "Unsupported SqueezeNet version " << version
69+
<< ". 1_0 or 1_1 expected" << std::endl;
7070
assert(false);
7171
}
7272

torchvision/csrc/models/vgg.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,36 +79,36 @@ torch::Tensor VGGImpl::forward(torch::Tensor x) {
7979
}
8080

8181
// clang-format off
82-
static std::unordered_map<char, std::vector<int>> cfg = {
82+
static std::unordered_map<char, std::vector<int>> cfgs = {
8383
{'A', {64, -1, 128, -1, 256, 256, -1, 512, 512, -1, 512, 512, -1}},
8484
{'B', {64, 64, -1, 128, 128, -1, 256, 256, -1, 512, 512, -1, 512, 512, -1}},
8585
{'D', {64, 64, -1, 128, 128, -1, 256, 256, 256, -1, 512, 512, 512, -1, 512, 512, 512, -1}},
8686
{'E', {64, 64, -1, 128, 128, -1, 256, 256, 256, 256, -1, 512, 512, 512, 512, -1, 512, 512, 512, 512, -1}}};
8787
// clang-format on
8888

8989
VGG11Impl::VGG11Impl(int64_t num_classes, bool initialize_weights)
90-
: VGGImpl(makeLayers(cfg['A']), num_classes, initialize_weights) {}
90+
: VGGImpl(makeLayers(cfgs['A']), num_classes, initialize_weights) {}
9191

9292
VGG13Impl::VGG13Impl(int64_t num_classes, bool initialize_weights)
93-
: VGGImpl(makeLayers(cfg['B']), num_classes, initialize_weights) {}
93+
: VGGImpl(makeLayers(cfgs['B']), num_classes, initialize_weights) {}
9494

9595
VGG16Impl::VGG16Impl(int64_t num_classes, bool initialize_weights)
96-
: VGGImpl(makeLayers(cfg['D']), num_classes, initialize_weights) {}
96+
: VGGImpl(makeLayers(cfgs['D']), num_classes, initialize_weights) {}
9797

9898
VGG19Impl::VGG19Impl(int64_t num_classes, bool initialize_weights)
99-
: VGGImpl(makeLayers(cfg['E']), num_classes, initialize_weights) {}
99+
: VGGImpl(makeLayers(cfgs['E']), num_classes, initialize_weights) {}
100100

101101
VGG11BNImpl::VGG11BNImpl(int64_t num_classes, bool initialize_weights)
102-
: VGGImpl(makeLayers(cfg['A'], true), num_classes, initialize_weights) {}
102+
: VGGImpl(makeLayers(cfgs['A'], true), num_classes, initialize_weights) {}
103103

104104
VGG13BNImpl::VGG13BNImpl(int64_t num_classes, bool initialize_weights)
105-
: VGGImpl(makeLayers(cfg['B'], true), num_classes, initialize_weights) {}
105+
: VGGImpl(makeLayers(cfgs['B'], true), num_classes, initialize_weights) {}
106106

107107
VGG16BNImpl::VGG16BNImpl(int64_t num_classes, bool initialize_weights)
108-
: VGGImpl(makeLayers(cfg['D'], true), num_classes, initialize_weights) {}
108+
: VGGImpl(makeLayers(cfgs['D'], true), num_classes, initialize_weights) {}
109109

110110
VGG19BNImpl::VGG19BNImpl(int64_t num_classes, bool initialize_weights)
111-
: VGGImpl(makeLayers(cfg['E'], true), num_classes, initialize_weights) {}
111+
: VGGImpl(makeLayers(cfgs['E'], true), num_classes, initialize_weights) {}
112112

113113
} // namespace models
114114
} // namespace vision

0 commit comments

Comments
 (0)