Skip to content

Commit dbbf5cc

Browse files
committed
feat: [collection] support user defined input data type
Signed-off-by: inocsin <[email protected]>
1 parent fdf81c2 commit dbbf5cc

File tree

7 files changed

+51
-55
lines changed

7 files changed

+51
-55
lines changed

core/compiler.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,10 +331,22 @@ void MapInputsAndDetermineDTypes(
331331
spec[i].dtype = nvinfer1::DataType::kFLOAT;
332332
} else if (spec[i].dtype_is_user_defined && cfg.partition_info.enabled) {
333333
if (!est_type_opt[i]) {
334-
LOG_INFO("Cannot infer input tensor dtype in graph, unable to verify user input dtype settings");
334+
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
335+
// TODO set input data type
336+
337+
std::stringstream ss;
338+
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
339+
ss << cfg.convert_info.collection_inputs.find(in)->second[i].dtype;
340+
// ss << cfg.convert_info.inputs.find(in)->second.dtype;
341+
ss << ". The compiler is going to use the user setting " << cfg.convert_info.collection_inputs.find(in)->second[i].dtype;
342+
auto warn_str = ss.str();
343+
LOG_WARNING(warn_str);
344+
// Overwrite type map with user settings
345+
first_use_type_map[in][i] = {util::TRTDataTypeToScalarType(cfg.convert_info.collection_inputs.find(in)->second[i].dtype)};
346+
335347
} else {
336348
// if (util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype) != est_type_opt.value()) {
337-
if (util::TRTDataTypeToScalarType(cfg.convert_info.collection_inputs.find(in)->second[i].dtype) != est_type_opt[i].value()) {
349+
if (util::TRTDataTypeToScalarType(cfg.convert_info.collection_inputs.find(in)->second[i].dtype) != est_type_opt[i].value()) {
338350
std::stringstream ss;
339351
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
340352
ss << cfg.convert_info.collection_inputs.find(in)->second[i].dtype;

core/ir/ir.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,19 +246,25 @@ CollectionTypeMap get_block_first_calc_dtypes_opt_collection(torch::jit::Block*
246246
LOG_DEBUG("get_block_first_calc_dtypes_opt_collection TupleType");
247247
// TODO: to evaluate the data type of tuple element
248248
// make sure very time get the same ptr
249+
c10::optional<at::ScalarType> tp = get_value_first_calc_dtype_opt(b, i);
249250
at::ArrayRef<torch::jit::Value*> unpack_tuple = torch::jit::createTupleUnpack(i);
250251
LOG_DEBUG("get_block_first_calc_dtypes_opt_collection: tuple size " << unpack_tuple.size());
251-
std::vector<c10::optional<at::ScalarType>> empty_dytpes(unpack_tuple.size());
252-
types.insert({i, empty_dytpes}); // insert an empty
252+
// Assume all tuple has the same datatype
253+
// std::vector<c10::optional<at::ScalarType>> dytpes(unpack_tuple.size(), tp);
254+
std::vector<c10::optional<at::ScalarType>> dytpes(unpack_tuple.size());
255+
types.insert({i, dytpes}); // insert an empty
253256
// for (auto item: unpack_tuple) {
254257
// torch::jit::Value* in = item;
255258
// types.insert({in, get_value_first_calc_dtype_opt(b, i)});
256259
// }
257260

258261
} else if(i->type()->kind() == torch::jit::TypeKind::ListType) {
259262
// TODO: to decide the size of list and type of list element
260-
LOG_DEBUG("get_block_first_calc_dtypes_opt ListType");
261-
types.insert({i, {}}); // insert an empty
263+
LOG_DEBUG("get_block_first_calc_dtypes_opt ListType: use size " << i->uses().size());
264+
c10::optional<at::ScalarType> tp = get_value_first_calc_dtype_opt(b, i);
265+
// std::vector<c10::optional<at::ScalarType>> dytpes(i->uses().size());
266+
std::vector<c10::optional<at::ScalarType>> dytpes(i->uses().size(), tp);
267+
types.insert({i, dytpes}); // insert an empty
262268

263269
}
264270
}

core/ir/ir.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ struct GraphInputs {
5252
// // TODO construct the IValue
5353
// }
5454
torch::jit::IValue input_signature; // nested Input, full input spec
55-
std::vector<Input> flattened_inputs; // flattend Input, can be removed
55+
std::vector<Input> flattened_inputs; // flattend Input
5656
std::vector<std::vector<Input>> collection_inputs; // only support two layer nesting, e.g. ((a, b), [c, d], e)
5757
};
5858

core/partitioning/shape_analysis.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomI
6767
// types for list is {}
6868
// auto in = generateSingleInput(input.second[i], types[input.first][i]);
6969
// TODO: need to decide the input type of list elements in ir.cpp
70-
c10::optional<at::ScalarType> type_opt = {};
71-
auto in = generateSingleInput(input.second[i], type_opt);
70+
// c10::optional<at::ScalarType> type_opt = {};
71+
// auto in = generateSingleInput(input.second[i], type_opt);
72+
auto in = generateSingleInput(input.second[i], types[input.first][i]);
7273
// list.push_back(in.clone());
7374
generic_list.push_back(in.clone());
7475
LOG_DEBUG("generateRandomInputs, 1");

cpp/src/compile_spec.cpp

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -90,25 +90,6 @@ void flatten_dfs(std::vector<torchtrt::core::ir::Input>& flattened_inputs, std::
9090
torch_tensorrt::core::ir::GraphInputs to_internal_graph_inputs(GraphInputs external_graph_input) {
9191
torch_tensorrt::core::ir::GraphInputs internal_graph_input;
9292

93-
// // flattened version
94-
// if (external_graph_input.flattened_inputs.size() > 0) {
95-
// // std::vector<torch::jit::IValue> input_shape_list;
96-
// auto empty_ivalue = torch::jit::IValue(c10::make_intrusive<torchtrt::core::ir::Input>(torchtrt::core::ir::Input()));
97-
// c10::TypePtr type = empty_ivalue.type();
98-
// auto input_shape_list = c10::impl::GenericList(type);
99-
// std::vector<torchtrt::core::ir::Input> internal_input = to_vec_internal_inputs(external_graph_input.flattened_inputs);
100-
// for (auto input_shape: internal_input) {
101-
// auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<torchtrt::core::ir::Input>(input_shape)));
102-
// input_shape_list.push_back(input_shape_ivalue);
103-
// }
104-
105-
// torch::jit::IValue input_signature(input_shape_list);
106-
// internal_graph_input.flattened_inputs = internal_input;
107-
// internal_graph_input.input_signature = input_signature;
108-
109-
// }
110-
// // nested version
111-
// else {
11293
std::vector<torchtrt::core::ir::Input> flattened_inputs;
11394
std::vector<std::vector<torchtrt::core::ir::Input>> collection_inputs;
11495

@@ -134,6 +115,7 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
134115
internal.graph_inputs.collection_inputs.resize(internal.inputs.size());
135116
for (int i = 0; i < internal.inputs.size(); i++) {
136117
internal.graph_inputs.collection_inputs[i].push_back(internal.inputs[i]);
118+
internal.graph_inputs.flattened_inputs = internal.inputs;
137119
}
138120
}
139121

cpp/src/torch_tensorrt.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ torch::jit::script::Module compile(const torch::jit::script::Module& module, Com
3030
LOG_DEBUG(get_build_info());
3131
// Want to export a much simpler (non TRT header dependent) API so doing the
3232
// type conversion here
33-
printf("in torch_tensorrt::ts::compile\n");
3433
return torch_tensorrt::core::CompileGraph(module, to_internal_compile_spec(info));
3534
}
3635

tests/cpp/test_collection.cpp

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@ TEST(CppAPITests, TestCollectionTupleInput) {
1010

1111
std::string path =
1212
"/root/Torch-TensorRT/tuple_input.ts";
13-
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kFloat);
14-
std::vector<at::Tensor> inputs;
15-
inputs.push_back(in0);
13+
// torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kFloat);
14+
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
15+
// std::vector<at::Tensor> inputs;
16+
// inputs.push_back(in0);
1617

1718
torch::jit::Module mod;
1819
try {
@@ -23,13 +24,13 @@ TEST(CppAPITests, TestCollectionTupleInput) {
2324
}
2425
mod.eval();
2526
mod.to(torch::kCUDA);
26-
2727

28-
std::vector<torch::jit::IValue> inputs_;
2928

30-
for (auto in : inputs) {
31-
inputs_.push_back(torch::jit::IValue(in.clone()));
32-
}
29+
// std::vector<torch::jit::IValue> inputs_;
30+
31+
// for (auto in : inputs) {
32+
// inputs_.push_back(torch::jit::IValue(in.clone()));
33+
// }
3334

3435

3536
std::vector<torch::jit::IValue> complex_inputs, complex_inputs_list;
@@ -42,16 +43,12 @@ TEST(CppAPITests, TestCollectionTupleInput) {
4243
// torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list);
4344

4445
complex_inputs.push_back(input_tuple);
45-
// complex_inputs_list.push_back(in0);
46-
// complex_inputs_list.push_back(in0);
47-
48-
4946

5047
auto out = mod.forward(complex_inputs);
5148
LOG_DEBUG("Finish torchscirpt forward");
5249

53-
54-
auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kUnknown);
50+
// auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kUnknown);
51+
auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf);
5552

5653
auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
5754

@@ -63,7 +60,6 @@ TEST(CppAPITests, TestCollectionTupleInput) {
6360

6461
std::tuple<torch::jit::IValue, torch::jit::IValue> input_shape_tuple(input_shape_ivalue, input_shape_ivalue);
6562

66-
6763
torch::jit::IValue complex_input_shape(input_shape_tuple);
6864
std::tuple<torch::jit::IValue> input_tuple2(complex_input_shape);
6965
torch::jit::IValue complex_input_shape2(input_tuple2);
@@ -74,13 +70,12 @@ TEST(CppAPITests, TestCollectionTupleInput) {
7470
compile_settings.min_block_size = 1;
7571

7672
// // FP16 execution
77-
// compile_settings.enabled_precisions = {torch::kHalf};
73+
compile_settings.enabled_precisions = {torch::kHalf};
7874
// // Compile module
7975
auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings);
8076
LOG_DEBUG("Finish compile");
8177
auto trt_out = trt_mod.forward(complex_inputs);
82-
// auto trt_out = trt_mod.forward(complex_inputs_list);
83-
78+
// std::cout << out.toTensor() << std::endl;
8479

8580
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5));
8681
}
@@ -90,7 +85,7 @@ TEST(CppAPITests, TestCollectionNormalInput) {
9085

9186
std::string path =
9287
"/root/Torch-TensorRT/normal_model.ts";
93-
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kFloat);
88+
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
9489
std::vector<at::Tensor> inputs;
9590
inputs.push_back(in0);
9691
inputs.push_back(in0);
@@ -116,14 +111,14 @@ TEST(CppAPITests, TestCollectionNormalInput) {
116111
LOG_DEBUG("Finish torchscirpt forward");
117112

118113
std::vector<torch_tensorrt::Input> input_range;
119-
input_range.push_back({in0.sizes(), torch::kF32});
120-
input_range.push_back({in0.sizes(), torch::kF32});
114+
input_range.push_back({in0.sizes(), torch::kF16});
115+
input_range.push_back({in0.sizes(), torch::kF16});
121116
torch_tensorrt::ts::CompileSpec compile_settings(input_range);
122117
compile_settings.require_full_compilation = true;
123118
compile_settings.min_block_size = 1;
124119

125120
// // FP16 execution
126-
// compile_settings.enabled_precisions = {torch::kHalf};
121+
compile_settings.enabled_precisions = {torch::kHalf};
127122
// // Compile module
128123
auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings);
129124
LOG_DEBUG("Finish compile");
@@ -138,7 +133,7 @@ TEST(CppAPITests, TestCollectionListInput) {
138133

139134
std::string path =
140135
"/root/Torch-TensorRT/list_input.ts";
141-
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kFloat);
136+
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
142137
std::vector<at::Tensor> inputs;
143138
inputs.push_back(in0);
144139

@@ -173,7 +168,8 @@ TEST(CppAPITests, TestCollectionListInput) {
173168
LOG_DEBUG("Finish torchscirpt forward");
174169

175170

176-
auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kUnknown);
171+
// auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kUnknown);
172+
auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf);
177173

178174
auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
179175

@@ -194,13 +190,13 @@ TEST(CppAPITests, TestCollectionListInput) {
194190
compile_settings.torch_executed_ops.push_back("aten::__getitem__");
195191

196192
// // FP16 execution
197-
// compile_settings.enabled_precisions = {torch::kHalf};
193+
compile_settings.enabled_precisions = {torch::kHalf};
198194
// // Compile module
199195
auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings);
200196
LOG_DEBUG("Finish compile");
201197
auto trt_out = trt_mod.forward(complex_inputs);
202198
// auto trt_out = trt_mod.forward(complex_inputs_list);
203199

204-
200+
// std::cout << out.toTensor() << std::endl;
205201
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5));
206202
}

0 commit comments

Comments
 (0)