diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 5442440422..0c0ed5ceef 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -60,6 +60,7 @@ void LowerGraph(std::shared_ptr& g, LowerInfo lower_info) { passes::UnpackAddMM(g); // passes::UnpackBatchNorm(g); passes::UnpackLogSoftmax(g); + passes::UnpackRsqrt(g); passes::UnpackStd(g); passes::UnpackVar(g); passes::RemoveNOPs(g); diff --git a/core/lowering/passes/BUILD b/core/lowering/passes/BUILD index d5f3616f8d..436b748b90 100644 --- a/core/lowering/passes/BUILD +++ b/core/lowering/passes/BUILD @@ -33,6 +33,7 @@ cc_library( "unpack_hardsigmoid.cpp", "unpack_hardswish.cpp", "unpack_log_softmax.cpp", + "unpack_rsqrt.cpp", "unpack_std.cpp", "unpack_var.cpp", "view_to_reshape.cpp", diff --git a/core/lowering/passes/CMakeLists.txt b/core/lowering/passes/CMakeLists.txt index 48e644a70d..040931fe39 100644 --- a/core/lowering/passes/CMakeLists.txt +++ b/core/lowering/passes/CMakeLists.txt @@ -20,6 +20,7 @@ target_sources(${lib_name} "${CMAKE_CURRENT_SOURCE_DIR}/unpack_hardsigmoid.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/unpack_hardswish.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/unpack_log_softmax.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/unpack_rsqrt.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/unpack_std.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/unpack_var.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/view_to_reshape.cpp" diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index 3b946593e2..196abffd11 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -33,6 +33,7 @@ void RemoveUnnecessaryCasts(std::shared_ptr& graph); void UnpackAddMM(std::shared_ptr& graph); void UnpackBatchNorm(std::shared_ptr& graph); void UnpackLogSoftmax(std::shared_ptr& graph); +void UnpackRsqrt(std::shared_ptr& graph); void UnpackStd(std::shared_ptr& graph); void UnpackVar(std::shared_ptr& graph); void AliasOperators(std::shared_ptr& graph); diff --git a/core/lowering/passes/unpack_rsqrt.cpp b/core/lowering/passes/unpack_rsqrt.cpp new file mode 100644 index 0000000000..f21f1c7ea4 --- /dev/null +++ b/core/lowering/passes/unpack_rsqrt.cpp @@ -0,0 +1,30 @@ +#include "torch/csrc/jit/passes/subgraph_rewrite.h" + +#include "core/util/prelude.h" + +namespace torch_tensorrt { +namespace core { +namespace lowering { +namespace passes { + +void UnpackRsqrt(std::shared_ptr& graph) { + std::string rsqrt_pattern = R"IR( + graph(%1): + %out: Tensor = aten::rsqrt(%1) + return (%out))IR"; + std::string unpacked_pattern = R"IR( + graph(%1): + %intermediate: Tensor = aten::sqrt(%1) + %out: Tensor = aten::reciprocal(%intermediate) + return (%out))IR"; + + torch::jit::SubgraphRewriter rsqrt_rewriter; + rsqrt_rewriter.RegisterRewritePattern(rsqrt_pattern, unpacked_pattern); + rsqrt_rewriter.runOnGraph(graph); + LOG_GRAPH("Post unpack rsqrt: " << *graph); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace torch_tensorrt diff --git a/tests/core/lowering/test_unpack_reduce_ops.cpp b/tests/core/lowering/test_unpack_reduce_ops.cpp index 146e49891a..8082a4437a 100644 --- a/tests/core/lowering/test_unpack_reduce_ops.cpp +++ b/tests/core/lowering/test_unpack_reduce_ops.cpp @@ -202,3 +202,45 @@ TEST(LoweringPasses, UnpackStdUnbiasedKeepDimsLowersCorrectly) { ASSERT_TRUE( torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); } + +TEST(LoweringPasses, UnpackRsqrtLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : Tensor = aten::rsqrt(%x.1) + return (%2))IR"; + + // Make range [0.01, 1.01] to ensure positives / avoid NaN with negative sqrt + auto in = at::rand({2, 3, 5, 7}, {at::kCUDA}) + 0.01; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + torch_tensorrt::core::lowering::passes::UnpackRsqrt(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +} + +TEST(LoweringPasses, UnpackRsqrtIntLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : Tensor = aten::rsqrt(%x.1) + return (%2))IR"; + + // Make range of ints [1, 10] + auto in = at::randint(1, 11, {2, 3, 5, 7}, {at::kCUDA}); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + torch_tensorrt::core::lowering::passes::UnpackRsqrt(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +}