@@ -56,54 +56,59 @@ int main(int argc, const char* argv[]) {
5656 }
5757
5858 auto compile_spec = trtorch::CompileSpec (dims);
59+ // compile_spec.torch_fallback = trtorch::CompileSpec::TorchFallback(true);
5960 compile_spec.workspace_size = 1 << 24 ;
60-
61- std::cout << " Checking operator support" << std::endl;
62- if (!trtorch::CheckMethodOperatorSupport (mod, " forward" )) {
63- std::cerr << " Method is not currently supported by TRTorch" << std::endl;
64- return -1 ;
65- }
66-
67- std::cout << " Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
61+ compile_spec.op_precision = torch::kChar ;
62+ // compile_spec.input_dtypes = {torch::kInt32, torch::kInt32};
63+ // std::cout << "===Compile Spec: " << compile_spec << std::endl;
64+ // compile_spec.torch_fallback = trtorch::CompileSpec::TorchFallback(true);
65+ // compile_spec.torch_fallback.min_block_size = 1;
66+ // std::cout << "Checking operator support" << std::endl;
67+ // if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) {
68+ // std::cerr << "Method is not currently supported by TRTorch" << std::endl;
69+ // return -1;
70+ // }
71+ //
72+ // std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
6873 auto engine = trtorch::ConvertGraphToTRTEngine (mod, " forward" , compile_spec);
6974 std::ofstream out (" /tmp/engine_converted_from_jit.trt" );
7075 out << engine;
7176 out.close ();
7277
73- std::vector<torch::jit::IValue> jit_inputs_ivalues;
74- std::vector<torch::jit::IValue> trt_inputs_ivalues;
75- auto in = at::randint (5 , dims[0 ], {at::kCUDA });
76- jit_inputs_ivalues.push_back (in.clone ());
77- trt_inputs_ivalues.push_back (in.clone ());
78-
79- torch::jit::IValue jit_results_ivalues = mod.forward (jit_inputs_ivalues);
80- std::vector<at::Tensor> jit_results;
81- if (jit_results_ivalues.isTensor ()) {
82- jit_results.push_back (jit_results_ivalues.toTensor ());
83- } else {
84- auto results = jit_results_ivalues.toTuple ()->elements ();
85- for (auto r : results) {
86- jit_results.push_back (r.toTensor ());
87- }
88- }
78+ // std::vector<torch::jit::IValue> jit_inputs_ivalues;
79+ // std::vector<torch::jit::IValue> trt_inputs_ivalues;
80+ // auto in = at::randint(5, dims[0], {at::kCUDA});
81+ // jit_inputs_ivalues.push_back(in.clone());
82+ // trt_inputs_ivalues.push_back(in.clone());
83+ // //
84+ // torch::jit::IValue jit_results_ivalues = mod.forward(jit_inputs_ivalues);
85+ // std::vector<at::Tensor> jit_results;
86+ // if (jit_results_ivalues.isTensor()) {
87+ // jit_results.push_back(jit_results_ivalues.toTensor());
88+ // } else {
89+ // auto results = jit_results_ivalues.toTuple()->elements();
90+ // for (auto r : results) {
91+ // jit_results.push_back(r.toTensor());
92+ // }
93+ // }
8994
9095 std::cout << " Compiling graph as module" << std::endl;
9196 auto trt_mod = trtorch::CompileGraph (mod, compile_spec);
92- std::cout << " Running TRT module" << std::endl;
93- torch::jit::IValue trt_results_ivalues = trt_mod.forward (trt_inputs_ivalues);
94- std::vector<at::Tensor> trt_results;
95- if (trt_results_ivalues.isTensor ()) {
96- trt_results.push_back (trt_results_ivalues.toTensor ());
97- } else {
98- auto results = trt_results_ivalues.toTuple ()->elements ();
99- for (auto r : results) {
100- trt_results.push_back (r.toTensor ());
101- }
102- }
103-
104- for (size_t i = 0 ; i < trt_results.size (); i++) {
105- almostEqual (jit_results[i], trt_results[i].reshape_as (jit_results[i]));
106- }
97+ // std::cout << "Running TRT module" << std::endl;
98+ // torch::jit::IValue trt_results_ivalues = trt_mod.forward(trt_inputs_ivalues);
99+ // std::vector<at::Tensor> trt_results;
100+ // if (trt_results_ivalues.isTensor()) {
101+ // trt_results.push_back(trt_results_ivalues.toTensor());
102+ // } else {
103+ // auto results = trt_results_ivalues.toTuple()->elements();
104+ // for (auto r : results) {
105+ // trt_results.push_back(r.toTensor());
106+ // }
107+ // }
108+ //
109+ // for (size_t i = 0; i < trt_results.size(); i++) {
110+ // almostEqual(jit_results[i], trt_results[i].reshape_as(jit_results[i]));
111+ // }
107112
108113 std::cout << " Converted Engine saved to /tmp/engine_converted_from_jit.trt" << std::endl;
109114
0 commit comments