From 05bf80cf65b13320b527ebaf590aca3778991485 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Fri, 19 Aug 2022 18:15:42 -0700 Subject: [PATCH 1/5] feat: rewriting param to a Constant if it's a introduced input Signed-off-by: Bo Wang --- core/compiler.cpp | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/core/compiler.cpp b/core/compiler.cpp index 7b58dbb2c1..9acdd5ab4c 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -11,6 +11,7 @@ #include "torch/csrc/jit/frontend/function_schema_parser.h" #include "torch/csrc/jit/ir/ir.h" +#include "torch/csrc/jit/ir/constants.h" #include "torch/csrc/jit/ir/ir_views.h" #include "torch/csrc/jit/passes/graph_fuser.h" #include "torch/csrc/jit/passes/loop_unrolling.h" @@ -28,6 +29,22 @@ namespace torch_tensorrt { namespace core { +void RewriteInputsWithParams(std::shared_ptr g, std::vector params) { + auto input_size = g->inputs().size(); + auto param_it = params.rbegin(); + for (int i = input_size - 1; i >= 0; --i) { + if (g->inputs()[i]->type() != c10::TensorType::get() && g->inputs()[i]->type()->kind() != torch::jit::TypeKind::TupleType && + g->inputs()[i]->type()->kind() != torch::jit::TypeKind::ListType && param_it != params.rend()) { + auto new_constant = torch::jit::tryInsertConstant(*g, *param_it); + ++param_it; + if (new_constant) { + g->inputs()[i]->replaceAllUsesWith(*new_constant); + g->eraseInput(i); + } + } + } +} + void AddEngineToGraph( torch::jit::script::Module mod, std::shared_ptr& g, @@ -434,6 +451,9 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) (!(cfg.lower_info.forced_fallback_modules.size() == 0 && cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible) || outputIsCollection)) { + if (!static_params.empty()) { + RewriteInputsWithParams(g, params); + } std::unordered_map fallback_nodes; auto collection_input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.collection_input_spec_map, first_use_types); From ab977f5bdc666304717b3f3cbd5af037ae41e3ef Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Thu, 8 Sep 2022 18:06:01 -0700 Subject: [PATCH 2/5] fix: deal with edge cases when introduced value is Tensor with gradient Signed-off-by: Bo Wang --- core/compiler.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/core/compiler.cpp b/core/compiler.cpp index 9acdd5ab4c..6eb9ce4e7e 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -10,8 +10,8 @@ #include "ATen/core/jit_type.h" #include "torch/csrc/jit/frontend/function_schema_parser.h" -#include "torch/csrc/jit/ir/ir.h" #include "torch/csrc/jit/ir/constants.h" +#include "torch/csrc/jit/ir/ir.h" #include "torch/csrc/jit/ir/ir_views.h" #include "torch/csrc/jit/passes/graph_fuser.h" #include "torch/csrc/jit/passes/loop_unrolling.h" @@ -33,9 +33,18 @@ void RewriteInputsWithParams(std::shared_ptr g, std::vectorinputs().size(); auto param_it = params.rbegin(); for (int i = input_size - 1; i >= 0; --i) { - if (g->inputs()[i]->type() != c10::TensorType::get() && g->inputs()[i]->type()->kind() != torch::jit::TypeKind::TupleType && + if (g->inputs()[i]->type() != c10::TensorType::get() && + g->inputs()[i]->type()->kind() != torch::jit::TypeKind::TupleType && g->inputs()[i]->type()->kind() != torch::jit::TypeKind::ListType && param_it != params.rend()) { - auto new_constant = torch::jit::tryInsertConstant(*g, *param_it); + auto val = *param_it; + if (val.isTensor()) { + at::Tensor val_tensor = val.toTensor(); + if (val_tensor.requires_grad()) { + val_tensor.set_requires_grad(false); + val = val_tensor; + } + } + auto new_constant = torch::jit::tryInsertConstant(*g, val); ++param_it; if (new_constant) { g->inputs()[i]->replaceAllUsesWith(*new_constant); From f2ef0eb387a6aa7b8ecce2b8175f8dea24366c33 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Wed, 14 Sep 2022 16:49:38 -0700 Subject: [PATCH 3/5] refactor: refactor RewriteInputsWithParams() to a lowering pass Signed-off-by: Bo Wang --- core/compiler.cpp | 29 ------------- core/lowering/lowering.cpp | 5 ++- core/lowering/passes/BUILD | 1 + core/lowering/passes/CMakeLists.txt | 1 + core/lowering/passes/passes.h | 1 + .../passes/rewrite_inputs_with_params.cpp | 41 +++++++++++++++++++ 6 files changed, 47 insertions(+), 31 deletions(-) create mode 100644 core/lowering/passes/rewrite_inputs_with_params.cpp diff --git a/core/compiler.cpp b/core/compiler.cpp index 6eb9ce4e7e..7b58dbb2c1 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -10,7 +10,6 @@ #include "ATen/core/jit_type.h" #include "torch/csrc/jit/frontend/function_schema_parser.h" -#include "torch/csrc/jit/ir/constants.h" #include "torch/csrc/jit/ir/ir.h" #include "torch/csrc/jit/ir/ir_views.h" #include "torch/csrc/jit/passes/graph_fuser.h" @@ -29,31 +28,6 @@ namespace torch_tensorrt { namespace core { -void RewriteInputsWithParams(std::shared_ptr g, std::vector params) { - auto input_size = g->inputs().size(); - auto param_it = params.rbegin(); - for (int i = input_size - 1; i >= 0; --i) { - if (g->inputs()[i]->type() != c10::TensorType::get() && - g->inputs()[i]->type()->kind() != torch::jit::TypeKind::TupleType && - g->inputs()[i]->type()->kind() != torch::jit::TypeKind::ListType && param_it != params.rend()) { - auto val = *param_it; - if (val.isTensor()) { - at::Tensor val_tensor = val.toTensor(); - if (val_tensor.requires_grad()) { - val_tensor.set_requires_grad(false); - val = val_tensor; - } - } - auto new_constant = torch::jit::tryInsertConstant(*g, val); - ++param_it; - if (new_constant) { - g->inputs()[i]->replaceAllUsesWith(*new_constant); - g->eraseInput(i); - } - } - } -} - void AddEngineToGraph( torch::jit::script::Module mod, std::shared_ptr& g, @@ -460,9 +434,6 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) (!(cfg.lower_info.forced_fallback_modules.size() == 0 && cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible) || outputIsCollection)) { - if (!static_params.empty()) { - RewriteInputsWithParams(g, params); - } std::unordered_map fallback_nodes; auto collection_input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.collection_input_spec_map, first_use_types); diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 8bbae296c3..fedd311ac7 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -25,7 +25,7 @@ void LowerBlock(torch::jit::Block* b) { DropUnusedNodes(b); } -void LowerGraph(std::shared_ptr& g, LowerInfo lower_info) { +void LowerGraph(std::shared_ptr& g, std::vector& params, LowerInfo lower_info) { torch::jit::EliminateRedundantGuards(g); torch::jit::RemoveListMutation(g); torch::jit::RemoveTensorMutation(g); @@ -66,6 +66,7 @@ void LowerGraph(std::shared_ptr& g, LowerInfo lower_info) { passes::SiluToSigmoidMultipication(g); passes::RemoveSingleUse0DTensors(g); passes::RemoveUnnecessaryCasts(g); + passes::RewriteInputsWithParams(g, params); LOG_GRAPH(*g); } @@ -99,7 +100,7 @@ std::pair, std::vector> L // In quantization aware trained (QAT) models, weights are passed through quantize and // dequantize nodes which should not be folded. So unfreeze_module is set to True for QAT models. LOG_GRAPH("Torch-TensorRT.TorchScript Graph Lowering"); - lowering::LowerGraph(graph_and_ivalues.first, lower_info); + lowering::LowerGraph(graph_and_ivalues.first, graph_and_ivalues.second, lower_info); // Is this necessary? // lowering::LowerBlock(g->block()); diff --git a/core/lowering/passes/BUILD b/core/lowering/passes/BUILD index 1f6a0cde8f..b2a90ffb18 100644 --- a/core/lowering/passes/BUILD +++ b/core/lowering/passes/BUILD @@ -27,6 +27,7 @@ cc_library( "remove_dropout.cpp", "remove_nops.cpp", "remove_unnecessary_casts.cpp", + "rewrite_inputs_with_params.cpp", "silu_to_sigmoid_multiplication.cpp", "unpack_addmm.cpp", "unpack_batch_norm.cpp", diff --git a/core/lowering/passes/CMakeLists.txt b/core/lowering/passes/CMakeLists.txt index a8cda65e71..fd34ff8d6c 100644 --- a/core/lowering/passes/CMakeLists.txt +++ b/core/lowering/passes/CMakeLists.txt @@ -22,6 +22,7 @@ target_sources(${lib_name} "${CMAKE_CURRENT_SOURCE_DIR}/unpack_std.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/unpack_var.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/view_to_reshape.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/rewrite_inputs_with_params.cpp" ) set(HEADER_FILES diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index 73bd9f61d7..6668be7812 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -38,6 +38,7 @@ void UnpackVar(std::shared_ptr& graph); void AliasOperators(std::shared_ptr& graph); void SiluToSigmoidMultipication(std::shared_ptr& graph); void UnpackHardSwish(std::shared_ptr& graph); +void RewriteInputsWithParams(std::shared_ptr& g, std::vector& params); } // namespace passes } // namespace lowering diff --git a/core/lowering/passes/rewrite_inputs_with_params.cpp b/core/lowering/passes/rewrite_inputs_with_params.cpp new file mode 100644 index 0000000000..a05a9050c2 --- /dev/null +++ b/core/lowering/passes/rewrite_inputs_with_params.cpp @@ -0,0 +1,41 @@ +#include "torch/csrc/jit/ir/constants.h" +#include "core/util/prelude.h" + + +namespace torch_tensorrt { +namespace core { +namespace lowering { +namespace passes { + + +void RewriteInputsWithParams(std::shared_ptr& g, std::vector& params) { + auto input_size = g->inputs().size(); + auto param_it = params.rbegin(); + for (int i = input_size - 1; i >= 0; --i) { + if (g->inputs()[i]->type() != c10::TensorType::get() && + g->inputs()[i]->type()->kind() != torch::jit::TypeKind::TupleType && + g->inputs()[i]->type()->kind() != torch::jit::TypeKind::ListType && param_it != params.rend()) { + auto val = *param_it; + if (val.isTensor()) { + at::Tensor val_tensor = val.toTensor(); + if (val_tensor.requires_grad()) { + val_tensor.set_requires_grad(false); + val = val_tensor; + } + } + auto new_constant = torch::jit::tryInsertConstant(*g, val); + ++param_it; + if (new_constant) { + g->inputs()[i]->replaceAllUsesWith(*new_constant); + g->eraseInput(i); + // erase an iterator, should be safe + params.erase(param_it.base()); + } + } + } +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace torch_tensorrt From 318281ddbce070fa4c1f33b9ab8bbdd3ea603d24 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Fri, 14 Oct 2022 00:03:03 -0700 Subject: [PATCH 4/5] test: add test case for rewrite input with params pass Signed-off-by: Bo Wang --- .../passes/rewrite_inputs_with_params.cpp | 5 ++- tests/core/lowering/BUILD | 5 +++ .../test_rewrite_inputs_with_params.cpp | 33 +++++++++++++++++++ 3 files changed, 40 insertions(+), 3 deletions(-) create mode 100644 tests/core/lowering/test_rewrite_inputs_with_params.cpp diff --git a/core/lowering/passes/rewrite_inputs_with_params.cpp b/core/lowering/passes/rewrite_inputs_with_params.cpp index a05a9050c2..13b71014dc 100644 --- a/core/lowering/passes/rewrite_inputs_with_params.cpp +++ b/core/lowering/passes/rewrite_inputs_with_params.cpp @@ -1,13 +1,11 @@ -#include "torch/csrc/jit/ir/constants.h" #include "core/util/prelude.h" - +#include "torch/csrc/jit/ir/constants.h" namespace torch_tensorrt { namespace core { namespace lowering { namespace passes { - void RewriteInputsWithParams(std::shared_ptr& g, std::vector& params) { auto input_size = g->inputs().size(); auto param_it = params.rbegin(); @@ -33,6 +31,7 @@ void RewriteInputsWithParams(std::shared_ptr& g, std::vector< } } } + LOG_GRAPH("After RewriteInputsWithParams: " << *g); } } // namespace passes diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index 75ae818905..4782f235c0 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -87,6 +87,10 @@ lowering_test( name = "test_unpack_reduce_ops", ) +lowering_test( + name = "test_rewrite_inputs_with_params", +) + test_suite( name = "lowering_tests", tests = [ @@ -102,6 +106,7 @@ test_suite( ":test_remove_detach_pass", ":test_remove_dropout_pass", ":test_remove_unnecessary_casts", + ":test_rewrite_inputs_with_params", ":test_unpack_hardsigmoid", ":test_unpack_hardswish", ":test_unpack_reduce_ops", diff --git a/tests/core/lowering/test_rewrite_inputs_with_params.cpp b/tests/core/lowering/test_rewrite_inputs_with_params.cpp new file mode 100644 index 0000000000..2f0341cabb --- /dev/null +++ b/tests/core/lowering/test_rewrite_inputs_with_params.cpp @@ -0,0 +1,33 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/csrc/jit/ir/subgraph_matcher.h" + +TEST(LoweringPasses, RewriteInputsWithParamsCorrectly) { + std::string source_graph = R"IR( + graph(%x: Tensor, %y: Tensor, %1 : Int(1)): + %out: Tensor = aten::sub(%x, %y, %1) + return (%out))IR"; + std::string target_graph = R"IR( + graph(%x: Tensor, %y : Tensor): + %2 : int = prim::Constant[value=0]() + %out: Tensor = aten::sub(%x, %y, %2) + return (%out))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + + torch::jit::IValue param0 = torch::jit::IValue(0); + std::vector params{param0}; + torch_tensorrt::core::lowering::passes::RewriteInputsWithParams(sg, params); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} \ No newline at end of file From ce63baaa533c3c5c8a5fa4ee58d01f432477272c Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Fri, 14 Oct 2022 00:07:05 -0700 Subject: [PATCH 5/5] chore: apply linting Signed-off-by: Bo Wang --- tests/core/lowering/test_rewrite_inputs_with_params.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/lowering/test_rewrite_inputs_with_params.cpp b/tests/core/lowering/test_rewrite_inputs_with_params.cpp index 2f0341cabb..b0d8045e7b 100644 --- a/tests/core/lowering/test_rewrite_inputs_with_params.cpp +++ b/tests/core/lowering/test_rewrite_inputs_with_params.cpp @@ -25,7 +25,7 @@ TEST(LoweringPasses, RewriteInputsWithParamsCorrectly) { torch::jit::IValue param0 = torch::jit::IValue(0); std::vector params{param0}; torch_tensorrt::core::lowering::passes::RewriteInputsWithParams(sg, params); - + auto tg = std::make_shared(); torch::jit::parseIR(target_graph, &*tg);