Skip to content

Add instance norm #573

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 162 additions & 52 deletions core/conversion/converters/impl/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,61 +10,171 @@ namespace converters {
namespace impl {
namespace {

auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern({
R"SIG(aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta,
void _batch_norm(
ConversionCtx* ctx,
const torch::jit::Node* n,
nvinfer1::ITensor* input,
const nvinfer1::Dims32& orig_shape,
const torch::Tensor& gamma,
const torch::Tensor& beta,
const torch::Tensor& mean,
const torch::Tensor& var,
const float eps) {
auto scale = gamma / torch::sqrt(var + eps);
auto bias = beta - mean * scale;
LOG_DEBUG("_batch_norm Tensor Scale : " << scale.sizes());
LOG_DEBUG("_batch_norm Tensor bias : " << bias.sizes());

auto scale_weights = Weights(ctx, scale);
auto bias_weights = Weights(ctx, bias);

auto power = Weights(ctx, at::ones_like(scale));
auto bn =
ctx->net->addScaleNd(*input, nvinfer1::ScaleMode::kCHANNEL, bias_weights.data, scale_weights.data, power.data, 1);
bn->setName(util::node_info(n).c_str());

// Un-pad bn output if needed
auto out_tensor = addUnpadding(ctx, n, bn->getOutput(0), orig_shape.nbDims);
ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
}

auto batch_norm_registrations TRTORCH_UNUSED =
RegisterNodeConversionPatterns()
.pattern({
R"SIG(aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta,
Tensor? mean, Tensor? var,
bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor))SIG",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto input = args[0].ITensor(); // assumes non-static input Tensor
auto orig_shape = input->getDimensions();
auto shape = util::toVec(orig_shape);
auto tensor_type = util::TRTDataTypeToScalarType(input->getType());
auto options = torch::TensorOptions().dtype(tensor_type);

torch::Tensor gamma, beta, mean, var;

if (ctx->input_is_dynamic) {
gamma = args[1].unwrapToTensor();
beta = args[2].unwrapToTensor();
mean = args[3].unwrapToTensor();
var = args[4].unwrapToTensor();
} else {
gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
}

auto eps = args[7].unwrapToDouble(1e-5f);

LOG_DEBUG("momentum disregarded");
LOG_DEBUG("training disregarded");
LOG_DEBUG("cudnn disregarded");
TRTORCH_CHECK(orig_shape.nbDims > 2, "Unable to create batch normalization layer from node: " << *n);

// Expand spatial dims from 1D to 2D if needed
bool expandDims = (orig_shape.nbDims < 4);

if (expandDims) {
input = addPadding(ctx, n, input, 4);
}

auto scale = gamma / torch::sqrt(var + eps);
auto bias = beta - mean * scale;

auto scale_weights = Weights(ctx, scale);
auto bias_weights = Weights(ctx, bias);

auto power = Weights(ctx, at::ones_like(scale));
auto bn = ctx->net->addScaleNd(
*input, nvinfer1::ScaleMode::kCHANNEL, bias_weights.data, scale_weights.data, power.data, 1);
bn->setName(util::node_info(n).c_str());
// Un-pad bn output if needed
auto out_tensor = addUnpadding(ctx, n, bn->getOutput(0), orig_shape.nbDims);
ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
return true;
}});
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto input = args[0].ITensor(); // assumes non-static input Tensor
auto orig_shape = input->getDimensions();
auto shape = util::toVec(orig_shape);
auto tensor_type = util::TRTDataTypeToScalarType(input->getType());
auto options = torch::TensorOptions().dtype(tensor_type);

torch::Tensor gamma, beta, mean, var;

if (ctx->input_is_dynamic) {
gamma = args[1].unwrapToTensor();
beta = args[2].unwrapToTensor();
mean = args[3].unwrapToTensor();
var = args[4].unwrapToTensor();
} else {
gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
}

auto eps = static_cast<float>(args[7].unwrapToDouble(1e-5f));

LOG_DEBUG("momentum disregarded");
LOG_DEBUG("training disregarded");
LOG_DEBUG("cudnn disregarded");
TRTORCH_CHECK(orig_shape.nbDims > 2, "Unable to create batch normalization layer from node: " << *n);

// Expand spatial dims from 1D to 2D if needed
bool expandDims = (orig_shape.nbDims < 4);
if (expandDims) {
input = addPadding(ctx, n, input, 4);
}

_batch_norm(ctx, n, input, orig_shape, gamma, beta, mean, var, eps);

return true;
}})
.pattern({
R"SIG(aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias,
Tensor? running_mean, Tensor? running_var,
bool use_input_stats, float momentum, float eps,
bool cudnn_enabled) -> (Tensor))SIG",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto input = args[0].ITensorOrFreeze(ctx);
auto orig_shape = input->getDimensions();
auto shape = util::toVec(orig_shape);
auto tensor_type = util::TRTDataTypeToScalarType(input->getType());
auto options = torch::TensorOptions().dtype(tensor_type);

LOG_DEBUG("Input :" << orig_shape << "/" << input->getType());
// affine=True
LOG_DEBUG("Args[1] weight : " << args[1].isIValue() << " / " << args[1].IValue()->isNone());
LOG_DEBUG("Args[2] bias : " << args[2].isIValue() << " / " << args[2].IValue()->isNone());
// track_running_stats=True
LOG_DEBUG("Args[3] running_mean : " << args[3].isIValue() << " / " << args[3].IValue()->isNone());
LOG_DEBUG("Args[4] running_var : " << args[4].isIValue() << " / " << args[4].IValue()->isNone());
LOG_DEBUG("use_input_stats, momemtum, cudnn_enabled disregarded");
LOG_DEBUG("ctx->input_is_dynamic : " << ctx->input_is_dynamic);

// Expand spatial dims from 1D to 2D if needed
bool expandDims = (orig_shape.nbDims < 4);
if (expandDims) {
input = addPadding(ctx, n, input, 4);
}

auto eps = static_cast<float>(args[7].unwrapToDouble(1e-5f));

auto scales = args[1].unwrapToTensor(at::ones(shape[1], options)).cpu().contiguous();
auto bias = args[2].unwrapToTensor(at::zeros(shape[1], options)).cpu().contiguous();

// track_running_stats=True
if (!args[3].IValue()->isNone() || !args[4].IValue()->isNone()) {
auto running_mean = args[3].unwrapToTensor();
auto running_var = args[4].unwrapToTensor();
_batch_norm(
ctx,
n,
input,
orig_shape,
scales.to(running_mean.options()),
bias.to(running_mean.options()),
running_mean,
running_var,
eps);
return true;
}

const int relu = 0;
const float alpha = 0;
LOG_DEBUG("Set parameter `relu` and `alpha` to 0");
/*
https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/namespacenvinfer1.html
https://github.com/NVIDIA/TensorRT/tree/8.0.1/plugin/instanceNormalizationPlugin
Type Parameter Description
float epsilon A small number to prevent being divided by zero during normalization.
Weights * scale A pointer to weights which contains information about scale factors for
normalization. The definition of Weights can be found in the NvInfer.h header.
Weights * bias A pointer to weights which contains information about the bias values for
normalization. The definition of Weights can be found in the NvInfer.h header.
int relu A value used to enable leaky relu activation
float alpha A small negative slope for the leaky relu activation
*/
std::vector<nvinfer1::PluginField> f;
f.emplace_back(nvinfer1::PluginField("epsilon", &eps, nvinfer1::PluginFieldType::kFLOAT32, 1));
f.emplace_back(nvinfer1::PluginField(
"scales", scales.data_ptr<float>(), nvinfer1::PluginFieldType::kFLOAT32, scales.numel()));
f.emplace_back(nvinfer1::PluginField(
"bias", bias.data_ptr<float>(), nvinfer1::PluginFieldType::kFLOAT32, bias.numel()));
f.emplace_back(nvinfer1::PluginField("relu", &relu, nvinfer1::PluginFieldType::kINT32, 1));
f.emplace_back(nvinfer1::PluginField("alpha", &alpha, nvinfer1::PluginFieldType::kFLOAT32, 1));

nvinfer1::PluginFieldCollection fc;
fc.nbFields = f.size();
fc.fields = f.data();

auto creator = getPluginRegistry()->getPluginCreator("InstanceNormalization_TRT", "1", "");
auto instance_norm_plugin = creator->createPlugin("instance_norm", &fc);

TRTORCH_CHECK(
instance_norm_plugin, "Unable to create instance_norm plugin from TensorRT plugin registry" << *n);

auto new_layer =
ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&input), 1, *instance_norm_plugin);

new_layer->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
return true;
}});
} // namespace
} // namespace impl
} // namespace converters
Expand Down
5 changes: 5 additions & 0 deletions tests/core/conversion/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ converter_test(
name = "test_batch_norm",
)

converter_test(
name = "test_instance_norm",
)

converter_test(
name = "test_cast",
)
Expand Down Expand Up @@ -128,6 +132,7 @@ test_suite(
tests = [
":test_activation",
":test_batch_norm",
":test_instance_norm",
":test_cast",
":test_clone",
":test_concat",
Expand Down
103 changes: 103 additions & 0 deletions tests/core/conversion/converters/test_instance_norm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#include <string>
#include "core/compiler.h"
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"

// Tensor instance_norm(
// const Tensor& input,
// const c10::optional<Tensor>& weight_opt /* optional */,
// const c10::optional<Tensor>& bias_opt /* optional */,
// const c10::optional<Tensor>& running_mean_opt /* optional */,
// const c10::optional<Tensor>& running_var_opt /* optional */,
// bool use_input_stats, double momentum, double eps, bool cudnn_enabled)
constexpr auto graph = R"IR(
graph(%input.1 : Tensor,
%weight.1 : Tensor?,
%bias.1 : Tensor?,
%running_mean.1 : Tensor?,
%running_var.1 : Tensor?,
%use_input_stats.1 : bool):
%cudnn_enabled.1 : bool = prim::Constant[value=1]()
%momentum.1 : float = prim::Constant[value=0.10000000000000001]()
%eps.1 : float = prim::Constant[value=1.0000000000000001e-05]()
%4 : Tensor = aten::instance_norm(%input.1,
%weight.1, %bias.1,
%running_mean.1, %running_var.1,
%use_input_stats.1, %momentum.1, %eps.1, %cudnn_enabled.1)
return (%4)
)IR";

TEST(Converters, ATenInstanceNormConvertsCorrectly) {
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
torch::jit::IValue weight, bias, mean, var; // NoneType
// https://github.com/pytorch/pytorch/blob/79693bb86a3f601a5c0d3da52d99acec95bb48c1/torch/nn/modules/instancenorm.py#L59
const bool use_input_stats = true;

auto trt_in = at::clone(in);
torch::jit::IValue trt_weight, trt_bias, trt_mean, trt_var;

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

params = trtorch::core::conversion::get_named_params(
g->inputs(), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenInstanceNormAffineConvertsCorrectly) {
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});

auto weight = at::randn({in.size(1)}).to(at::kCUDA);
auto bias = at::randn({in.size(1)}).to(at::kCUDA);

torch::jit::IValue mean, var; // NoneType
const bool use_input_stats = true;

auto trt_in = at::clone(in);
auto trt_weight = at::clone(weight);
auto trt_bias = at::clone(bias);
torch::jit::IValue trt_mean, trt_var;

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

params = trtorch::core::conversion::get_named_params(
g->inputs(), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenInstanceNormRunningStatsConvertsCorrectly) {
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::randn({1, 5, 5, 5}, {at::kCUDA});

torch::jit::IValue weight, bias;
auto mean = at::zeros({in.size(1)}, {at::kCUDA});
auto var = at::ones({in.size(1)}, {at::kCUDA});
const bool use_input_stats = false;

auto trt_in = at::clone(in);
torch::jit::IValue trt_weight, trt_bias;
auto trt_mean = at::clone(mean);
auto trt_var = at::clone(var);

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

params = trtorch::core::conversion::get_named_params(
g->inputs(), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}