@@ -33,7 +33,7 @@ TEST(CppAPITests, TestCollectionTupleInput) {
33
33
34
34
35
35
std::vector<torch::jit::IValue> complex_inputs, complex_inputs_list;
36
- std::vector<torch::jit::IValue> tuple;
36
+ // std::vector<torch::jit::IValue> tuple;
37
37
std::tuple<torch::jit::IValue, torch::jit::IValue> input_tuple (in0, in0);
38
38
// auto input_list = c10::impl::GenericList(c10::TensorType::get());
39
39
// input_list.push_back(inputs_[0]);
@@ -42,8 +42,8 @@ TEST(CppAPITests, TestCollectionTupleInput) {
42
42
// torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list);
43
43
44
44
complex_inputs.push_back (input_tuple);
45
- complex_inputs_list.push_back (in0);
46
- complex_inputs_list.push_back (in0);
45
+ // complex_inputs_list.push_back(in0);
46
+ // complex_inputs_list.push_back(in0);
47
47
48
48
49
49
@@ -56,10 +56,10 @@ TEST(CppAPITests, TestCollectionTupleInput) {
56
56
auto input_shape_ivalue = torch::jit::IValue (std::move (c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
57
57
58
58
59
- c10::TypePtr elementType = input_shape_ivalue.type ();
60
- auto list = c10::impl::GenericList (elementType);
61
- list.push_back (input_shape_ivalue);
62
- list.push_back (input_shape_ivalue);
59
+ // c10::TypePtr elementType = input_shape_ivalue.type();
60
+ // auto list = c10::impl::GenericList(elementType);
61
+ // list.push_back(input_shape_ivalue);
62
+ // list.push_back(input_shape_ivalue);
63
63
64
64
std::tuple<torch::jit::IValue, torch::jit::IValue> input_shape_tuple (input_shape_ivalue, input_shape_ivalue);
65
65
@@ -73,10 +73,6 @@ TEST(CppAPITests, TestCollectionTupleInput) {
73
73
compile_settings.require_full_compilation = false ;
74
74
compile_settings.min_block_size = 1 ;
75
75
76
- // compile_settings.torch_executed_modules.push_back("model1");
77
- // compile_settings.torch_executed_ops.push_back("aten::sub");
78
-
79
-
80
76
// // FP16 execution
81
77
// compile_settings.enabled_precisions = {torch::kHalf};
82
78
// // Compile module
@@ -133,5 +129,78 @@ TEST(CppAPITests, TestCollectionNormalInput) {
133
129
LOG_DEBUG (" Finish compile" );
134
130
auto trt_out = trt_mod.forward (inputs_);
135
131
132
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (out.toTensor (), trt_out.toTensor (), 1e-5 ));
133
+ }
134
+
135
+
136
+
137
+ TEST (CppAPITests, TestCollectionListInput) {
138
+
139
+ std::string path =
140
+ " /root/Torch-TensorRT/list_input.ts" ;
141
+ torch::Tensor in0 = torch::randn ({1 , 3 , 512 , 512 }, torch::kCUDA ).to (torch::kFloat );
142
+ std::vector<at::Tensor> inputs;
143
+ inputs.push_back (in0);
144
+
145
+ torch::jit::Module mod;
146
+ try {
147
+ // Deserialize the ScriptModule from a file using torch::jit::load().
148
+ mod = torch::jit::load (path);
149
+ } catch (const c10::Error& e) {
150
+ std::cerr << " error loading the model\n " ;
151
+ }
152
+ mod.eval ();
153
+ mod.to (torch::kCUDA );
154
+
155
+
156
+ std::vector<torch::jit::IValue> inputs_;
157
+
158
+ for (auto in : inputs) {
159
+ inputs_.push_back (torch::jit::IValue (in.clone ()));
160
+ }
161
+
162
+ std::vector<torch::jit::IValue> complex_inputs;
163
+ auto input_list = c10::impl::GenericList (c10::TensorType::get ());
164
+ input_list.push_back (inputs_[0 ]);
165
+ input_list.push_back (inputs_[0 ]);
166
+
167
+ torch::jit::IValue input_list_ivalue = torch::jit::IValue (input_list);
168
+
169
+ complex_inputs.push_back (input_list_ivalue);
170
+
171
+
172
+ auto out = mod.forward (complex_inputs);
173
+ LOG_DEBUG (" Finish torchscirpt forward" );
174
+
175
+
176
+ auto input_shape = torch_tensorrt::Input (in0.sizes (), torch_tensorrt::DataType::kUnknown );
177
+
178
+ auto input_shape_ivalue = torch::jit::IValue (std::move (c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
179
+
180
+
181
+ c10::TypePtr elementType = input_shape_ivalue.type ();
182
+ auto list = c10::impl::GenericList (elementType);
183
+ list.push_back (input_shape_ivalue);
184
+ list.push_back (input_shape_ivalue);
185
+
186
+
187
+ torch::jit::IValue complex_input_shape (list);
188
+ std::tuple<torch::jit::IValue> input_tuple2 (complex_input_shape);
189
+ torch::jit::IValue complex_input_shape2 (input_tuple2);
190
+
191
+ auto compile_settings = torch_tensorrt::ts::CompileSpec (complex_input_shape2);
192
+ compile_settings.require_full_compilation = false ;
193
+ compile_settings.min_block_size = 1 ;
194
+ compile_settings.torch_executed_ops .push_back (" aten::__getitem__" );
195
+
196
+ // // FP16 execution
197
+ // compile_settings.enabled_precisions = {torch::kHalf};
198
+ // // Compile module
199
+ auto trt_mod = torch_tensorrt::torchscript::compile (mod, compile_settings);
200
+ LOG_DEBUG (" Finish compile" );
201
+ auto trt_out = trt_mod.forward (complex_inputs);
202
+ // auto trt_out = trt_mod.forward(complex_inputs_list);
203
+
204
+
136
205
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (out.toTensor (), trt_out.toTensor (), 1e-5 ));
137
206
}
0 commit comments