diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 669cd5ed3c..18893e1631 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -26,7 +26,7 @@ void LowerBlock(torch::jit::Block* b) { DropUnusedNodes(b); } -void LowerGraph(std::shared_ptr& g, LowerInfo lower_info) { +void LowerGraph(std::shared_ptr& g, std::vector& params, LowerInfo lower_info) { torch::jit::EliminateRedundantGuards(g); torch::jit::RemoveListMutation(g); torch::jit::RemoveTensorMutation(g); @@ -70,6 +70,7 @@ void LowerGraph(std::shared_ptr& g, LowerInfo lower_info) { passes::SiluToSigmoidMultipication(g); passes::RemoveSingleUse0DTensors(g); passes::RemoveUnnecessaryCasts(g); + passes::RewriteInputsWithParams(g, params); LOG_GRAPH(*g); } @@ -103,7 +104,7 @@ std::pair, std::vector> L // In quantization aware trained (QAT) models, weights are passed through quantize and // dequantize nodes which should not be folded. So unfreeze_module is set to True for QAT models. LOG_GRAPH("Torch-TensorRT.TorchScript Graph Lowering"); - lowering::LowerGraph(graph_and_ivalues.first, lower_info); + lowering::LowerGraph(graph_and_ivalues.first, graph_and_ivalues.second, lower_info); // Is this necessary? // lowering::LowerBlock(g->block()); diff --git a/core/lowering/passes/BUILD b/core/lowering/passes/BUILD index 436b748b90..99708b6fe4 100644 --- a/core/lowering/passes/BUILD +++ b/core/lowering/passes/BUILD @@ -27,6 +27,7 @@ cc_library( "remove_dropout.cpp", "remove_nops.cpp", "remove_unnecessary_casts.cpp", + "rewrite_inputs_with_params.cpp", "silu_to_sigmoid_multiplication.cpp", "unpack_addmm.cpp", "unpack_batch_norm.cpp", diff --git a/core/lowering/passes/CMakeLists.txt b/core/lowering/passes/CMakeLists.txt index 040931fe39..291fc03cb8 100644 --- a/core/lowering/passes/CMakeLists.txt +++ b/core/lowering/passes/CMakeLists.txt @@ -24,6 +24,7 @@ target_sources(${lib_name} "${CMAKE_CURRENT_SOURCE_DIR}/unpack_std.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/unpack_var.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/view_to_reshape.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/rewrite_inputs_with_params.cpp" ) set(HEADER_FILES diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index 196abffd11..b1346714ec 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -39,6 +39,7 @@ void UnpackVar(std::shared_ptr& graph); void AliasOperators(std::shared_ptr& graph); void SiluToSigmoidMultipication(std::shared_ptr& graph); void UnpackHardSwish(std::shared_ptr& graph); +void RewriteInputsWithParams(std::shared_ptr& g, std::vector& params); void UnpackHardSigmoid(std::shared_ptr& graph); } // namespace passes diff --git a/core/lowering/passes/rewrite_inputs_with_params.cpp b/core/lowering/passes/rewrite_inputs_with_params.cpp new file mode 100644 index 0000000000..13b71014dc --- /dev/null +++ b/core/lowering/passes/rewrite_inputs_with_params.cpp @@ -0,0 +1,40 @@ +#include "core/util/prelude.h" +#include "torch/csrc/jit/ir/constants.h" + +namespace torch_tensorrt { +namespace core { +namespace lowering { +namespace passes { + +void RewriteInputsWithParams(std::shared_ptr& g, std::vector& params) { + auto input_size = g->inputs().size(); + auto param_it = params.rbegin(); + for (int i = input_size - 1; i >= 0; --i) { + if (g->inputs()[i]->type() != c10::TensorType::get() && + g->inputs()[i]->type()->kind() != torch::jit::TypeKind::TupleType && + g->inputs()[i]->type()->kind() != torch::jit::TypeKind::ListType && param_it != params.rend()) { + auto val = *param_it; + if (val.isTensor()) { + at::Tensor val_tensor = val.toTensor(); + if (val_tensor.requires_grad()) { + val_tensor.set_requires_grad(false); + val = val_tensor; + } + } + auto new_constant = torch::jit::tryInsertConstant(*g, val); + ++param_it; + if (new_constant) { + g->inputs()[i]->replaceAllUsesWith(*new_constant); + g->eraseInput(i); + // erase an iterator, should be safe + params.erase(param_it.base()); + } + } + } + LOG_GRAPH("After RewriteInputsWithParams: " << *g); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace torch_tensorrt diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index 75ae818905..4782f235c0 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -87,6 +87,10 @@ lowering_test( name = "test_unpack_reduce_ops", ) +lowering_test( + name = "test_rewrite_inputs_with_params", +) + test_suite( name = "lowering_tests", tests = [ @@ -102,6 +106,7 @@ test_suite( ":test_remove_detach_pass", ":test_remove_dropout_pass", ":test_remove_unnecessary_casts", + ":test_rewrite_inputs_with_params", ":test_unpack_hardsigmoid", ":test_unpack_hardswish", ":test_unpack_reduce_ops", diff --git a/tests/core/lowering/test_rewrite_inputs_with_params.cpp b/tests/core/lowering/test_rewrite_inputs_with_params.cpp new file mode 100644 index 0000000000..b0d8045e7b --- /dev/null +++ b/tests/core/lowering/test_rewrite_inputs_with_params.cpp @@ -0,0 +1,33 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/csrc/jit/ir/subgraph_matcher.h" + +TEST(LoweringPasses, RewriteInputsWithParamsCorrectly) { + std::string source_graph = R"IR( + graph(%x: Tensor, %y: Tensor, %1 : Int(1)): + %out: Tensor = aten::sub(%x, %y, %1) + return (%out))IR"; + std::string target_graph = R"IR( + graph(%x: Tensor, %y : Tensor): + %2 : int = prim::Constant[value=0]() + %out: Tensor = aten::sub(%x, %y, %2) + return (%out))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + + torch::jit::IValue param0 = torch::jit::IValue(0); + std::vector params{param0}; + torch_tensorrt::core::lowering::passes::RewriteInputsWithParams(sg, params); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} \ No newline at end of file