Skip to content

Commit ccee7f5

Browse files
committed
feat: [collection] support output type of list and tuple
Signed-off-by: inocsin <[email protected]>
1 parent dbbf5cc commit ccee7f5

File tree

3 files changed

+202
-15
lines changed

3 files changed

+202
-15
lines changed

core/conversion/conversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ std::string ConvertBlockToEngine(
481481
std::unordered_map<c10::OperatorName, std::string> GetUnsupportedOpsInBlock(const torch::jit::Block* b) {
482482
std::unordered_map<c10::OperatorName, std::string> unsupported_ops;
483483
for (const auto n : b->nodes()) {
484-
if (n->kind() != torch::jit::prim::Loop && n->kind() != torch::jit::prim::If && !OpSupported(n)) {
484+
if (n->kind() != torch::jit::prim::Loop && n->kind() != torch::jit::prim::If && !OpSupported(n) && n->kind() != torch::jit::prim::TupleConstruct) {
485485
auto schema = n->maybeSchema();
486486
TORCHTRT_CHECK(
487487
schema,

tests/cpp/test_collection.cpp

Lines changed: 167 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,52 @@
66
#include "torch_tensorrt/torch_tensorrt.h"
77

88

9+
TEST(CppAPITests, TestCollectionNormalInput) {
10+
11+
std::string path =
12+
"/root/Torch-TensorRT/normal_model.ts";
13+
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
14+
std::vector<at::Tensor> inputs;
15+
inputs.push_back(in0);
16+
inputs.push_back(in0);
17+
18+
torch::jit::Module mod;
19+
try {
20+
// Deserialize the ScriptModule from a file using torch::jit::load().
21+
mod = torch::jit::load(path);
22+
} catch (const c10::Error& e) {
23+
std::cerr << "error loading the model\n";
24+
}
25+
mod.eval();
26+
mod.to(torch::kCUDA);
27+
28+
29+
std::vector<torch::jit::IValue> inputs_;
30+
31+
for (auto in : inputs) {
32+
inputs_.push_back(torch::jit::IValue(in.clone()));
33+
}
34+
35+
auto out = mod.forward(inputs_);
36+
LOG_DEBUG("Finish torchscirpt forward");
37+
38+
std::vector<torch_tensorrt::Input> input_range;
39+
input_range.push_back({in0.sizes(), torch::kF16});
40+
input_range.push_back({in0.sizes(), torch::kF16});
41+
torch_tensorrt::ts::CompileSpec compile_settings(input_range);
42+
compile_settings.require_full_compilation = true;
43+
compile_settings.min_block_size = 1;
44+
45+
// // FP16 execution
46+
compile_settings.enabled_precisions = {torch::kHalf};
47+
// // Compile module
48+
auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings);
49+
LOG_DEBUG("Finish compile");
50+
auto trt_out = trt_mod.forward(inputs_);
51+
52+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5));
53+
}
54+
955
TEST(CppAPITests, TestCollectionTupleInput) {
1056

1157
std::string path =
@@ -81,14 +127,13 @@ TEST(CppAPITests, TestCollectionTupleInput) {
81127
}
82128

83129

84-
TEST(CppAPITests, TestCollectionNormalInput) {
130+
TEST(CppAPITests, TestCollectionListInput) {
85131

86132
std::string path =
87-
"/root/Torch-TensorRT/normal_model.ts";
133+
"/root/Torch-TensorRT/list_input.ts";
88134
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
89135
std::vector<at::Tensor> inputs;
90136
inputs.push_back(in0);
91-
inputs.push_back(in0);
92137

93138
torch::jit::Module mod;
94139
try {
@@ -107,32 +152,136 @@ TEST(CppAPITests, TestCollectionNormalInput) {
107152
inputs_.push_back(torch::jit::IValue(in.clone()));
108153
}
109154

110-
auto out = mod.forward(inputs_);
155+
std::vector<torch::jit::IValue> complex_inputs;
156+
auto input_list = c10::impl::GenericList(c10::TensorType::get());
157+
input_list.push_back(inputs_[0]);
158+
input_list.push_back(inputs_[0]);
159+
160+
torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list);
161+
162+
complex_inputs.push_back(input_list_ivalue);
163+
164+
165+
auto out = mod.forward(complex_inputs);
111166
LOG_DEBUG("Finish torchscirpt forward");
112167

113-
std::vector<torch_tensorrt::Input> input_range;
114-
input_range.push_back({in0.sizes(), torch::kF16});
115-
input_range.push_back({in0.sizes(), torch::kF16});
116-
torch_tensorrt::ts::CompileSpec compile_settings(input_range);
117-
compile_settings.require_full_compilation = true;
168+
169+
// auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kUnknown);
170+
auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf);
171+
172+
auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
173+
174+
175+
c10::TypePtr elementType = input_shape_ivalue.type();
176+
auto list = c10::impl::GenericList(elementType);
177+
list.push_back(input_shape_ivalue);
178+
list.push_back(input_shape_ivalue);
179+
180+
181+
torch::jit::IValue complex_input_shape(list);
182+
std::tuple<torch::jit::IValue> input_tuple2(complex_input_shape);
183+
torch::jit::IValue complex_input_shape2(input_tuple2);
184+
185+
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
186+
compile_settings.require_full_compilation = false;
118187
compile_settings.min_block_size = 1;
188+
compile_settings.torch_executed_ops.push_back("aten::__getitem__");
119189

120190
// // FP16 execution
121191
compile_settings.enabled_precisions = {torch::kHalf};
122192
// // Compile module
123193
auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings);
124194
LOG_DEBUG("Finish compile");
125-
auto trt_out = trt_mod.forward(inputs_);
195+
auto trt_out = trt_mod.forward(complex_inputs);
196+
// auto trt_out = trt_mod.forward(complex_inputs_list);
126197

198+
// std::cout << out.toTensor() << std::endl;
127199
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5));
128200
}
129201

130202

203+
TEST(CppAPITests, TestCollectionTupleInputOutput) {
131204

132-
TEST(CppAPITests, TestCollectionListInput) {
205+
std::string path =
206+
"/root/Torch-TensorRT/tuple_input_output.ts";
207+
// torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kFloat);
208+
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
209+
// std::vector<at::Tensor> inputs;
210+
// inputs.push_back(in0);
211+
212+
torch::jit::Module mod;
213+
try {
214+
// Deserialize the ScriptModule from a file using torch::jit::load().
215+
mod = torch::jit::load(path);
216+
} catch (const c10::Error& e) {
217+
std::cerr << "error loading the model\n";
218+
}
219+
mod.eval();
220+
mod.to(torch::kCUDA);
221+
222+
223+
// std::vector<torch::jit::IValue> inputs_;
224+
225+
// for (auto in : inputs) {
226+
// inputs_.push_back(torch::jit::IValue(in.clone()));
227+
// }
228+
229+
230+
std::vector<torch::jit::IValue> complex_inputs, complex_inputs_list;
231+
// std::vector<torch::jit::IValue> tuple;
232+
std::tuple<torch::jit::IValue, torch::jit::IValue> input_tuple(in0, in0);
233+
// auto input_list = c10::impl::GenericList(c10::TensorType::get());
234+
// input_list.push_back(inputs_[0]);
235+
// input_list.push_back(inputs_[0]);
236+
237+
// torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list);
238+
239+
complex_inputs.push_back(input_tuple);
240+
241+
auto out = mod.forward(complex_inputs);
242+
LOG_DEBUG("Finish torchscirpt forward");
243+
244+
// auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kUnknown);
245+
auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf);
246+
247+
auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
248+
249+
250+
// c10::TypePtr elementType = input_shape_ivalue.type();
251+
// auto list = c10::impl::GenericList(elementType);
252+
// list.push_back(input_shape_ivalue);
253+
// list.push_back(input_shape_ivalue);
254+
255+
std::tuple<torch::jit::IValue, torch::jit::IValue> input_shape_tuple(input_shape_ivalue, input_shape_ivalue);
256+
257+
torch::jit::IValue complex_input_shape(input_shape_tuple);
258+
std::tuple<torch::jit::IValue> input_tuple2(complex_input_shape);
259+
torch::jit::IValue complex_input_shape2(input_tuple2);
260+
// torch::jit::IValue complex_input_shape(list);
261+
262+
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
263+
compile_settings.require_full_compilation = false;
264+
compile_settings.min_block_size = 1;
265+
266+
// compile_settings.torch_executed_ops.push_back("prim::TupleConstruct");
267+
268+
// // FP16 execution
269+
compile_settings.enabled_precisions = {torch::kHalf};
270+
// // Compile module
271+
auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings);
272+
LOG_DEBUG("Finish compile");
273+
auto trt_out = trt_mod.forward(complex_inputs);
274+
// std::cout << out.toTensor() << std::endl;
275+
276+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTuple()->elements()[0].toTensor(), trt_out.toTuple()->elements()[0].toTensor(), 1e-5));
277+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor(), 1e-5));
278+
}
279+
280+
281+
TEST(CppAPITests, TestCollectionListInputOutput) {
133282

134283
std::string path =
135-
"/root/Torch-TensorRT/list_input.ts";
284+
"/root/Torch-TensorRT/list_input_output.ts";
136285
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
137286
std::vector<at::Tensor> inputs;
138287
inputs.push_back(in0);
@@ -187,7 +336,10 @@ TEST(CppAPITests, TestCollectionListInput) {
187336
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
188337
compile_settings.require_full_compilation = false;
189338
compile_settings.min_block_size = 1;
339+
340+
// Need to skip the conversion of __getitem__ and ListConstruct
190341
compile_settings.torch_executed_ops.push_back("aten::__getitem__");
342+
compile_settings.torch_executed_ops.push_back("prim::ListConstruct");
191343

192344
// // FP16 execution
193345
compile_settings.enabled_precisions = {torch::kHalf};
@@ -198,5 +350,7 @@ TEST(CppAPITests, TestCollectionListInput) {
198350
// auto trt_out = trt_mod.forward(complex_inputs_list);
199351

200352
// std::cout << out.toTensor() << std::endl;
201-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5));
353+
354+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toList().vec()[0].toTensor(), trt_out.toList().vec()[0].toTensor(), 1e-5));
355+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toList().vec()[1].toTensor(), trt_out.toList().vec()[1].toTensor(), 1e-5));
202356
}

tests/py/test_collection.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,25 @@ def forward(self, z: List[torch.Tensor]):
5959
r = z[0] + z[1]
6060
return r
6161

62+
class TupleInputOutput(nn.Module):
63+
def __init__(self):
64+
super(TupleInputOutput, self).__init__()
65+
66+
def forward(self, z: Tuple[torch.Tensor, torch.Tensor]):
67+
r1 = z[0] + z[1]
68+
r2 = z[0] - z[1]
69+
r = (r1, r2)
70+
return r
71+
72+
class ListInputOutput(nn.Module):
73+
def __init__(self):
74+
super(ListInputOutput, self).__init__()
75+
76+
def forward(self, z: List[torch.Tensor]):
77+
r1 = z[0] + z[1]
78+
r2 = z[0] - z[1]
79+
r = [r1, r2]
80+
return r
6281

6382
input_data = torch.randn((16, 3, 32, 32))
6483
input_data = input_data.float().to("cuda")
@@ -82,4 +101,18 @@ def forward(self, z: List[torch.Tensor]):
82101
print(list_input_ts.graph)
83102
result = list_input_ts([input_data, input_data])
84103
list_input_ts.to("cuda").eval()
85-
torch.jit.save(list_input_ts, "./list_input.ts")
104+
torch.jit.save(list_input_ts, "./list_input.ts")
105+
106+
tuple_input = TupleInputOutput()
107+
tuple_input_ts = torch.jit.script(tuple_input)
108+
print(tuple_input_ts.graph)
109+
result = tuple_input_ts((input_data, input_data))
110+
tuple_input_ts.to("cuda").eval()
111+
torch.jit.save(tuple_input_ts, "./tuple_input_output.ts")
112+
113+
list_input = ListInputOutput()
114+
list_input_ts = torch.jit.script(list_input)
115+
print(list_input_ts.graph)
116+
result = list_input_ts([input_data, input_data])
117+
list_input_ts.to("cuda").eval()
118+
torch.jit.save(list_input_ts, "./list_input_output.ts")

0 commit comments

Comments
 (0)