@@ -10,61 +10,163 @@ namespace converters {
10
10
namespace impl {
11
11
namespace {
12
12
13
- auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern({
14
- R"SIG( aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta,
13
+ void _batch_norm (
14
+ ConversionCtx* ctx,
15
+ const torch::jit::Node* n,
16
+ nvinfer1::ITensor* input,
17
+ const nvinfer1::Dims32& orig_shape,
18
+ const torch::Tensor& gamma,
19
+ const torch::Tensor& beta,
20
+ const torch::Tensor& mean,
21
+ const torch::Tensor& var,
22
+ const float eps) {
23
+ auto scale = gamma / torch::sqrt (var + eps);
24
+ auto bias = beta - mean * scale;
25
+ LOG_DEBUG (" _batch_norm Tensor Scale : " << scale.sizes ());
26
+ LOG_DEBUG (" _batch_norm Tensor bias : " << bias.sizes ());
27
+
28
+ auto scale_weights = Weights (ctx, scale);
29
+ auto bias_weights = Weights (ctx, bias);
30
+
31
+ auto power = Weights (ctx, at::ones_like (scale));
32
+ auto bn =
33
+ ctx->net ->addScaleNd (*input, nvinfer1::ScaleMode::kCHANNEL , bias_weights.data , scale_weights.data , power.data , 1 );
34
+ bn->setName (util::node_info (n).c_str ());
35
+
36
+ // Un-pad bn output if needed
37
+ auto out_tensor = addUnpadding (ctx, n, bn->getOutput (0 ), orig_shape.nbDims );
38
+ ctx->AssociateValueAndTensor (n->outputs ()[0 ], out_tensor);
39
+ }
40
+
41
+ auto batch_norm_registrations TRTORCH_UNUSED =
42
+ RegisterNodeConversionPatterns ()
43
+ .pattern({
44
+ R"SIG( aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta,
15
45
Tensor? mean, Tensor? var,
16
46
bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor))SIG" ,
17
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
18
- auto input = args[0 ].ITensor (); // assumes non-static input Tensor
19
- auto orig_shape = input->getDimensions ();
20
- auto shape = util::toVec (orig_shape);
21
- auto tensor_type = util::TRTDataTypeToScalarType (input->getType ());
22
- auto options = torch::TensorOptions ().dtype (tensor_type);
23
-
24
- torch::Tensor gamma, beta, mean, var;
25
-
26
- if (ctx->input_is_dynamic ) {
27
- gamma = args[1 ].unwrapToTensor ();
28
- beta = args[2 ].unwrapToTensor ();
29
- mean = args[3 ].unwrapToTensor ();
30
- var = args[4 ].unwrapToTensor ();
31
- } else {
32
- gamma = args[1 ].unwrapToTensor (at::full ({shape}, 1 , {options}));
33
- beta = args[2 ].unwrapToTensor (at::full ({shape}, 1 , {options}));
34
- mean = args[3 ].unwrapToTensor (at::full ({shape}, 0 , {options}));
35
- var = args[4 ].unwrapToTensor (at::full ({shape}, 0 , {options}));
36
- }
37
-
38
- auto eps = args[7 ].unwrapToDouble (1e-5f );
39
-
40
- LOG_DEBUG (" momentum disregarded" );
41
- LOG_DEBUG (" training disregarded" );
42
- LOG_DEBUG (" cudnn disregarded" );
43
- TRTORCH_CHECK (orig_shape.nbDims > 2 , " Unable to create batch normalization layer from node: " << *n);
44
-
45
- // Expand spatial dims from 1D to 2D if needed
46
- bool expandDims = (orig_shape.nbDims < 4 );
47
-
48
- if (expandDims) {
49
- input = addPadding (ctx, n, input, 4 );
50
- }
51
-
52
- auto scale = gamma / torch::sqrt (var + eps);
53
- auto bias = beta - mean * scale;
54
-
55
- auto scale_weights = Weights (ctx, scale);
56
- auto bias_weights = Weights (ctx, bias);
57
-
58
- auto power = Weights (ctx, at::ones_like (scale));
59
- auto bn = ctx->net ->addScaleNd (
60
- *input, nvinfer1::ScaleMode::kCHANNEL , bias_weights.data , scale_weights.data , power.data , 1 );
61
- bn->setName (util::node_info (n).c_str ());
62
- // Un-pad bn output if needed
63
- auto out_tensor = addUnpadding (ctx, n, bn->getOutput (0 ), orig_shape.nbDims );
64
- ctx->AssociateValueAndTensor (n->outputs ()[0 ], out_tensor);
65
- return true ;
66
- }});
47
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
48
+ auto input = args[0 ].ITensor (); // assumes non-static input Tensor
49
+ auto orig_shape = input->getDimensions ();
50
+ auto shape = util::toVec (orig_shape);
51
+ auto tensor_type = util::TRTDataTypeToScalarType (input->getType ());
52
+ auto options = torch::TensorOptions ().dtype (tensor_type);
53
+
54
+ torch::Tensor gamma, beta, mean, var;
55
+
56
+ if (ctx->input_is_dynamic ) {
57
+ gamma = args[1 ].unwrapToTensor ();
58
+ beta = args[2 ].unwrapToTensor ();
59
+ mean = args[3 ].unwrapToTensor ();
60
+ var = args[4 ].unwrapToTensor ();
61
+ } else {
62
+ gamma = args[1 ].unwrapToTensor (at::full ({shape}, 1 , {options}));
63
+ beta = args[2 ].unwrapToTensor (at::full ({shape}, 1 , {options}));
64
+ mean = args[3 ].unwrapToTensor (at::full ({shape}, 0 , {options}));
65
+ var = args[4 ].unwrapToTensor (at::full ({shape}, 0 , {options}));
66
+ }
67
+
68
+ auto eps = static_cast <float >(args[7 ].unwrapToDouble (1e-5f ));
69
+
70
+ LOG_DEBUG (" momentum disregarded" );
71
+ LOG_DEBUG (" training disregarded" );
72
+ LOG_DEBUG (" cudnn disregarded" );
73
+ TRTORCH_CHECK (orig_shape.nbDims > 2 , " Unable to create batch normalization layer from node: " << *n);
74
+
75
+ // Expand spatial dims from 1D to 2D if needed
76
+ bool expandDims = (orig_shape.nbDims < 4 );
77
+ if (expandDims) {
78
+ input = addPadding (ctx, n, input, 4 );
79
+ }
80
+
81
+ _batch_norm (ctx, n, input, orig_shape, gamma, beta, mean, var, eps);
82
+
83
+ return true ;
84
+ }})
85
+ .pattern({
86
+ R"SIG( aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias,
87
+ Tensor? running_mean, Tensor? running_var,
88
+ bool use_input_stats, float momentum, float eps,
89
+ bool cudnn_enabled) -> (Tensor))SIG" ,
90
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
91
+ auto input = args[0 ].ITensorOrFreeze (ctx);
92
+ auto orig_shape = input->getDimensions ();
93
+ auto shape = util::toVec (orig_shape);
94
+ auto tensor_type = util::TRTDataTypeToScalarType (input->getType ());
95
+ auto options = torch::TensorOptions ().dtype (tensor_type);
96
+
97
+ LOG_DEBUG (" Input :" << orig_shape << " /" << input->getType ());
98
+ // affine=True
99
+ LOG_DEBUG (" Args[1] weight : " << args[1 ].isIValue () << " / " << args[1 ].IValue ()->isNone ());
100
+ LOG_DEBUG (" Args[2] bias : " << args[2 ].isIValue () << " / " << args[2 ].IValue ()->isNone ());
101
+ // track_running_stats=True
102
+ LOG_DEBUG (" Args[3] running_mean : " << args[3 ].isIValue () << " / " << args[3 ].IValue ()->isNone ());
103
+ LOG_DEBUG (" Args[4] running_var : " << args[4 ].isIValue () << " / " << args[4 ].IValue ()->isNone ());
104
+
105
+ LOG_DEBUG (" use_input_stats, momemtum, cudnn_enabled disregarded" );
106
+ LOG_DEBUG (" ctx->input_is_dynamic : " << ctx->input_is_dynamic );
107
+
108
+ // Expand spatial dims from 1D to 2D if needed
109
+ bool expandDims = (orig_shape.nbDims < 4 );
110
+ if (expandDims) {
111
+ input = addPadding (ctx, n, input, 4 );
112
+ }
113
+
114
+ auto eps = static_cast <float >(args[7 ].unwrapToDouble (1e-5f ));
115
+
116
+ auto scales = args[1 ].unwrapToTensor (at::ones (shape[1 ], options)).cpu ().contiguous ();
117
+ auto bias = args[2 ].unwrapToTensor (at::zeros (shape[1 ], options)).cpu ().contiguous ();
118
+ LOG_DEBUG (" Scales : " << );
119
+ LOG_DEBUG (" bias : " << bias);
120
+
121
+ // track_running_stats=True
122
+ if (!args[3 ].IValue ()->isNone () || !args[4 ].IValue ()->isNone ()) {
123
+ auto running_mean = args[3 ].unwrapToTensor ().cpu ().contiguous ();
124
+ auto running_var = args[4 ].unwrapToTensor ().cpu ().contiguous ();
125
+ _batch_norm (ctx, n, input, orig_shape, scales, bias, running_mean, running_var, eps);
126
+ return true ;
127
+ }
128
+
129
+ const int relu = 0 ;
130
+ const float alpha = 0 ;
131
+ LOG_DEBUG (" Set parameter `relu` and `alpha` to 0" );
132
+ /*
133
+ https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/namespacenvinfer1.html
134
+ https://github.com/NVIDIA/TensorRT/tree/8.0.1/plugin/instanceNormalizationPlugin
135
+ Type Parameter Description
136
+ float epsilon A small number to prevent being divided by zero during normalization.
137
+ Weights * scale A pointer to weights which contains information about scale factors for
138
+ normalization. The definition of Weights can be found in the NvInfer.h header.
139
+ Weights * bias A pointer to weights which contains information about the bias values for
140
+ normalization. The definition of Weights can be found in the NvInfer.h header.
141
+ int relu A value used to enable leaky relu activation
142
+ float alpha A small negative slope for the leaky relu activation
143
+ */
144
+ std::vector<nvinfer1::PluginField> f;
145
+ f.emplace_back (nvinfer1::PluginField (" epsilon" , &eps, nvinfer1::PluginFieldType::kFLOAT32 , 1 ));
146
+ f.emplace_back (nvinfer1::PluginField (
147
+ " scales" , scales.data_ptr <float >(), nvinfer1::PluginFieldType::kFLOAT32 , scales.numel ()));
148
+ f.emplace_back (nvinfer1::PluginField (
149
+ " bias" , bias.data_ptr <float >(), nvinfer1::PluginFieldType::kFLOAT32 , bias.numel ()));
150
+ f.emplace_back (nvinfer1::PluginField (" relu" , &relu, nvinfer1::PluginFieldType::kINT32 , 1 ));
151
+ f.emplace_back (nvinfer1::PluginField (" alpha" , &alpha, nvinfer1::PluginFieldType::kFLOAT32 , 1 ));
152
+
153
+ nvinfer1::PluginFieldCollection fc;
154
+ fc.nbFields = f.size ();
155
+ fc.fields = f.data ();
156
+
157
+ auto creator = getPluginRegistry ()->getPluginCreator (" InstanceNormalization_TRT" , " 1" , " " );
158
+ auto instance_norm_plugin = creator->createPlugin (" instance_norm" , &fc);
159
+
160
+ TRTORCH_CHECK (
161
+ instance_norm_plugin, " Unable to create instance_norm plugin from TensorRT plugin registry" << *n);
162
+
163
+ auto new_layer =
164
+ ctx->net ->addPluginV2 (reinterpret_cast <nvinfer1::ITensor* const *>(&input), 1 , *instance_norm_plugin);
67
165
166
+ new_layer->setName (util::node_info (n).c_str ());
167
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], new_layer->getOutput (0 ));
168
+ return true ;
169
+ }});
68
170
} // namespace
69
171
} // namespace impl
70
172
} // namespace converters
0 commit comments