Skip to content

Commit 23880d4

Browse files
fmigneaultfacebook-github-bot
authored andcommitted
replace torch 1.5.0 items flagged with deprecation warnings (fix #190… (#2435)
Summary: …6) (#1918) Pull Request resolved: #2435 Reviewed By: lw Differential Revision: D22438546 Pulled By: fmassa fbshipit-source-id: 8200da87e3459ddaddf089d7d99f4535b5049743
1 parent e2825e8 commit 23880d4

File tree

11 files changed

+48
-48
lines changed

11 files changed

+48
-48
lines changed

torchvision/csrc/models/densenet.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ struct _DenseLayerImpl : torch::nn::SequentialImpl {
1515
int64_t bn_size,
1616
double drop_rate)
1717
: drop_rate(drop_rate) {
18-
push_back("norm1", torch::nn::BatchNorm(num_input_features));
18+
push_back("norm1", torch::nn::BatchNorm2d(num_input_features));
1919
push_back("relu1", torch::nn::Functional(modelsimpl::relu_));
2020
push_back(
2121
"conv1",
2222
torch::nn::Conv2d(Options(num_input_features, bn_size * growth_rate, 1)
2323
.stride(1)
2424
.bias(false)));
25-
push_back("norm2", torch::nn::BatchNorm(bn_size * growth_rate));
25+
push_back("norm2", torch::nn::BatchNorm2d(bn_size * growth_rate));
2626
push_back("relu2", torch::nn::Functional(modelsimpl::relu_));
2727
push_back(
2828
"conv2",
@@ -69,7 +69,7 @@ TORCH_MODULE(_DenseBlock);
6969

7070
struct _TransitionImpl : torch::nn::SequentialImpl {
7171
_TransitionImpl(int64_t num_input_features, int64_t num_output_features) {
72-
push_back("norm", torch::nn::BatchNorm(num_input_features));
72+
push_back("norm", torch::nn::BatchNorm2d(num_input_features));
7373
push_back("relu ", torch::nn::Functional(modelsimpl::relu_));
7474
push_back(
7575
"conv",
@@ -102,7 +102,7 @@ DenseNetImpl::DenseNetImpl(
102102
torch::nn::Conv2d(
103103
Options(3, num_init_features, 7).stride(2).padding(3).bias(false)));
104104

105-
features->push_back("norm0", torch::nn::BatchNorm(num_init_features));
105+
features->push_back("norm0", torch::nn::BatchNorm2d(num_init_features));
106106
features->push_back("relu0", torch::nn::Functional(modelsimpl::relu_));
107107
features->push_back(
108108
"pool0", torch::nn::Functional(torch::max_pool2d, 3, 2, 1, 1, false));
@@ -125,7 +125,7 @@ DenseNetImpl::DenseNetImpl(
125125
}
126126

127127
// Final batch norm
128-
features->push_back("norm5", torch::nn::BatchNorm(num_features));
128+
features->push_back("norm5", torch::nn::BatchNorm2d(num_features));
129129
// Linear layer
130130
classifier = torch::nn::Linear(num_features, num_classes);
131131

@@ -136,7 +136,7 @@ DenseNetImpl::DenseNetImpl(
136136
for (auto& module : modules(/*include_self=*/false)) {
137137
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get()))
138138
torch::nn::init::kaiming_normal_(M->weight);
139-
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
139+
else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
140140
torch::nn::init::constant_(M->weight, 1);
141141
torch::nn::init::constant_(M->bias, 0);
142142
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get()))

torchvision/csrc/models/googlenet.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace _googlenetimpl {
1111
BasicConv2dImpl::BasicConv2dImpl(torch::nn::Conv2dOptions options) {
1212
options.bias(false);
1313
conv = torch::nn::Conv2d(options);
14-
bn = torch::nn::BatchNorm(
14+
bn = torch::nn::BatchNorm2d(
1515
torch::nn::BatchNormOptions(options.out_channels()).eps(0.001));
1616

1717
register_module("conv", conv);
@@ -155,7 +155,7 @@ void GoogLeNetImpl::_initialize_weights() {
155155
else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get()))
156156
torch::nn::init::normal_(M->weight); // Note: used instead of truncated
157157
// normal initialization
158-
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
158+
else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
159159
torch::nn::init::ones_(M->weight);
160160
torch::nn::init::zeros_(M->bias);
161161
}

torchvision/csrc/models/googlenet.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace models {
1010
namespace _googlenetimpl {
1111
struct VISION_API BasicConv2dImpl : torch::nn::Module {
1212
torch::nn::Conv2d conv{nullptr};
13-
torch::nn::BatchNorm bn{nullptr};
13+
torch::nn::BatchNorm2d bn{nullptr};
1414

1515
BasicConv2dImpl(torch::nn::Conv2dOptions options);
1616

torchvision/csrc/models/inception.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ BasicConv2dImpl::BasicConv2dImpl(
1111
double std_dev) {
1212
options.bias(false);
1313
conv = torch::nn::Conv2d(options);
14-
bn = torch::nn::BatchNorm(
14+
bn = torch::nn::BatchNorm2d(
1515
torch::nn::BatchNormOptions(options.out_channels()).eps(0.001));
1616

1717
register_module("conv", conv);

torchvision/csrc/models/inception.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace models {
99
namespace _inceptionimpl {
1010
struct VISION_API BasicConv2dImpl : torch::nn::Module {
1111
torch::nn::Conv2d conv{nullptr};
12-
torch::nn::BatchNorm bn{nullptr};
12+
torch::nn::BatchNorm2d bn{nullptr};
1313

1414
BasicConv2dImpl(torch::nn::Conv2dOptions options, double std_dev = 0.1);
1515

torchvision/csrc/models/mnasnet.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module {
2424
apply_residual = input == output && stride == 1;
2525

2626
layers->push_back(torch::nn::Conv2d(Options(input, mid, 1).bias(false)));
27-
layers->push_back(torch::nn::BatchNorm(
27+
layers->push_back(torch::nn::BatchNorm2d(
2828
torch::nn::BatchNormOptions(mid).momentum(bn_momentum)));
2929
layers->push_back(
3030
torch::nn::Functional(torch::nn::Functional(modelsimpl::relu_)));
@@ -34,12 +34,12 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module {
3434
.stride(stride)
3535
.groups(mid)
3636
.bias(false))));
37-
layers->push_back(torch::nn::BatchNorm(
37+
layers->push_back(torch::nn::BatchNorm2d(
3838
torch::nn::BatchNormOptions(mid).momentum(bn_momentum)));
3939
layers->push_back(
4040
torch::nn::Functional(torch::nn::Functional(modelsimpl::relu_)));
4141
layers->push_back(torch::nn::Conv2d(Options(mid, output, 1).bias(false)));
42-
layers->push_back(torch::nn::BatchNorm(
42+
layers->push_back(torch::nn::BatchNorm2d(
4343
torch::nn::BatchNormOptions(output).momentum(bn_momentum)));
4444

4545
register_module("layers", layers);
@@ -109,9 +109,9 @@ void MNASNetImpl::_initialize_weights() {
109109
torch::nn::init::kaiming_normal_(
110110
M->weight,
111111
0,
112-
torch::nn::init::FanMode::FanOut,
113-
torch::nn::init::Nonlinearity::ReLU);
114-
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
112+
torch::kFanOut,
113+
torch::kReLU);
114+
else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
115115
torch::nn::init::ones_(M->weight);
116116
torch::nn::init::zeros_(M->bias);
117117
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) {
@@ -128,17 +128,17 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) {
128128

129129
layers->push_back(
130130
torch::nn::Conv2d(Options(3, 32, 3).padding(1).stride(2).bias(false)));
131-
layers->push_back(torch::nn::BatchNorm(
131+
layers->push_back(torch::nn::BatchNorm2d(
132132
torch::nn::BatchNormOptions(32).momentum(BN_MOMENTUM)));
133133
layers->push_back(torch::nn::Functional(modelsimpl::relu_));
134134
layers->push_back(torch::nn::Conv2d(
135135
Options(32, 32, 3).padding(1).stride(1).groups(32).bias(false)));
136-
layers->push_back(torch::nn::BatchNorm(
136+
layers->push_back(torch::nn::BatchNorm2d(
137137
torch::nn::BatchNormOptions(32).momentum(BN_MOMENTUM)));
138138
layers->push_back(torch::nn::Functional(modelsimpl::relu_));
139139
layers->push_back(
140140
torch::nn::Conv2d(Options(32, 16, 1).padding(0).stride(1).bias(false)));
141-
layers->push_back(torch::nn::BatchNorm(
141+
layers->push_back(torch::nn::BatchNorm2d(
142142
torch::nn::BatchNormOptions(16).momentum(BN_MOMENTUM)));
143143

144144
layers->push_back(stack(16, depths[0], 3, 2, 3, 3, BN_MOMENTUM));
@@ -150,7 +150,7 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) {
150150

151151
layers->push_back(torch::nn::Conv2d(
152152
Options(depths[5], 1280, 1).padding(0).stride(1).bias(false)));
153-
layers->push_back(torch::nn::BatchNorm(
153+
layers->push_back(torch::nn::BatchNorm2d(
154154
torch::nn::BatchNormOptions(1280).momentum(BN_MOMENTUM)));
155155
layers->push_back(torch::nn::Functional(modelsimpl::relu_));
156156

torchvision/csrc/models/mobilenet.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct ConvBNReLUImpl : torch::nn::SequentialImpl {
3333
.padding(padding)
3434
.groups(groups)
3535
.bias(false)));
36-
push_back(torch::nn::BatchNorm(out_planes));
36+
push_back(torch::nn::BatchNorm2d(out_planes));
3737
push_back(torch::nn::Functional(modelsimpl::relu6_));
3838
}
3939

@@ -68,7 +68,7 @@ struct MobileNetInvertedResidualImpl : torch::nn::Module {
6868
conv->push_back(ConvBNReLU(hidden_dim, hidden_dim, 3, stride, hidden_dim));
6969
conv->push_back(torch::nn::Conv2d(
7070
Options(hidden_dim, output, 1).stride(1).padding(0).bias(false)));
71-
conv->push_back(torch::nn::BatchNorm(output));
71+
conv->push_back(torch::nn::BatchNorm2d(output));
7272

7373
register_module("conv", conv);
7474
}
@@ -135,10 +135,10 @@ MobileNetV2Impl::MobileNetV2Impl(
135135
for (auto& module : modules(/*include_self=*/false)) {
136136
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) {
137137
torch::nn::init::kaiming_normal_(
138-
M->weight, 0, torch::nn::init::FanMode::FanOut);
138+
M->weight, 0, torch::kFanOut);
139139
if (M->options.bias())
140140
torch::nn::init::zeros_(M->bias);
141-
} else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
141+
} else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
142142
torch::nn::init::ones_(M->weight);
143143
torch::nn::init::zeros_(M->bias);
144144
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) {

torchvision/csrc/models/resnet.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ BasicBlock::BasicBlock(
4040
conv1 = conv3x3(inplanes, planes, stride);
4141
conv2 = conv3x3(planes, planes);
4242

43-
bn1 = torch::nn::BatchNorm(planes);
44-
bn2 = torch::nn::BatchNorm(planes);
43+
bn1 = torch::nn::BatchNorm2d(planes);
44+
bn2 = torch::nn::BatchNorm2d(planes);
4545

4646
register_module("conv1", conv1);
4747
register_module("conv2", conv2);
@@ -68,9 +68,9 @@ Bottleneck::Bottleneck(
6868
conv2 = conv3x3(width, width, stride, groups);
6969
conv3 = conv1x1(width, planes * expansion);
7070

71-
bn1 = torch::nn::BatchNorm(width);
72-
bn2 = torch::nn::BatchNorm(width);
73-
bn3 = torch::nn::BatchNorm(planes * expansion);
71+
bn1 = torch::nn::BatchNorm2d(width);
72+
bn2 = torch::nn::BatchNorm2d(width);
73+
bn3 = torch::nn::BatchNorm2d(planes * expansion);
7474

7575
register_module("conv1", conv1);
7676
register_module("conv2", conv2);

torchvision/csrc/models/resnet.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ struct VISION_API BasicBlock : torch::nn::Module {
2828
torch::nn::Sequential downsample;
2929

3030
torch::nn::Conv2d conv1{nullptr}, conv2{nullptr};
31-
torch::nn::BatchNorm bn1{nullptr}, bn2{nullptr};
31+
torch::nn::BatchNorm2d bn1{nullptr}, bn2{nullptr};
3232

3333
static int expansion;
3434

@@ -51,7 +51,7 @@ struct VISION_API Bottleneck : torch::nn::Module {
5151
torch::nn::Sequential downsample;
5252

5353
torch::nn::Conv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr};
54-
torch::nn::BatchNorm bn1{nullptr}, bn2{nullptr}, bn3{nullptr};
54+
torch::nn::BatchNorm2d bn1{nullptr}, bn2{nullptr}, bn3{nullptr};
5555

5656
static int expansion;
5757

@@ -71,7 +71,7 @@ template <typename Block>
7171
struct ResNetImpl : torch::nn::Module {
7272
int64_t groups, base_width, inplanes;
7373
torch::nn::Conv2d conv1;
74-
torch::nn::BatchNorm bn1;
74+
torch::nn::BatchNorm2d bn1;
7575
torch::nn::Sequential layer1, layer2, layer3, layer4;
7676
torch::nn::Linear fc;
7777

@@ -99,7 +99,7 @@ torch::nn::Sequential ResNetImpl<Block>::_make_layer(
9999
if (stride != 1 || inplanes != planes * Block::expansion) {
100100
downsample = torch::nn::Sequential(
101101
_resnetimpl::conv1x1(inplanes, planes * Block::expansion, stride),
102-
torch::nn::BatchNorm(planes * Block::expansion));
102+
torch::nn::BatchNorm2d(planes * Block::expansion));
103103
}
104104

105105
torch::nn::Sequential layers;
@@ -146,9 +146,9 @@ ResNetImpl<Block>::ResNetImpl(
146146
torch::nn::init::kaiming_normal_(
147147
M->weight,
148148
/*a=*/0,
149-
torch::nn::init::FanMode::FanOut,
150-
torch::nn::init::Nonlinearity::ReLU);
151-
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
149+
torch::kFanOut,
150+
torch::kReLU);
151+
else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
152152
torch::nn::init::constant_(M->weight, 1);
153153
torch::nn::init::constant_(M->bias, 0);
154154
}

torchvision/csrc/models/shufflenetv2.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,20 @@ struct ShuffleNetV2InvertedResidualImpl : torch::nn::Module {
4949
if (stride > 1) {
5050
branch1 = torch::nn::Sequential(
5151
conv33(inp, inp, stride),
52-
torch::nn::BatchNorm(inp),
52+
torch::nn::BatchNorm2d(inp),
5353
conv11(inp, branch_features),
54-
torch::nn::BatchNorm(branch_features),
54+
torch::nn::BatchNorm2d(branch_features),
5555
torch::nn::Functional(modelsimpl::relu_));
5656
}
5757

5858
branch2 = torch::nn::Sequential(
5959
conv11(stride > 1 ? inp : branch_features, branch_features),
60-
torch::nn::BatchNorm(branch_features),
60+
torch::nn::BatchNorm2d(branch_features),
6161
torch::nn::Functional(modelsimpl::relu_),
6262
conv33(branch_features, branch_features, stride),
63-
torch::nn::BatchNorm(branch_features),
63+
torch::nn::BatchNorm2d(branch_features),
6464
conv11(branch_features, branch_features),
65-
torch::nn::BatchNorm(branch_features),
65+
torch::nn::BatchNorm2d(branch_features),
6666
torch::nn::Functional(modelsimpl::relu_));
6767

6868
if (!branch1.is_empty())
@@ -108,7 +108,7 @@ ShuffleNetV2Impl::ShuffleNetV2Impl(
108108
.stride(2)
109109
.padding(1)
110110
.bias(false)),
111-
torch::nn::BatchNorm(output_channels),
111+
torch::nn::BatchNorm2d(output_channels),
112112
torch::nn::Functional(modelsimpl::relu_));
113113

114114
input_channels = output_channels;
@@ -135,7 +135,7 @@ ShuffleNetV2Impl::ShuffleNetV2Impl(
135135
.stride(1)
136136
.padding(0)
137137
.bias(false)),
138-
torch::nn::BatchNorm(output_channels),
138+
torch::nn::BatchNorm2d(output_channels),
139139
torch::nn::Functional(modelsimpl::relu_));
140140

141141
fc = torch::nn::Linear(output_channels, num_classes);

torchvision/csrc/models/vgg.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ torch::nn::Sequential makeLayers(
1919
torch::nn::Conv2dOptions(channels, V, 3).padding(1)));
2020

2121
if (batch_norm)
22-
seq->push_back(torch::nn::BatchNorm(V));
22+
seq->push_back(torch::nn::BatchNorm2d(V));
2323
seq->push_back(torch::nn::Functional(modelsimpl::relu_));
2424

2525
channels = V;
@@ -35,10 +35,10 @@ void VGGImpl::_initialize_weights() {
3535
torch::nn::init::kaiming_normal_(
3636
M->weight,
3737
/*a=*/0,
38-
torch::nn::init::FanMode::FanOut,
39-
torch::nn::init::Nonlinearity::ReLU);
38+
torch::kFanOut,
39+
torch::kReLU);
4040
torch::nn::init::constant_(M->bias, 0);
41-
} else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
41+
} else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
4242
torch::nn::init::constant_(M->weight, 1);
4343
torch::nn::init::constant_(M->bias, 0);
4444
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) {

0 commit comments

Comments
 (0)