Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions torchvision/csrc/models/densenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ struct _DenseLayerImpl : torch::nn::SequentialImpl {
int64_t bn_size,
double drop_rate)
: drop_rate(drop_rate) {
push_back("norm1", torch::nn::BatchNorm(num_input_features));
push_back("norm1", torch::nn::BatchNorm2d(num_input_features));
push_back("relu1", torch::nn::Functional(modelsimpl::relu_));
push_back(
"conv1",
torch::nn::Conv2d(Options(num_input_features, bn_size * growth_rate, 1)
.stride(1)
.bias(false)));
push_back("norm2", torch::nn::BatchNorm(bn_size * growth_rate));
push_back("norm2", torch::nn::BatchNorm2d(bn_size * growth_rate));
push_back("relu2", torch::nn::Functional(modelsimpl::relu_));
push_back(
"conv2",
Expand Down Expand Up @@ -69,7 +69,7 @@ TORCH_MODULE(_DenseBlock);

struct _TransitionImpl : torch::nn::SequentialImpl {
_TransitionImpl(int64_t num_input_features, int64_t num_output_features) {
push_back("norm", torch::nn::BatchNorm(num_input_features));
push_back("norm", torch::nn::BatchNorm2d(num_input_features));
push_back("relu ", torch::nn::Functional(modelsimpl::relu_));
push_back(
"conv",
Expand Down Expand Up @@ -102,7 +102,7 @@ DenseNetImpl::DenseNetImpl(
torch::nn::Conv2d(
Options(3, num_init_features, 7).stride(2).padding(3).bias(false)));

features->push_back("norm0", torch::nn::BatchNorm(num_init_features));
features->push_back("norm0", torch::nn::BatchNorm2d(num_init_features));
features->push_back("relu0", torch::nn::Functional(modelsimpl::relu_));
features->push_back(
"pool0", torch::nn::Functional(torch::max_pool2d, 3, 2, 1, 1, false));
Expand All @@ -125,7 +125,7 @@ DenseNetImpl::DenseNetImpl(
}

// Final batch norm
features->push_back("norm5", torch::nn::BatchNorm(num_features));
features->push_back("norm5", torch::nn::BatchNorm2d(num_features));
// Linear layer
classifier = torch::nn::Linear(num_features, num_classes);

Expand All @@ -136,7 +136,7 @@ DenseNetImpl::DenseNetImpl(
for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get()))
torch::nn::init::kaiming_normal_(M->weight);
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::constant_(M->weight, 1);
torch::nn::init::constant_(M->bias, 0);
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get()))
Expand Down
4 changes: 2 additions & 2 deletions torchvision/csrc/models/googlenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace _googlenetimpl {
BasicConv2dImpl::BasicConv2dImpl(torch::nn::Conv2dOptions options) {
options.bias(false);
conv = torch::nn::Conv2d(options);
bn = torch::nn::BatchNorm(
bn = torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(options.out_channels()).eps(0.001));

register_module("conv", conv);
Expand Down Expand Up @@ -155,7 +155,7 @@ void GoogLeNetImpl::_initialize_weights() {
else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get()))
torch::nn::init::normal_(M->weight); // Note: used instead of truncated
// normal initialization
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::ones_(M->weight);
torch::nn::init::zeros_(M->bias);
}
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/models/googlenet.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace models {
namespace _googlenetimpl {
struct VISION_API BasicConv2dImpl : torch::nn::Module {
torch::nn::Conv2d conv{nullptr};
torch::nn::BatchNorm bn{nullptr};
torch::nn::BatchNorm2d bn{nullptr};

BasicConv2dImpl(torch::nn::Conv2dOptions options);

Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/models/inception.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ BasicConv2dImpl::BasicConv2dImpl(
double std_dev) {
options.bias(false);
conv = torch::nn::Conv2d(options);
bn = torch::nn::BatchNorm(
bn = torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(options.out_channels()).eps(0.001));

register_module("conv", conv);
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/models/inception.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace models {
namespace _inceptionimpl {
struct VISION_API BasicConv2dImpl : torch::nn::Module {
torch::nn::Conv2d conv{nullptr};
torch::nn::BatchNorm bn{nullptr};
torch::nn::BatchNorm2d bn{nullptr};

BasicConv2dImpl(torch::nn::Conv2dOptions options, double std_dev = 0.1);

Expand Down
20 changes: 10 additions & 10 deletions torchvision/csrc/models/mnasnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module {
apply_residual = input == output && stride == 1;

layers->push_back(torch::nn::Conv2d(Options(input, mid, 1).bias(false)));
layers->push_back(torch::nn::BatchNorm(
layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(mid).momentum(bn_momentum)));
layers->push_back(
torch::nn::Functional(torch::nn::Functional(modelsimpl::relu_)));
Expand All @@ -34,12 +34,12 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module {
.stride(stride)
.groups(mid)
.bias(false))));
layers->push_back(torch::nn::BatchNorm(
layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(mid).momentum(bn_momentum)));
layers->push_back(
torch::nn::Functional(torch::nn::Functional(modelsimpl::relu_)));
layers->push_back(torch::nn::Conv2d(Options(mid, output, 1).bias(false)));
layers->push_back(torch::nn::BatchNorm(
layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(output).momentum(bn_momentum)));

register_module("layers", layers);
Expand Down Expand Up @@ -109,9 +109,9 @@ void MNASNetImpl::_initialize_weights() {
torch::nn::init::kaiming_normal_(
M->weight,
0,
torch::nn::init::FanMode::FanOut,
torch::nn::init::Nonlinearity::ReLU);
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
torch::kFanOut,
torch::kReLU);
else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::ones_(M->weight);
torch::nn::init::zeros_(M->bias);
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) {
Expand All @@ -128,17 +128,17 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) {

layers->push_back(
torch::nn::Conv2d(Options(3, 32, 3).padding(1).stride(2).bias(false)));
layers->push_back(torch::nn::BatchNorm(
layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(32).momentum(BN_MOMENTUM)));
layers->push_back(torch::nn::Functional(modelsimpl::relu_));
layers->push_back(torch::nn::Conv2d(
Options(32, 32, 3).padding(1).stride(1).groups(32).bias(false)));
layers->push_back(torch::nn::BatchNorm(
layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(32).momentum(BN_MOMENTUM)));
layers->push_back(torch::nn::Functional(modelsimpl::relu_));
layers->push_back(
torch::nn::Conv2d(Options(32, 16, 1).padding(0).stride(1).bias(false)));
layers->push_back(torch::nn::BatchNorm(
layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(16).momentum(BN_MOMENTUM)));

layers->push_back(stack(16, depths[0], 3, 2, 3, 3, BN_MOMENTUM));
Expand All @@ -150,7 +150,7 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) {

layers->push_back(torch::nn::Conv2d(
Options(depths[5], 1280, 1).padding(0).stride(1).bias(false)));
layers->push_back(torch::nn::BatchNorm(
layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(1280).momentum(BN_MOMENTUM)));
layers->push_back(torch::nn::Functional(modelsimpl::relu_));

Expand Down
8 changes: 4 additions & 4 deletions torchvision/csrc/models/mobilenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct ConvBNReLUImpl : torch::nn::SequentialImpl {
.padding(padding)
.groups(groups)
.bias(false)));
push_back(torch::nn::BatchNorm(out_planes));
push_back(torch::nn::BatchNorm2d(out_planes));
push_back(torch::nn::Functional(modelsimpl::relu6_));
}

Expand Down Expand Up @@ -68,7 +68,7 @@ struct MobileNetInvertedResidualImpl : torch::nn::Module {
conv->push_back(ConvBNReLU(hidden_dim, hidden_dim, 3, stride, hidden_dim));
conv->push_back(torch::nn::Conv2d(
Options(hidden_dim, output, 1).stride(1).padding(0).bias(false)));
conv->push_back(torch::nn::BatchNorm(output));
conv->push_back(torch::nn::BatchNorm2d(output));

register_module("conv", conv);
}
Expand Down Expand Up @@ -135,10 +135,10 @@ MobileNetV2Impl::MobileNetV2Impl(
for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) {
torch::nn::init::kaiming_normal_(
M->weight, 0, torch::nn::init::FanMode::FanOut);
M->weight, 0, torch::kFanOut);
if (M->options.bias())
torch::nn::init::zeros_(M->bias);
} else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
} else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::ones_(M->weight);
torch::nn::init::zeros_(M->bias);
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) {
Expand Down
10 changes: 5 additions & 5 deletions torchvision/csrc/models/resnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ BasicBlock::BasicBlock(
conv1 = conv3x3(inplanes, planes, stride);
conv2 = conv3x3(planes, planes);

bn1 = torch::nn::BatchNorm(planes);
bn2 = torch::nn::BatchNorm(planes);
bn1 = torch::nn::BatchNorm2d(planes);
bn2 = torch::nn::BatchNorm2d(planes);

register_module("conv1", conv1);
register_module("conv2", conv2);
Expand All @@ -68,9 +68,9 @@ Bottleneck::Bottleneck(
conv2 = conv3x3(width, width, stride, groups);
conv3 = conv1x1(width, planes * expansion);

bn1 = torch::nn::BatchNorm(width);
bn2 = torch::nn::BatchNorm(width);
bn3 = torch::nn::BatchNorm(planes * expansion);
bn1 = torch::nn::BatchNorm2d(width);
bn2 = torch::nn::BatchNorm2d(width);
bn3 = torch::nn::BatchNorm2d(planes * expansion);

register_module("conv1", conv1);
register_module("conv2", conv2);
Expand Down
14 changes: 7 additions & 7 deletions torchvision/csrc/models/resnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct VISION_API BasicBlock : torch::nn::Module {
torch::nn::Sequential downsample;

torch::nn::Conv2d conv1{nullptr}, conv2{nullptr};
torch::nn::BatchNorm bn1{nullptr}, bn2{nullptr};
torch::nn::BatchNorm2d bn1{nullptr}, bn2{nullptr};

static int expansion;

Expand All @@ -51,7 +51,7 @@ struct VISION_API Bottleneck : torch::nn::Module {
torch::nn::Sequential downsample;

torch::nn::Conv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr};
torch::nn::BatchNorm bn1{nullptr}, bn2{nullptr}, bn3{nullptr};
torch::nn::BatchNorm2d bn1{nullptr}, bn2{nullptr}, bn3{nullptr};

static int expansion;

Expand All @@ -71,7 +71,7 @@ template <typename Block>
struct ResNetImpl : torch::nn::Module {
int64_t groups, base_width, inplanes;
torch::nn::Conv2d conv1;
torch::nn::BatchNorm bn1;
torch::nn::BatchNorm2d bn1;
torch::nn::Sequential layer1, layer2, layer3, layer4;
torch::nn::Linear fc;

Expand Down Expand Up @@ -99,7 +99,7 @@ torch::nn::Sequential ResNetImpl<Block>::_make_layer(
if (stride != 1 || inplanes != planes * Block::expansion) {
downsample = torch::nn::Sequential(
_resnetimpl::conv1x1(inplanes, planes * Block::expansion, stride),
torch::nn::BatchNorm(planes * Block::expansion));
torch::nn::BatchNorm2d(planes * Block::expansion));
}

torch::nn::Sequential layers;
Expand Down Expand Up @@ -146,9 +146,9 @@ ResNetImpl<Block>::ResNetImpl(
torch::nn::init::kaiming_normal_(
M->weight,
/*a=*/0,
torch::nn::init::FanMode::FanOut,
torch::nn::init::Nonlinearity::ReLU);
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
torch::kFanOut,
torch::kReLU);
else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::constant_(M->weight, 1);
torch::nn::init::constant_(M->bias, 0);
}
Expand Down
14 changes: 7 additions & 7 deletions torchvision/csrc/models/shufflenetv2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,20 @@ struct ShuffleNetV2InvertedResidualImpl : torch::nn::Module {
if (stride > 1) {
branch1 = torch::nn::Sequential(
conv33(inp, inp, stride),
torch::nn::BatchNorm(inp),
torch::nn::BatchNorm2d(inp),
conv11(inp, branch_features),
torch::nn::BatchNorm(branch_features),
torch::nn::BatchNorm2d(branch_features),
torch::nn::Functional(modelsimpl::relu_));
}

branch2 = torch::nn::Sequential(
conv11(stride > 1 ? inp : branch_features, branch_features),
torch::nn::BatchNorm(branch_features),
torch::nn::BatchNorm2d(branch_features),
torch::nn::Functional(modelsimpl::relu_),
conv33(branch_features, branch_features, stride),
torch::nn::BatchNorm(branch_features),
torch::nn::BatchNorm2d(branch_features),
conv11(branch_features, branch_features),
torch::nn::BatchNorm(branch_features),
torch::nn::BatchNorm2d(branch_features),
torch::nn::Functional(modelsimpl::relu_));

if (!branch1.is_empty())
Expand Down Expand Up @@ -108,7 +108,7 @@ ShuffleNetV2Impl::ShuffleNetV2Impl(
.stride(2)
.padding(1)
.bias(false)),
torch::nn::BatchNorm(output_channels),
torch::nn::BatchNorm2d(output_channels),
torch::nn::Functional(modelsimpl::relu_));

input_channels = output_channels;
Expand All @@ -135,7 +135,7 @@ ShuffleNetV2Impl::ShuffleNetV2Impl(
.stride(1)
.padding(0)
.bias(false)),
torch::nn::BatchNorm(output_channels),
torch::nn::BatchNorm2d(output_channels),
torch::nn::Functional(modelsimpl::relu_));

fc = torch::nn::Linear(output_channels, num_classes);
Expand Down
8 changes: 4 additions & 4 deletions torchvision/csrc/models/vgg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ torch::nn::Sequential makeLayers(
torch::nn::Conv2dOptions(channels, V, 3).padding(1)));

if (batch_norm)
seq->push_back(torch::nn::BatchNorm(V));
seq->push_back(torch::nn::BatchNorm2d(V));
seq->push_back(torch::nn::Functional(modelsimpl::relu_));

channels = V;
Expand All @@ -35,10 +35,10 @@ void VGGImpl::_initialize_weights() {
torch::nn::init::kaiming_normal_(
M->weight,
/*a=*/0,
torch::nn::init::FanMode::FanOut,
torch::nn::init::Nonlinearity::ReLU);
torch::kFanOut,
torch::kReLU);
torch::nn::init::constant_(M->bias, 0);
} else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
} else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::constant_(M->weight, 1);
torch::nn::init::constant_(M->bias, 0);
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) {
Expand Down