diff --git a/torchvision/csrc/models/mnasnet.cpp b/torchvision/csrc/models/mnasnet.cpp index 75b63c9f5c5..83d2a0bb18c 100644 --- a/torchvision/csrc/models/mnasnet.cpp +++ b/torchvision/csrc/models/mnasnet.cpp @@ -109,10 +109,7 @@ void MNASNetImpl::_initialize_weights() { for (auto& module : modules(/*include_self=*/false)) { if (auto M = dynamic_cast(module.get())) torch::nn::init::kaiming_normal_( - M->weight, - 0, - torch::nn::init::FanMode::FanOut, - torch::nn::init::Nonlinearity::ReLU); + M->weight, 0, torch::kFanOut, torch::kReLU); else if (auto M = dynamic_cast(module.get())) { torch::nn::init::ones_(M->weight); torch::nn::init::zeros_(M->bias); diff --git a/torchvision/csrc/models/mobilenet.cpp b/torchvision/csrc/models/mobilenet.cpp index 2b49c844977..e80512b8971 100644 --- a/torchvision/csrc/models/mobilenet.cpp +++ b/torchvision/csrc/models/mobilenet.cpp @@ -134,8 +134,7 @@ MobileNetV2Impl::MobileNetV2Impl( for (auto& module : modules(/*include_self=*/false)) { if (auto M = dynamic_cast(module.get())) { - torch::nn::init::kaiming_normal_( - M->weight, 0, torch::nn::init::FanMode::FanOut); + torch::nn::init::kaiming_normal_(M->weight, 0, torch::kFanOut); if (M->options.with_bias()) torch::nn::init::zeros_(M->bias); } else if (auto M = dynamic_cast(module.get())) { diff --git a/torchvision/csrc/models/resnet.h b/torchvision/csrc/models/resnet.h index ae9f4613ebe..c03c5ce41ce 100644 --- a/torchvision/csrc/models/resnet.h +++ b/torchvision/csrc/models/resnet.h @@ -146,8 +146,8 @@ ResNetImpl::ResNetImpl( torch::nn::init::kaiming_normal_( M->weight, /*a=*/0, - torch::nn::init::FanMode::FanOut, - torch::nn::init::Nonlinearity::ReLU); + torch::kFanOut, + torch::kReLU); else if (auto M = dynamic_cast(module.get())) { torch::nn::init::constant_(M->weight, 1); torch::nn::init::constant_(M->bias, 0); diff --git a/torchvision/csrc/models/vgg.cpp b/torchvision/csrc/models/vgg.cpp index c3677d6dd60..77d181a5751 100644 --- a/torchvision/csrc/models/vgg.cpp +++ b/torchvision/csrc/models/vgg.cpp @@ -35,8 +35,8 @@ void VGGImpl::_initialize_weights() { torch::nn::init::kaiming_normal_( M->weight, /*a=*/0, - torch::nn::init::FanMode::FanOut, - torch::nn::init::Nonlinearity::ReLU); + torch::kFanOut, + torch::kReLU); torch::nn::init::constant_(M->bias, 0); } else if (auto M = dynamic_cast(module.get())) { torch::nn::init::constant_(M->weight, 1);