Skip to content

Commit fdf81c2

Browse files
committed
feat: [collection] support list input type
Signed-off-by: inocsin <[email protected]>
1 parent ced2b9c commit fdf81c2

File tree

3 files changed

+98
-15
lines changed

3 files changed

+98
-15
lines changed

core/partitioning/shape_analysis.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,23 @@ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomI
6262
std::vector<torch::jit::IValue> list;
6363
c10::TypePtr elementType = c10::TensorType::get();
6464
auto generic_list = c10::impl::GenericList(elementType);
65+
LOG_DEBUG("generateRandomInputs, 0");
6566
for (int i = 0; i < input.second.size(); i++) {
66-
auto in = generateSingleInput(input.second[i], types[input.first][i]);
67+
// types for list is {}
68+
// auto in = generateSingleInput(input.second[i], types[input.first][i]);
69+
// TODO: need to decide the input type of list elements in ir.cpp
70+
c10::optional<at::ScalarType> type_opt = {};
71+
auto in = generateSingleInput(input.second[i], type_opt);
6772
// list.push_back(in.clone());
6873
generic_list.push_back(in.clone());
74+
LOG_DEBUG("generateRandomInputs, 1");
6975
}
7076
// c10::TypePtr elementType = list[0].type();
71-
77+
LOG_DEBUG("generateRandomInputs, 2");
7278
// generic_list.append(list);
73-
ivalue_map[input.first] = generic_list;
79+
ivalue_map[input.first] = c10::IValue(generic_list);
7480
// jit_inputs_ivalues.push_back(list);
81+
LOG_DEBUG("generateRandomInputs, finish generate random input of list type");
7582
} else if (input.first->type()->kind() == torch::jit::TypeKind::TupleType) {
7683
// create tuple
7784
// auto tuple = torch::jit::Tuple::create(ivalues_maps[input]);

tests/cpp/test_collection.cpp

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ TEST(CppAPITests, TestCollectionTupleInput) {
3333

3434

3535
std::vector<torch::jit::IValue> complex_inputs, complex_inputs_list;
36-
std::vector<torch::jit::IValue> tuple;
36+
// std::vector<torch::jit::IValue> tuple;
3737
std::tuple<torch::jit::IValue, torch::jit::IValue> input_tuple(in0, in0);
3838
// auto input_list = c10::impl::GenericList(c10::TensorType::get());
3939
// input_list.push_back(inputs_[0]);
@@ -42,8 +42,8 @@ TEST(CppAPITests, TestCollectionTupleInput) {
4242
// torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list);
4343

4444
complex_inputs.push_back(input_tuple);
45-
complex_inputs_list.push_back(in0);
46-
complex_inputs_list.push_back(in0);
45+
// complex_inputs_list.push_back(in0);
46+
// complex_inputs_list.push_back(in0);
4747

4848

4949

@@ -56,10 +56,10 @@ TEST(CppAPITests, TestCollectionTupleInput) {
5656
auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
5757

5858

59-
c10::TypePtr elementType = input_shape_ivalue.type();
60-
auto list = c10::impl::GenericList(elementType);
61-
list.push_back(input_shape_ivalue);
62-
list.push_back(input_shape_ivalue);
59+
// c10::TypePtr elementType = input_shape_ivalue.type();
60+
// auto list = c10::impl::GenericList(elementType);
61+
// list.push_back(input_shape_ivalue);
62+
// list.push_back(input_shape_ivalue);
6363

6464
std::tuple<torch::jit::IValue, torch::jit::IValue> input_shape_tuple(input_shape_ivalue, input_shape_ivalue);
6565

@@ -73,10 +73,6 @@ TEST(CppAPITests, TestCollectionTupleInput) {
7373
compile_settings.require_full_compilation = false;
7474
compile_settings.min_block_size = 1;
7575

76-
// compile_settings.torch_executed_modules.push_back("model1");
77-
// compile_settings.torch_executed_ops.push_back("aten::sub");
78-
79-
8076
// // FP16 execution
8177
// compile_settings.enabled_precisions = {torch::kHalf};
8278
// // Compile module
@@ -133,5 +129,78 @@ TEST(CppAPITests, TestCollectionNormalInput) {
133129
LOG_DEBUG("Finish compile");
134130
auto trt_out = trt_mod.forward(inputs_);
135131

132+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5));
133+
}
134+
135+
136+
137+
TEST(CppAPITests, TestCollectionListInput) {
138+
139+
std::string path =
140+
"/root/Torch-TensorRT/list_input.ts";
141+
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kFloat);
142+
std::vector<at::Tensor> inputs;
143+
inputs.push_back(in0);
144+
145+
torch::jit::Module mod;
146+
try {
147+
// Deserialize the ScriptModule from a file using torch::jit::load().
148+
mod = torch::jit::load(path);
149+
} catch (const c10::Error& e) {
150+
std::cerr << "error loading the model\n";
151+
}
152+
mod.eval();
153+
mod.to(torch::kCUDA);
154+
155+
156+
std::vector<torch::jit::IValue> inputs_;
157+
158+
for (auto in : inputs) {
159+
inputs_.push_back(torch::jit::IValue(in.clone()));
160+
}
161+
162+
std::vector<torch::jit::IValue> complex_inputs;
163+
auto input_list = c10::impl::GenericList(c10::TensorType::get());
164+
input_list.push_back(inputs_[0]);
165+
input_list.push_back(inputs_[0]);
166+
167+
torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list);
168+
169+
complex_inputs.push_back(input_list_ivalue);
170+
171+
172+
auto out = mod.forward(complex_inputs);
173+
LOG_DEBUG("Finish torchscirpt forward");
174+
175+
176+
auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kUnknown);
177+
178+
auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
179+
180+
181+
c10::TypePtr elementType = input_shape_ivalue.type();
182+
auto list = c10::impl::GenericList(elementType);
183+
list.push_back(input_shape_ivalue);
184+
list.push_back(input_shape_ivalue);
185+
186+
187+
torch::jit::IValue complex_input_shape(list);
188+
std::tuple<torch::jit::IValue> input_tuple2(complex_input_shape);
189+
torch::jit::IValue complex_input_shape2(input_tuple2);
190+
191+
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
192+
compile_settings.require_full_compilation = false;
193+
compile_settings.min_block_size = 1;
194+
compile_settings.torch_executed_ops.push_back("aten::__getitem__");
195+
196+
// // FP16 execution
197+
// compile_settings.enabled_precisions = {torch::kHalf};
198+
// // Compile module
199+
auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings);
200+
LOG_DEBUG("Finish compile");
201+
auto trt_out = trt_mod.forward(complex_inputs);
202+
// auto trt_out = trt_mod.forward(complex_inputs_list);
203+
204+
136205
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5));
137206
}

tests/py/test_collection.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,11 @@ def forward(self, z: List[torch.Tensor]):
7575
print(tuple_input_ts.graph)
7676
result = tuple_input_ts((input_data, input_data))
7777
tuple_input_ts.to("cuda").eval()
78-
torch.jit.save(tuple_input_ts, "./tuple_input.ts")
78+
torch.jit.save(tuple_input_ts, "./tuple_input.ts")
79+
80+
list_input = ListInput()
81+
list_input_ts = torch.jit.script(list_input)
82+
print(list_input_ts.graph)
83+
result = list_input_ts([input_data, input_data])
84+
list_input_ts.to("cuda").eval()
85+
torch.jit.save(list_input_ts, "./list_input.ts")

0 commit comments

Comments
 (0)