Skip to content

Commit dcc289e

Browse files
committed
add instance norm
1 parent 8f1f9cd commit dcc289e

File tree

3 files changed

+245
-52
lines changed

3 files changed

+245
-52
lines changed

core/conversion/converters/impl/batch_norm.cpp

Lines changed: 154 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -10,61 +10,163 @@ namespace converters {
1010
namespace impl {
1111
namespace {
1212

13-
auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern({
14-
R"SIG(aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta,
13+
void _batch_norm(
14+
ConversionCtx* ctx,
15+
const torch::jit::Node* n,
16+
nvinfer1::ITensor* input,
17+
const nvinfer1::Dims32& orig_shape,
18+
const torch::Tensor& gamma,
19+
const torch::Tensor& beta,
20+
const torch::Tensor& mean,
21+
const torch::Tensor& var,
22+
const float eps) {
23+
auto scale = gamma / torch::sqrt(var + eps);
24+
auto bias = beta - mean * scale;
25+
LOG_DEBUG("_batch_norm Tensor Scale : " << scale.sizes());
26+
LOG_DEBUG("_batch_norm Tensor bias : " << bias.sizes());
27+
28+
auto scale_weights = Weights(ctx, scale);
29+
auto bias_weights = Weights(ctx, bias);
30+
31+
auto power = Weights(ctx, at::ones_like(scale));
32+
auto bn =
33+
ctx->net->addScaleNd(*input, nvinfer1::ScaleMode::kCHANNEL, bias_weights.data, scale_weights.data, power.data, 1);
34+
bn->setName(util::node_info(n).c_str());
35+
36+
// Un-pad bn output if needed
37+
auto out_tensor = addUnpadding(ctx, n, bn->getOutput(0), orig_shape.nbDims);
38+
ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
39+
}
40+
41+
auto batch_norm_registrations TRTORCH_UNUSED =
42+
RegisterNodeConversionPatterns()
43+
.pattern({
44+
R"SIG(aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta,
1545
Tensor? mean, Tensor? var,
1646
bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor))SIG",
17-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
18-
auto input = args[0].ITensor(); // assumes non-static input Tensor
19-
auto orig_shape = input->getDimensions();
20-
auto shape = util::toVec(orig_shape);
21-
auto tensor_type = util::TRTDataTypeToScalarType(input->getType());
22-
auto options = torch::TensorOptions().dtype(tensor_type);
23-
24-
torch::Tensor gamma, beta, mean, var;
25-
26-
if (ctx->input_is_dynamic) {
27-
gamma = args[1].unwrapToTensor();
28-
beta = args[2].unwrapToTensor();
29-
mean = args[3].unwrapToTensor();
30-
var = args[4].unwrapToTensor();
31-
} else {
32-
gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
33-
beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
34-
mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
35-
var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
36-
}
37-
38-
auto eps = args[7].unwrapToDouble(1e-5f);
39-
40-
LOG_DEBUG("momentum disregarded");
41-
LOG_DEBUG("training disregarded");
42-
LOG_DEBUG("cudnn disregarded");
43-
TRTORCH_CHECK(orig_shape.nbDims > 2, "Unable to create batch normalization layer from node: " << *n);
44-
45-
// Expand spatial dims from 1D to 2D if needed
46-
bool expandDims = (orig_shape.nbDims < 4);
47-
48-
if (expandDims) {
49-
input = addPadding(ctx, n, input, 4);
50-
}
51-
52-
auto scale = gamma / torch::sqrt(var + eps);
53-
auto bias = beta - mean * scale;
54-
55-
auto scale_weights = Weights(ctx, scale);
56-
auto bias_weights = Weights(ctx, bias);
57-
58-
auto power = Weights(ctx, at::ones_like(scale));
59-
auto bn = ctx->net->addScaleNd(
60-
*input, nvinfer1::ScaleMode::kCHANNEL, bias_weights.data, scale_weights.data, power.data, 1);
61-
bn->setName(util::node_info(n).c_str());
62-
// Un-pad bn output if needed
63-
auto out_tensor = addUnpadding(ctx, n, bn->getOutput(0), orig_shape.nbDims);
64-
ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
65-
return true;
66-
}});
47+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
48+
auto input = args[0].ITensor(); // assumes non-static input Tensor
49+
auto orig_shape = input->getDimensions();
50+
auto shape = util::toVec(orig_shape);
51+
auto tensor_type = util::TRTDataTypeToScalarType(input->getType());
52+
auto options = torch::TensorOptions().dtype(tensor_type);
53+
54+
torch::Tensor gamma, beta, mean, var;
55+
56+
if (ctx->input_is_dynamic) {
57+
gamma = args[1].unwrapToTensor();
58+
beta = args[2].unwrapToTensor();
59+
mean = args[3].unwrapToTensor();
60+
var = args[4].unwrapToTensor();
61+
} else {
62+
gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
63+
beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
64+
mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
65+
var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
66+
}
67+
68+
auto eps = static_cast<float>(args[7].unwrapToDouble(1e-5f));
69+
70+
LOG_DEBUG("momentum disregarded");
71+
LOG_DEBUG("training disregarded");
72+
LOG_DEBUG("cudnn disregarded");
73+
TRTORCH_CHECK(orig_shape.nbDims > 2, "Unable to create batch normalization layer from node: " << *n);
74+
75+
// Expand spatial dims from 1D to 2D if needed
76+
bool expandDims = (orig_shape.nbDims < 4);
77+
if (expandDims) {
78+
input = addPadding(ctx, n, input, 4);
79+
}
80+
81+
_batch_norm(ctx, n, input, orig_shape, gamma, beta, mean, var, eps);
82+
83+
return true;
84+
}})
85+
.pattern({
86+
R"SIG(aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias,
87+
Tensor? running_mean, Tensor? running_var,
88+
bool use_input_stats, float momentum, float eps,
89+
bool cudnn_enabled) -> (Tensor))SIG",
90+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
91+
auto input = args[0].ITensorOrFreeze(ctx);
92+
auto orig_shape = input->getDimensions();
93+
auto shape = util::toVec(orig_shape);
94+
auto tensor_type = util::TRTDataTypeToScalarType(input->getType());
95+
auto options = torch::TensorOptions().dtype(tensor_type);
96+
97+
LOG_DEBUG("Input :" << orig_shape << "/" << input->getType());
98+
// affine=True
99+
LOG_DEBUG("Args[1] weight : " << args[1].isIValue() << " / " << args[1].IValue()->isNone());
100+
LOG_DEBUG("Args[2] bias : " << args[2].isIValue() << " / " << args[2].IValue()->isNone());
101+
// track_running_stats=True
102+
LOG_DEBUG("Args[3] running_mean : " << args[3].isIValue() << " / " << args[3].IValue()->isNone());
103+
LOG_DEBUG("Args[4] running_var : " << args[4].isIValue() << " / " << args[4].IValue()->isNone());
104+
105+
LOG_DEBUG("use_input_stats, momemtum, cudnn_enabled disregarded");
106+
LOG_DEBUG("ctx->input_is_dynamic : " << ctx->input_is_dynamic);
107+
108+
// Expand spatial dims from 1D to 2D if needed
109+
bool expandDims = (orig_shape.nbDims < 4);
110+
if (expandDims) {
111+
input = addPadding(ctx, n, input, 4);
112+
}
113+
114+
auto eps = static_cast<float>(args[7].unwrapToDouble(1e-5f));
115+
116+
auto scales = args[1].unwrapToTensor(at::ones(shape[1], options)).cpu().contiguous();
117+
auto bias = args[2].unwrapToTensor(at::zeros(shape[1], options)).cpu().contiguous();
118+
LOG_DEBUG("Scales : " << );
119+
LOG_DEBUG("bias : " << bias);
120+
121+
// track_running_stats=True
122+
if (!args[3].IValue()->isNone() || !args[4].IValue()->isNone()) {
123+
auto running_mean = args[3].unwrapToTensor().cpu().contiguous();
124+
auto running_var = args[4].unwrapToTensor().cpu().contiguous();
125+
_batch_norm(ctx, n, input, orig_shape, scales, bias, running_mean, running_var, eps);
126+
return true;
127+
}
128+
129+
const int relu = 0;
130+
const float alpha = 0;
131+
LOG_DEBUG("Set parameter `relu` and `alpha` to 0");
132+
/*
133+
https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/namespacenvinfer1.html
134+
https://github.com/NVIDIA/TensorRT/tree/8.0.1/plugin/instanceNormalizationPlugin
135+
Type Parameter Description
136+
float epsilon A small number to prevent being divided by zero during normalization.
137+
Weights * scale A pointer to weights which contains information about scale factors for
138+
normalization. The definition of Weights can be found in the NvInfer.h header.
139+
Weights * bias A pointer to weights which contains information about the bias values for
140+
normalization. The definition of Weights can be found in the NvInfer.h header.
141+
int relu A value used to enable leaky relu activation
142+
float alpha A small negative slope for the leaky relu activation
143+
*/
144+
std::vector<nvinfer1::PluginField> f;
145+
f.emplace_back(nvinfer1::PluginField("epsilon", &eps, nvinfer1::PluginFieldType::kFLOAT32, 1));
146+
f.emplace_back(nvinfer1::PluginField(
147+
"scales", scales.data_ptr<float>(), nvinfer1::PluginFieldType::kFLOAT32, scales.numel()));
148+
f.emplace_back(nvinfer1::PluginField(
149+
"bias", bias.data_ptr<float>(), nvinfer1::PluginFieldType::kFLOAT32, bias.numel()));
150+
f.emplace_back(nvinfer1::PluginField("relu", &relu, nvinfer1::PluginFieldType::kINT32, 1));
151+
f.emplace_back(nvinfer1::PluginField("alpha", &alpha, nvinfer1::PluginFieldType::kFLOAT32, 1));
152+
153+
nvinfer1::PluginFieldCollection fc;
154+
fc.nbFields = f.size();
155+
fc.fields = f.data();
156+
157+
auto creator = getPluginRegistry()->getPluginCreator("InstanceNormalization_TRT", "1", "");
158+
auto instance_norm_plugin = creator->createPlugin("instance_norm", &fc);
159+
160+
TRTORCH_CHECK(
161+
instance_norm_plugin, "Unable to create instance_norm plugin from TensorRT plugin registry" << *n);
162+
163+
auto new_layer =
164+
ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&input), 1, *instance_norm_plugin);
67165

166+
new_layer->setName(util::node_info(n).c_str());
167+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
168+
return true;
169+
}});
68170
} // namespace
69171
} // namespace impl
70172
} // namespace converters

tests/core/conversion/converters/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ converter_test(
1515
name = "test_batch_norm",
1616
)
1717

18+
converter_test(
19+
name = "test_instance_norm",
20+
)
21+
1822
converter_test(
1923
name = "test_clone",
2024
)
@@ -120,6 +124,7 @@ test_suite(
120124
tests = [
121125
":test_activation",
122126
":test_batch_norm",
127+
":test_instance_norm",
123128
":test_clone",
124129
":test_concat",
125130
":test_constant_pad",
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "gtest/gtest.h"
4+
#include "tests/util/util.h"
5+
#include "torch/csrc/jit/ir/irparser.h"
6+
7+
// Tensor instance_norm(
8+
// const Tensor& input,
9+
// const c10::optional<Tensor>& weight_opt /* optional */,
10+
// const c10::optional<Tensor>& bias_opt /* optional */,
11+
// const c10::optional<Tensor>& running_mean_opt /* optional */,
12+
// const c10::optional<Tensor>& running_var_opt /* optional */,
13+
// bool use_input_stats, double momentum, double eps, bool cudnn_enabled)
14+
inline constexpr auto graph = R"IR(
15+
graph(%0 : Tensor,
16+
%1 : Tensor?,
17+
%2 : Tensor?,
18+
%3 : Tensor?,
19+
%4 : Tensor?,
20+
%5 : bool):
21+
%9 : bool = prim::Constant[value=0]()
22+
%6 : float = prim::Constant[value=0.10000000000000001]()
23+
%7 : float = prim::Constant[value=1.0000000000000001e-05]()
24+
%8 : Tensor = aten::instance_norm(%0, %1, %2, %3, %4, %5, %6, %7, %9)
25+
return (%8)
26+
)IR";
27+
28+
TEST(Converters, ATenInstanceNormConvertsCorrectly) {
29+
auto g = std::make_shared<torch::jit::Graph>();
30+
torch::jit::parseIR(graph, g.get());
31+
32+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
33+
torch::jit::IValue weight, bias, mean, var; // NoneType
34+
bool use_input_stats = true;
35+
36+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
37+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
38+
39+
params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
40+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
41+
42+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
43+
}
44+
45+
TEST(Converters, ATenInstanceNormAffineConvertsCorrectly) {
46+
auto g = std::make_shared<torch::jit::Graph>();
47+
torch::jit::parseIR(graph, g.get());
48+
49+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
50+
51+
auto weight = at::randn({in.size(1)}).to(at::kCUDA);
52+
auto bias = at::randn({in.size(1)}).to(at::kCUDA);
53+
54+
torch::jit::IValue mean, var; // NoneType
55+
bool use_input_stats = true;
56+
57+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
58+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
59+
60+
params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
61+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
62+
63+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
64+
}
65+
66+
67+
TEST(Converters, ATenInstanceNormRunningStatsConvertsCorrectly) {
68+
auto g = std::make_shared<torch::jit::Graph>();
69+
torch::jit::parseIR(graph, g.get());
70+
71+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
72+
73+
torch::jit::IValue weight, bias; // NoneType
74+
75+
auto mean = at::randn({in.size(1)}).to(at::kCUDA);
76+
auto var = at::randn({in.size(1)}).to(at::kCUDA);
77+
bool use_input_stats = false;
78+
79+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
80+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
81+
82+
params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
83+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
84+
85+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
86+
}

0 commit comments

Comments
 (0)