diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index b25af1b8de..70439fb127 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -25,6 +25,7 @@ void LowerBlock(torch::jit::Block* b) { } void LowerGraph(std::shared_ptr& g) { + passes::UnpackHardSwish(g); torch::jit::EliminateRedundantGuards(g); torch::jit::RemoveListMutation(g); torch::jit::RemoveTensorMutation(g); diff --git a/core/lowering/passes/BUILD b/core/lowering/passes/BUILD index 5fd09ce328..d615af06bc 100644 --- a/core/lowering/passes/BUILD +++ b/core/lowering/passes/BUILD @@ -24,6 +24,7 @@ cc_library( "unpack_addmm.cpp", "unpack_batch_norm.cpp", "unpack_log_softmax.cpp", + "unpack_hardswish.cpp" ], hdrs = [ "passes.h", diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index df204df918..db2bcb2ebe 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -21,6 +21,7 @@ void UnpackBatchNorm(std::shared_ptr& graph); void UnpackLogSoftmax(std::shared_ptr& graph); void AliasOperators(std::shared_ptr& graph); void SiluToSigmoidMultipication(std::shared_ptr& graph); +void UnpackHardSwish(std::shared_ptr& graph); } // namespace passes } // namespace lowering diff --git a/core/lowering/passes/unpack_hardswish.cpp b/core/lowering/passes/unpack_hardswish.cpp new file mode 100644 index 0000000000..7e460b67ec --- /dev/null +++ b/core/lowering/passes/unpack_hardswish.cpp @@ -0,0 +1,44 @@ +#include "torch/csrc/jit/passes/subgraph_rewrite.h" + +#include "core/util/prelude.h" + +namespace trtorch { +namespace core { +namespace lowering { +namespace passes { + +void UnpackHardSwish(std::shared_ptr& graph) { + std::string hardswish_pattern = R"IR( + graph(%input): + %result = aten::hardswish(%input) + return (%result))IR"; + + std::string hardswish_pattern_inplace = R"IR( + graph(%input): + %result = aten::hardswish_(%input) + return (%result))IR"; + + std::string new_pattern = R"IR( + graph(%input): + %1 : Scalar = prim::Constant[value=3.]() + %2 : Scalar = prim::Constant[value=1.]() + %3 = aten::add(%input, %1, %2) + %4 : Scalar = prim::Constant[value=0.]() + %5 : Scalar = prim::Constant[value=6.]() + %6 = aten::hardtanh(%3, %4, %5) + %7 = aten::div(%6, %5) + %8 = aten::mul(%input, %7) + return (%8))IR"; + + torch::jit::SubgraphRewriter rewriter; + rewriter.RegisterRewritePattern(hardswish_pattern, new_pattern); + rewriter.RegisterRewritePattern(hardswish_pattern_inplace, new_pattern); + rewriter.runOnGraph(graph); + + LOG_GRAPH("Post unpack hardswish: " << *graph); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace trtorch diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index fb5e0cbd5a..7dfdd4cbe8 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -35,6 +35,10 @@ lowering_test( name = "test_silu_to_sigmoid_multiplication", ) +lowering_test( + name = "test_unpack_hardswish", +) + test_suite( name = "lowering_tests", tests = [ @@ -44,5 +48,6 @@ test_suite( ":test_remove_detach_pass", ":test_remove_dropout_pass", ":test_remove_to", + ":test_unpack_hardswish" ], ) diff --git a/tests/core/lowering/test_unpack_hardswish.cpp b/tests/core/lowering/test_unpack_hardswish.cpp new file mode 100644 index 0000000000..d60d3d0ae9 --- /dev/null +++ b/tests/core/lowering/test_unpack_hardswish.cpp @@ -0,0 +1,87 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/csrc/jit/ir/subgraph_matcher.h" + +TEST(LoweringPasses, UnpackHardSwish) { + std::string source_graph = R"IR( + graph(%input): + %result = aten::hardswish(%input) + return (%result))IR"; + + std::string target_graph = R"IR( + graph(%input): + %1 : Scalar = prim::Constant[value=3.]() + %2 : Scalar = prim::Constant[value=1.]() + %3 = aten::add(%input, %1, %2) + %4 : Scalar = prim::Constant[value=0.]() + %5 : Scalar = prim::Constant[value=6.]() + %6 = aten::hardtanh(%3, %4, %5) + %7 = aten::div(%6, %5) + %8 = aten::mul(%input, %7) + return (%8))IR"; + + trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + + auto in = at::rand({10, 100}, {at::kCUDA}); + auto sg_params = trtorch::core::conversion::get_named_params(sg->inputs(), {}); + auto sg_results = trtorch::tests::util::RunGraph(sg, sg_params, {in}); + + trtorch::core::lowering::passes::UnpackHardSwish(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); + + in = at::clone(in); + auto tg_params = trtorch::core::conversion::get_named_params(tg->inputs(), {}); + auto tg_results = trtorch::tests::util::RunGraph(tg, tg_params, {in}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(sg_results[0], tg_results[0], 2e-6)); +} + +TEST(LoweringPasses, UnpackHardInplaceSwish) { + std::string source_graph = R"IR( + graph(%input): + %result = aten::hardswish_(%input) + return (%result))IR"; + + std::string target_graph = R"IR( + graph(%input): + %1 : Scalar = prim::Constant[value=3.]() + %2 : Scalar = prim::Constant[value=1.]() + %3 = aten::add(%input, %1, %2) + %4 : Scalar = prim::Constant[value=0.]() + %5 : Scalar = prim::Constant[value=6.]() + %6 = aten::hardtanh(%3, %4, %5) + %7 = aten::div(%6, %5) + %8 = aten::mul(%input, %7) + return (%8))IR"; + + trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + + auto in = at::rand({10, 100}, {at::kCUDA}); + auto sg_params = trtorch::core::conversion::get_named_params(sg->inputs(), {}); + auto sg_results = trtorch::tests::util::RunGraph(sg, sg_params, {in}); + + trtorch::core::lowering::passes::UnpackHardSwish(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); + + in = at::clone(in); + auto tg_params = trtorch::core::conversion::get_named_params(tg->inputs(), {}); + auto tg_results = trtorch::tests::util::RunGraph(tg, tg_params, {in}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(sg_results[0], tg_results[0], 2e-6)); +} \ No newline at end of file