diff --git a/core/conversion/converters/impl/batch_norm.cpp b/core/conversion/converters/impl/batch_norm.cpp index fc19cb282e..252f51dae0 100644 --- a/core/conversion/converters/impl/batch_norm.cpp +++ b/core/conversion/converters/impl/batch_norm.cpp @@ -10,61 +10,171 @@ namespace converters { namespace impl { namespace { -auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern({ - R"SIG(aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta, +void _batch_norm( + ConversionCtx* ctx, + const torch::jit::Node* n, + nvinfer1::ITensor* input, + const nvinfer1::Dims32& orig_shape, + const torch::Tensor& gamma, + const torch::Tensor& beta, + const torch::Tensor& mean, + const torch::Tensor& var, + const float eps) { + auto scale = gamma / torch::sqrt(var + eps); + auto bias = beta - mean * scale; + LOG_DEBUG("_batch_norm Tensor Scale : " << scale.sizes()); + LOG_DEBUG("_batch_norm Tensor bias : " << bias.sizes()); + + auto scale_weights = Weights(ctx, scale); + auto bias_weights = Weights(ctx, bias); + + auto power = Weights(ctx, at::ones_like(scale)); + auto bn = + ctx->net->addScaleNd(*input, nvinfer1::ScaleMode::kCHANNEL, bias_weights.data, scale_weights.data, power.data, 1); + bn->setName(util::node_info(n).c_str()); + + // Un-pad bn output if needed + auto out_tensor = addUnpadding(ctx, n, bn->getOutput(0), orig_shape.nbDims); + ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); +} + +auto batch_norm_registrations TRTORCH_UNUSED = + RegisterNodeConversionPatterns() + .pattern({ + R"SIG(aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta, Tensor? mean, Tensor? var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor))SIG", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto input = args[0].ITensor(); // assumes non-static input Tensor - auto orig_shape = input->getDimensions(); - auto shape = util::toVec(orig_shape); - auto tensor_type = util::TRTDataTypeToScalarType(input->getType()); - auto options = torch::TensorOptions().dtype(tensor_type); - - torch::Tensor gamma, beta, mean, var; - - if (ctx->input_is_dynamic) { - gamma = args[1].unwrapToTensor(); - beta = args[2].unwrapToTensor(); - mean = args[3].unwrapToTensor(); - var = args[4].unwrapToTensor(); - } else { - gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options})); - beta = args[2].unwrapToTensor(at::full({shape}, 1, {options})); - mean = args[3].unwrapToTensor(at::full({shape}, 0, {options})); - var = args[4].unwrapToTensor(at::full({shape}, 0, {options})); - } - - auto eps = args[7].unwrapToDouble(1e-5f); - - LOG_DEBUG("momentum disregarded"); - LOG_DEBUG("training disregarded"); - LOG_DEBUG("cudnn disregarded"); - TRTORCH_CHECK(orig_shape.nbDims > 2, "Unable to create batch normalization layer from node: " << *n); - - // Expand spatial dims from 1D to 2D if needed - bool expandDims = (orig_shape.nbDims < 4); - - if (expandDims) { - input = addPadding(ctx, n, input, 4); - } - - auto scale = gamma / torch::sqrt(var + eps); - auto bias = beta - mean * scale; - - auto scale_weights = Weights(ctx, scale); - auto bias_weights = Weights(ctx, bias); - - auto power = Weights(ctx, at::ones_like(scale)); - auto bn = ctx->net->addScaleNd( - *input, nvinfer1::ScaleMode::kCHANNEL, bias_weights.data, scale_weights.data, power.data, 1); - bn->setName(util::node_info(n).c_str()); - // Un-pad bn output if needed - auto out_tensor = addUnpadding(ctx, n, bn->getOutput(0), orig_shape.nbDims); - ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor); - return true; - }}); + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto input = args[0].ITensor(); // assumes non-static input Tensor + auto orig_shape = input->getDimensions(); + auto shape = util::toVec(orig_shape); + auto tensor_type = util::TRTDataTypeToScalarType(input->getType()); + auto options = torch::TensorOptions().dtype(tensor_type); + + torch::Tensor gamma, beta, mean, var; + + if (ctx->input_is_dynamic) { + gamma = args[1].unwrapToTensor(); + beta = args[2].unwrapToTensor(); + mean = args[3].unwrapToTensor(); + var = args[4].unwrapToTensor(); + } else { + gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options})); + beta = args[2].unwrapToTensor(at::full({shape}, 1, {options})); + mean = args[3].unwrapToTensor(at::full({shape}, 0, {options})); + var = args[4].unwrapToTensor(at::full({shape}, 0, {options})); + } + + auto eps = static_cast(args[7].unwrapToDouble(1e-5f)); + + LOG_DEBUG("momentum disregarded"); + LOG_DEBUG("training disregarded"); + LOG_DEBUG("cudnn disregarded"); + TRTORCH_CHECK(orig_shape.nbDims > 2, "Unable to create batch normalization layer from node: " << *n); + + // Expand spatial dims from 1D to 2D if needed + bool expandDims = (orig_shape.nbDims < 4); + if (expandDims) { + input = addPadding(ctx, n, input, 4); + } + + _batch_norm(ctx, n, input, orig_shape, gamma, beta, mean, var, eps); + + return true; + }}) + .pattern({ + R"SIG(aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, + Tensor? running_mean, Tensor? running_var, + bool use_input_stats, float momentum, float eps, + bool cudnn_enabled) -> (Tensor))SIG", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto input = args[0].ITensorOrFreeze(ctx); + auto orig_shape = input->getDimensions(); + auto shape = util::toVec(orig_shape); + auto tensor_type = util::TRTDataTypeToScalarType(input->getType()); + auto options = torch::TensorOptions().dtype(tensor_type); + + LOG_DEBUG("Input :" << orig_shape << "/" << input->getType()); + // affine=True + LOG_DEBUG("Args[1] weight : " << args[1].isIValue() << " / " << args[1].IValue()->isNone()); + LOG_DEBUG("Args[2] bias : " << args[2].isIValue() << " / " << args[2].IValue()->isNone()); + // track_running_stats=True + LOG_DEBUG("Args[3] running_mean : " << args[3].isIValue() << " / " << args[3].IValue()->isNone()); + LOG_DEBUG("Args[4] running_var : " << args[4].isIValue() << " / " << args[4].IValue()->isNone()); + LOG_DEBUG("use_input_stats, momemtum, cudnn_enabled disregarded"); + LOG_DEBUG("ctx->input_is_dynamic : " << ctx->input_is_dynamic); + + // Expand spatial dims from 1D to 2D if needed + bool expandDims = (orig_shape.nbDims < 4); + if (expandDims) { + input = addPadding(ctx, n, input, 4); + } + + auto eps = static_cast(args[7].unwrapToDouble(1e-5f)); + + auto scales = args[1].unwrapToTensor(at::ones(shape[1], options)).cpu().contiguous(); + auto bias = args[2].unwrapToTensor(at::zeros(shape[1], options)).cpu().contiguous(); + + // track_running_stats=True + if (!args[3].IValue()->isNone() || !args[4].IValue()->isNone()) { + auto running_mean = args[3].unwrapToTensor(); + auto running_var = args[4].unwrapToTensor(); + _batch_norm( + ctx, + n, + input, + orig_shape, + scales.to(running_mean.options()), + bias.to(running_mean.options()), + running_mean, + running_var, + eps); + return true; + } + + const int relu = 0; + const float alpha = 0; + LOG_DEBUG("Set parameter `relu` and `alpha` to 0"); + /* + https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/namespacenvinfer1.html + https://github.com/NVIDIA/TensorRT/tree/8.0.1/plugin/instanceNormalizationPlugin + Type Parameter Description + float epsilon A small number to prevent being divided by zero during normalization. + Weights * scale A pointer to weights which contains information about scale factors for + normalization. The definition of Weights can be found in the NvInfer.h header. + Weights * bias A pointer to weights which contains information about the bias values for + normalization. The definition of Weights can be found in the NvInfer.h header. + int relu A value used to enable leaky relu activation + float alpha A small negative slope for the leaky relu activation + */ + std::vector f; + f.emplace_back(nvinfer1::PluginField("epsilon", &eps, nvinfer1::PluginFieldType::kFLOAT32, 1)); + f.emplace_back(nvinfer1::PluginField( + "scales", scales.data_ptr(), nvinfer1::PluginFieldType::kFLOAT32, scales.numel())); + f.emplace_back(nvinfer1::PluginField( + "bias", bias.data_ptr(), nvinfer1::PluginFieldType::kFLOAT32, bias.numel())); + f.emplace_back(nvinfer1::PluginField("relu", &relu, nvinfer1::PluginFieldType::kINT32, 1)); + f.emplace_back(nvinfer1::PluginField("alpha", &alpha, nvinfer1::PluginFieldType::kFLOAT32, 1)); + + nvinfer1::PluginFieldCollection fc; + fc.nbFields = f.size(); + fc.fields = f.data(); + + auto creator = getPluginRegistry()->getPluginCreator("InstanceNormalization_TRT", "1", ""); + auto instance_norm_plugin = creator->createPlugin("instance_norm", &fc); + + TRTORCH_CHECK( + instance_norm_plugin, "Unable to create instance_norm plugin from TensorRT plugin registry" << *n); + + auto new_layer = + ctx->net->addPluginV2(reinterpret_cast(&input), 1, *instance_norm_plugin); + new_layer->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + return true; + }}); } // namespace } // namespace impl } // namespace converters diff --git a/tests/core/conversion/converters/BUILD b/tests/core/conversion/converters/BUILD index 1b583875da..7d33949af2 100644 --- a/tests/core/conversion/converters/BUILD +++ b/tests/core/conversion/converters/BUILD @@ -15,6 +15,10 @@ converter_test( name = "test_batch_norm", ) +converter_test( + name = "test_instance_norm", +) + converter_test( name = "test_cast", ) @@ -128,6 +132,7 @@ test_suite( tests = [ ":test_activation", ":test_batch_norm", + ":test_instance_norm", ":test_cast", ":test_clone", ":test_concat", diff --git a/tests/core/conversion/converters/test_instance_norm.cpp b/tests/core/conversion/converters/test_instance_norm.cpp new file mode 100644 index 0000000000..1df8d45d66 --- /dev/null +++ b/tests/core/conversion/converters/test_instance_norm.cpp @@ -0,0 +1,103 @@ +#include +#include "core/compiler.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" + +// Tensor instance_norm( +// const Tensor& input, +// const c10::optional& weight_opt /* optional */, +// const c10::optional& bias_opt /* optional */, +// const c10::optional& running_mean_opt /* optional */, +// const c10::optional& running_var_opt /* optional */, +// bool use_input_stats, double momentum, double eps, bool cudnn_enabled) +constexpr auto graph = R"IR( + graph(%input.1 : Tensor, + %weight.1 : Tensor?, + %bias.1 : Tensor?, + %running_mean.1 : Tensor?, + %running_var.1 : Tensor?, + %use_input_stats.1 : bool): + %cudnn_enabled.1 : bool = prim::Constant[value=1]() + %momentum.1 : float = prim::Constant[value=0.10000000000000001]() + %eps.1 : float = prim::Constant[value=1.0000000000000001e-05]() + %4 : Tensor = aten::instance_norm(%input.1, + %weight.1, %bias.1, + %running_mean.1, %running_var.1, + %use_input_stats.1, %momentum.1, %eps.1, %cudnn_enabled.1) + return (%4) +)IR"; + +TEST(Converters, ATenInstanceNormConvertsCorrectly) { + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA}); + torch::jit::IValue weight, bias, mean, var; // NoneType + // https://github.com/pytorch/pytorch/blob/79693bb86a3f601a5c0d3da52d99acec95bb48c1/torch/nn/modules/instancenorm.py#L59 + const bool use_input_stats = true; + + auto trt_in = at::clone(in); + torch::jit::IValue trt_weight, trt_bias, trt_mean, trt_var; + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + params = trtorch::core::conversion::get_named_params( + g->inputs(), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenInstanceNormAffineConvertsCorrectly) { + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA}); + + auto weight = at::randn({in.size(1)}).to(at::kCUDA); + auto bias = at::randn({in.size(1)}).to(at::kCUDA); + + torch::jit::IValue mean, var; // NoneType + const bool use_input_stats = true; + + auto trt_in = at::clone(in); + auto trt_weight = at::clone(weight); + auto trt_bias = at::clone(bias); + torch::jit::IValue trt_mean, trt_var; + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + params = trtorch::core::conversion::get_named_params( + g->inputs(), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenInstanceNormRunningStatsConvertsCorrectly) { + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randn({1, 5, 5, 5}, {at::kCUDA}); + + torch::jit::IValue weight, bias; + auto mean = at::zeros({in.size(1)}, {at::kCUDA}); + auto var = at::ones({in.size(1)}, {at::kCUDA}); + const bool use_input_stats = false; + + auto trt_in = at::clone(in); + torch::jit::IValue trt_weight, trt_bias; + auto trt_mean = at::clone(mean); + auto trt_var = at::clone(var); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + params = trtorch::core::conversion::get_named_params( + g->inputs(), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +}