6
6
#include " torch_tensorrt/torch_tensorrt.h"
7
7
8
8
9
+ TEST (CppAPITests, TestCollectionNormalInput) {
10
+
11
+ std::string path =
12
+ " /root/Torch-TensorRT/normal_model.ts" ;
13
+ torch::Tensor in0 = torch::randn ({1 , 3 , 512 , 512 }, torch::kCUDA ).to (torch::kHalf );
14
+ std::vector<at::Tensor> inputs;
15
+ inputs.push_back (in0);
16
+ inputs.push_back (in0);
17
+
18
+ torch::jit::Module mod;
19
+ try {
20
+ // Deserialize the ScriptModule from a file using torch::jit::load().
21
+ mod = torch::jit::load (path);
22
+ } catch (const c10::Error& e) {
23
+ std::cerr << " error loading the model\n " ;
24
+ }
25
+ mod.eval ();
26
+ mod.to (torch::kCUDA );
27
+
28
+
29
+ std::vector<torch::jit::IValue> inputs_;
30
+
31
+ for (auto in : inputs) {
32
+ inputs_.push_back (torch::jit::IValue (in.clone ()));
33
+ }
34
+
35
+ auto out = mod.forward (inputs_);
36
+ LOG_DEBUG (" Finish torchscirpt forward" );
37
+
38
+ std::vector<torch_tensorrt::Input> input_range;
39
+ input_range.push_back ({in0.sizes (), torch::kF16 });
40
+ input_range.push_back ({in0.sizes (), torch::kF16 });
41
+ torch_tensorrt::ts::CompileSpec compile_settings (input_range);
42
+ compile_settings.require_full_compilation = true ;
43
+ compile_settings.min_block_size = 1 ;
44
+
45
+ // // FP16 execution
46
+ compile_settings.enabled_precisions = {torch::kHalf };
47
+ // // Compile module
48
+ auto trt_mod = torch_tensorrt::torchscript::compile (mod, compile_settings);
49
+ LOG_DEBUG (" Finish compile" );
50
+ auto trt_out = trt_mod.forward (inputs_);
51
+
52
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (out.toTensor (), trt_out.toTensor (), 1e-5 ));
53
+ }
54
+
9
55
TEST (CppAPITests, TestCollectionTupleInput) {
10
56
11
57
std::string path =
@@ -81,14 +127,13 @@ TEST(CppAPITests, TestCollectionTupleInput) {
81
127
}
82
128
83
129
84
- TEST (CppAPITests, TestCollectionNormalInput ) {
130
+ TEST (CppAPITests, TestCollectionListInput ) {
85
131
86
132
std::string path =
87
- " /root/Torch-TensorRT/normal_model .ts" ;
133
+ " /root/Torch-TensorRT/list_input .ts" ;
88
134
torch::Tensor in0 = torch::randn ({1 , 3 , 512 , 512 }, torch::kCUDA ).to (torch::kHalf );
89
135
std::vector<at::Tensor> inputs;
90
136
inputs.push_back (in0);
91
- inputs.push_back (in0);
92
137
93
138
torch::jit::Module mod;
94
139
try {
@@ -107,32 +152,136 @@ TEST(CppAPITests, TestCollectionNormalInput) {
107
152
inputs_.push_back (torch::jit::IValue (in.clone ()));
108
153
}
109
154
110
- auto out = mod.forward (inputs_);
155
+ std::vector<torch::jit::IValue> complex_inputs;
156
+ auto input_list = c10::impl::GenericList (c10::TensorType::get ());
157
+ input_list.push_back (inputs_[0 ]);
158
+ input_list.push_back (inputs_[0 ]);
159
+
160
+ torch::jit::IValue input_list_ivalue = torch::jit::IValue (input_list);
161
+
162
+ complex_inputs.push_back (input_list_ivalue);
163
+
164
+
165
+ auto out = mod.forward (complex_inputs);
111
166
LOG_DEBUG (" Finish torchscirpt forward" );
112
167
113
- std::vector<torch_tensorrt::Input> input_range;
114
- input_range.push_back ({in0.sizes (), torch::kF16 });
115
- input_range.push_back ({in0.sizes (), torch::kF16 });
116
- torch_tensorrt::ts::CompileSpec compile_settings (input_range);
117
- compile_settings.require_full_compilation = true ;
168
+
169
+ // auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kUnknown);
170
+ auto input_shape = torch_tensorrt::Input (in0.sizes (), torch_tensorrt::DataType::kHalf );
171
+
172
+ auto input_shape_ivalue = torch::jit::IValue (std::move (c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
173
+
174
+
175
+ c10::TypePtr elementType = input_shape_ivalue.type ();
176
+ auto list = c10::impl::GenericList (elementType);
177
+ list.push_back (input_shape_ivalue);
178
+ list.push_back (input_shape_ivalue);
179
+
180
+
181
+ torch::jit::IValue complex_input_shape (list);
182
+ std::tuple<torch::jit::IValue> input_tuple2 (complex_input_shape);
183
+ torch::jit::IValue complex_input_shape2 (input_tuple2);
184
+
185
+ auto compile_settings = torch_tensorrt::ts::CompileSpec (complex_input_shape2);
186
+ compile_settings.require_full_compilation = false ;
118
187
compile_settings.min_block_size = 1 ;
188
+ compile_settings.torch_executed_ops .push_back (" aten::__getitem__" );
119
189
120
190
// // FP16 execution
121
191
compile_settings.enabled_precisions = {torch::kHalf };
122
192
// // Compile module
123
193
auto trt_mod = torch_tensorrt::torchscript::compile (mod, compile_settings);
124
194
LOG_DEBUG (" Finish compile" );
125
- auto trt_out = trt_mod.forward (inputs_);
195
+ auto trt_out = trt_mod.forward (complex_inputs);
196
+ // auto trt_out = trt_mod.forward(complex_inputs_list);
126
197
198
+ // std::cout << out.toTensor() << std::endl;
127
199
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (out.toTensor (), trt_out.toTensor (), 1e-5 ));
128
200
}
129
201
130
202
203
+ TEST (CppAPITests, TestCollectionTupleInputOutput) {
131
204
132
- TEST (CppAPITests, TestCollectionListInput) {
205
+ std::string path =
206
+ " /root/Torch-TensorRT/tuple_input_output.ts" ;
207
+ // torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kFloat);
208
+ torch::Tensor in0 = torch::randn ({1 , 3 , 512 , 512 }, torch::kCUDA ).to (torch::kHalf );
209
+ // std::vector<at::Tensor> inputs;
210
+ // inputs.push_back(in0);
211
+
212
+ torch::jit::Module mod;
213
+ try {
214
+ // Deserialize the ScriptModule from a file using torch::jit::load().
215
+ mod = torch::jit::load (path);
216
+ } catch (const c10::Error& e) {
217
+ std::cerr << " error loading the model\n " ;
218
+ }
219
+ mod.eval ();
220
+ mod.to (torch::kCUDA );
221
+
222
+
223
+ // std::vector<torch::jit::IValue> inputs_;
224
+
225
+ // for (auto in : inputs) {
226
+ // inputs_.push_back(torch::jit::IValue(in.clone()));
227
+ // }
228
+
229
+
230
+ std::vector<torch::jit::IValue> complex_inputs, complex_inputs_list;
231
+ // std::vector<torch::jit::IValue> tuple;
232
+ std::tuple<torch::jit::IValue, torch::jit::IValue> input_tuple (in0, in0);
233
+ // auto input_list = c10::impl::GenericList(c10::TensorType::get());
234
+ // input_list.push_back(inputs_[0]);
235
+ // input_list.push_back(inputs_[0]);
236
+
237
+ // torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list);
238
+
239
+ complex_inputs.push_back (input_tuple);
240
+
241
+ auto out = mod.forward (complex_inputs);
242
+ LOG_DEBUG (" Finish torchscirpt forward" );
243
+
244
+ // auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kUnknown);
245
+ auto input_shape = torch_tensorrt::Input (in0.sizes (), torch_tensorrt::DataType::kHalf );
246
+
247
+ auto input_shape_ivalue = torch::jit::IValue (std::move (c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
248
+
249
+
250
+ // c10::TypePtr elementType = input_shape_ivalue.type();
251
+ // auto list = c10::impl::GenericList(elementType);
252
+ // list.push_back(input_shape_ivalue);
253
+ // list.push_back(input_shape_ivalue);
254
+
255
+ std::tuple<torch::jit::IValue, torch::jit::IValue> input_shape_tuple (input_shape_ivalue, input_shape_ivalue);
256
+
257
+ torch::jit::IValue complex_input_shape (input_shape_tuple);
258
+ std::tuple<torch::jit::IValue> input_tuple2 (complex_input_shape);
259
+ torch::jit::IValue complex_input_shape2 (input_tuple2);
260
+ // torch::jit::IValue complex_input_shape(list);
261
+
262
+ auto compile_settings = torch_tensorrt::ts::CompileSpec (complex_input_shape2);
263
+ compile_settings.require_full_compilation = false ;
264
+ compile_settings.min_block_size = 1 ;
265
+
266
+ // compile_settings.torch_executed_ops.push_back("prim::TupleConstruct");
267
+
268
+ // // FP16 execution
269
+ compile_settings.enabled_precisions = {torch::kHalf };
270
+ // // Compile module
271
+ auto trt_mod = torch_tensorrt::torchscript::compile (mod, compile_settings);
272
+ LOG_DEBUG (" Finish compile" );
273
+ auto trt_out = trt_mod.forward (complex_inputs);
274
+ // std::cout << out.toTensor() << std::endl;
275
+
276
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (out.toTuple ()->elements ()[0 ].toTensor (), trt_out.toTuple ()->elements ()[0 ].toTensor (), 1e-5 ));
277
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (out.toTuple ()->elements ()[1 ].toTensor (), trt_out.toTuple ()->elements ()[1 ].toTensor (), 1e-5 ));
278
+ }
279
+
280
+
281
+ TEST (CppAPITests, TestCollectionListInputOutput) {
133
282
134
283
std::string path =
135
- " /root/Torch-TensorRT/list_input .ts" ;
284
+ " /root/Torch-TensorRT/list_input_output .ts" ;
136
285
torch::Tensor in0 = torch::randn ({1 , 3 , 512 , 512 }, torch::kCUDA ).to (torch::kHalf );
137
286
std::vector<at::Tensor> inputs;
138
287
inputs.push_back (in0);
@@ -187,7 +336,10 @@ TEST(CppAPITests, TestCollectionListInput) {
187
336
auto compile_settings = torch_tensorrt::ts::CompileSpec (complex_input_shape2);
188
337
compile_settings.require_full_compilation = false ;
189
338
compile_settings.min_block_size = 1 ;
339
+
340
+ // Need to skip the conversion of __getitem__ and ListConstruct
190
341
compile_settings.torch_executed_ops .push_back (" aten::__getitem__" );
342
+ compile_settings.torch_executed_ops .push_back (" prim::ListConstruct" );
191
343
192
344
// // FP16 execution
193
345
compile_settings.enabled_precisions = {torch::kHalf };
@@ -198,5 +350,7 @@ TEST(CppAPITests, TestCollectionListInput) {
198
350
// auto trt_out = trt_mod.forward(complex_inputs_list);
199
351
200
352
// std::cout << out.toTensor() << std::endl;
201
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (out.toTensor (), trt_out.toTensor (), 1e-5 ));
353
+
354
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (out.toList ().vec ()[0 ].toTensor (), trt_out.toList ().vec ()[0 ].toTensor (), 1e-5 ));
355
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (out.toList ().vec ()[1 ].toTensor (), trt_out.toList ().vec ()[1 ].toTensor (), 1e-5 ));
202
356
}
0 commit comments