diff --git a/core/conversion/conversionctx/BUILD b/core/conversion/conversionctx/BUILD index bf76b9b905..6626ae4457 100644 --- a/core/conversion/conversionctx/BUILD +++ b/core/conversion/conversionctx/BUILD @@ -21,6 +21,7 @@ cc_library( deps = [ "@tensorrt//:nvinfer", "//core/util:prelude", + "//core/ir", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], "//conditions:default": ["@libtorch//:libtorch"], diff --git a/core/conversion/conversionctx/ConversionCtx.h b/core/conversion/conversionctx/ConversionCtx.h index 5f8d6e955b..f7ed5c89d2 100644 --- a/core/conversion/conversionctx/ConversionCtx.h +++ b/core/conversion/conversionctx/ConversionCtx.h @@ -9,20 +9,13 @@ #include "torch/csrc/jit/ir/ir.h" #include +#include "core/ir/ir.h" #include "core/util/prelude.h" namespace torch_tensorrt { namespace core { namespace conversion { -struct Device { - nvinfer1::DeviceType device_type; - int64_t gpu_id; - int64_t dla_core; - bool allow_gpu_fallback; - Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {} -}; - struct BuilderSettings { std::set enabled_precisions = {}; bool sparse_weights = false; @@ -30,7 +23,7 @@ struct BuilderSettings { bool refit = false; bool debug = false; bool truncate_long_and_double = false; - Device device; + ir::Device device; nvinfer1::EngineCapability capability = TRT_ENGINE_CAPABILITY_STANDARD; nvinfer1::IInt8Calibrator* calibrator = nullptr; uint64_t num_avg_timing_iters = 1; diff --git a/core/ir/ir.h b/core/ir/ir.h index a5225daa25..6c78908d5b 100644 --- a/core/ir/ir.h +++ b/core/ir/ir.h @@ -11,6 +11,14 @@ namespace torch_tensorrt { namespace core { namespace ir { +struct Device { + nvinfer1::DeviceType device_type; + int64_t gpu_id; + int64_t dla_core; + bool allow_gpu_fallback; + Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {} +}; + struct Input : torch::CustomClassHolder { Input(){}; Input( diff --git a/core/lowering/BUILD b/core/lowering/BUILD index ae0f39032a..9c08752bb5 100644 --- a/core/lowering/BUILD +++ b/core/lowering/BUILD @@ -24,6 +24,7 @@ cc_library( deps = [ "//core/lowering/passes", "//core/util:prelude", + "//core/ir", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], "//conditions:default": ["@libtorch//:libtorch"], diff --git a/core/lowering/CMakeLists.txt b/core/lowering/CMakeLists.txt index 445f627473..af9323784c 100644 --- a/core/lowering/CMakeLists.txt +++ b/core/lowering/CMakeLists.txt @@ -15,6 +15,8 @@ set(HEADER_FILES target_sources(${lib_name} PRIVATE ${CXX_SRCS} + PUBLIC + $ $ ) @@ -25,8 +27,9 @@ target_include_directories(${lib_name} target_link_libraries(${lib_name} PUBLIC + TensorRT::nvinfer torch - PRIVATE + core_ir core_util ) diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 18893e1631..4d665b390a 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -70,6 +70,10 @@ void LowerGraph(std::shared_ptr& g, std::vector +#include "core/ir/ir.h" #include "torch/csrc/jit/ir/ir.h" namespace torch_tensorrt { @@ -15,8 +16,13 @@ struct LowerInfo { // Since these QDQ nodes will be identical as they share same input, one of them is eliminated due to CSE lowering // pass. Disable this in order to not disturb TensorRT's QAT optimizations. bool disable_cse = false; + ir::Device target_device; std::vector forced_fallback_modules; friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l); + + std::string getGPUDeviceString() { + return "cuda:" + std::to_string(target_device.gpu_id); + }; }; void LowerBlock(torch::jit::Block* b); diff --git a/core/lowering/passes/BUILD b/core/lowering/passes/BUILD index 99708b6fe4..f1fbf60ef7 100644 --- a/core/lowering/passes/BUILD +++ b/core/lowering/passes/BUILD @@ -14,6 +14,7 @@ cc_library( name = "passes", srcs = [ "convNd_to_convolution.cpp", + "device_casting.cpp", "exception_elimination.cpp", "fuse_addmm_branches.cpp", "linear_to_addmm.cpp", diff --git a/core/lowering/passes/CMakeLists.txt b/core/lowering/passes/CMakeLists.txt index 291fc03cb8..4c9ebc7efa 100644 --- a/core/lowering/passes/CMakeLists.txt +++ b/core/lowering/passes/CMakeLists.txt @@ -1,5 +1,6 @@ target_sources(${lib_name} PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/convNd_to_convolution.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/device_casting.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/exception_elimination.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/fuse_addmm_branches.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/linear_to_addmm.cpp" diff --git a/core/lowering/passes/device_casting.cpp b/core/lowering/passes/device_casting.cpp new file mode 100644 index 0000000000..eee27336d4 --- /dev/null +++ b/core/lowering/passes/device_casting.cpp @@ -0,0 +1,121 @@ +#include "torch/csrc/jit/ir/constants.h" +#include "torch/csrc/jit/passes/subgraph_rewrite.h" + +#include "core/util/prelude.h" + +namespace torch_tensorrt { +namespace core { +namespace lowering { +namespace passes { + +void UnpackAndCastMaskedFill(std::shared_ptr& graph, std::string target_device_name) { + std::string masked_fill_pattern = R"IR( + graph(%self, %mask, %value): + %out: Tensor = aten::masked_fill_(%self, %mask, %value) + return (%out))IR"; + + // Calls to masked_fill_ often utilize CPU tensors, and as such + // should be moved to gpu to avoid device mismatch errors + + // Separate string into portions to insert device name + std::string clean_pattern_part_1 = R"IR( + graph(%self, %mask, %value): + %device: Device = prim::Constant[value=")IR"; + + std::string clean_pattern_part_2 = R"IR("]() + %dtype: NoneType = prim::Constant() + %false: bool = prim::Constant[value=0]() + %mask_cuda: Tensor = aten::to(%mask, %device, %dtype, %false, %false) + %self_cuda: Tensor = aten::to(%self, %device, %dtype, %false, %false) + %out: Tensor = aten::masked_fill(%self_cuda, %mask_cuda, %value) + return (%out))IR"; + + auto unpacked_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2; + + torch::jit::SubgraphRewriter masked_fill_rewriter; + masked_fill_rewriter.RegisterRewritePattern(masked_fill_pattern, unpacked_pattern); + masked_fill_rewriter.runOnGraph(graph); + LOG_GRAPH("After unpack and cast masked_fill_: " << *graph); +} + +void UnpackAndCastNumToTensor(std::shared_ptr& graph, std::string target_device_name) { + std::string num_to_tensor_cast_pattern = R"IR( + graph(%1: Scalar): + %2: Tensor = prim::NumToTensor(%1) + return (%2))IR"; + + // 0D Tensors are initialized on cpu, and need to be moved to gpu + // to avoid device mismatch issues + + // Separate string into portions to insert device name + std::string clean_pattern_part_1 = R"IR( + graph(%1: Scalar): + %2: Tensor = prim::NumToTensor(%1) + %device: Device = prim::Constant[value=")IR"; + + std::string clean_pattern_part_2 = R"IR("]() + %dtype: NoneType = prim::Constant() + %false: bool = prim::Constant[value=0]() + %3: Tensor = aten::to(%2, %device, %dtype, %false, %false) + return (%3))IR"; + + auto num_to_tensor_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2; + + torch::jit::SubgraphRewriter num_to_tensor_cast_rewriter; + num_to_tensor_cast_rewriter.RegisterRewritePattern(num_to_tensor_cast_pattern, num_to_tensor_clean_pattern); + num_to_tensor_cast_rewriter.runOnGraph(graph); + + LOG_GRAPH("After unpack and cast NumToTensor: " << *graph); +} + +void UnpackAndCastFull(std::shared_ptr& graph, std::string target_device_name) { + std::string full_cast_pattern = R"IR( + graph(%1, %2, %3, %4, %5, %6): + %out: Tensor = aten::full(%1, %2, %3, %4, %5, %6) + return (%out))IR"; + + // Tensors created via aten::full are initialized on cpu, and need to be casted to gpu + // to avoid device mismatch issues + + // Separate string into portions to insert device name + std::string clean_pattern_part_1 = R"IR( + graph(%1, %2, %3, %4, %5, %6): + %device: Device = prim::Constant[value=")IR"; + + std::string clean_pattern_part_2 = R"IR("]() + %out: Tensor = aten::full(%1, %2, %3, %4, %device, %6) + return (%out))IR"; + + auto full_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2; + + torch::jit::SubgraphRewriter full_cast_rewriter; + full_cast_rewriter.RegisterRewritePattern(full_cast_pattern, full_clean_pattern); + full_cast_rewriter.runOnGraph(graph); + + LOG_GRAPH("After unpack and cast full: " << *graph); +} + +void ReplaceScalarImplicit(std::shared_ptr& graph) { + std::string scalar_implicit_cast_pattern = R"IR( + graph(%1: Tensor): + %2: Scalar = aten::ScalarImplicit(%1) + return (%2))IR"; + + // ScalarImplicit can only unpack 0D tensors, whereas Tensors operated on by + // TensorRT are padded to 1 dimension. aten::item() resolves this conflict + std::string scalar_implicit_clean_pattern = R"IR( + graph(%1: Tensor): + %2: Scalar = aten::item(%1) + return (%2))IR"; + + torch::jit::SubgraphRewriter scalar_implicit_cast_rewriter; + scalar_implicit_cast_rewriter.RegisterRewritePattern(scalar_implicit_cast_pattern, scalar_implicit_clean_pattern); + scalar_implicit_cast_rewriter.runOnGraph(graph); + + LOG_GRAPH("After unpack and cast full: " << *graph); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace torch_tensorrt diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index b1346714ec..713894e0c7 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -41,6 +41,10 @@ void SiluToSigmoidMultipication(std::shared_ptr& graph); void UnpackHardSwish(std::shared_ptr& graph); void RewriteInputsWithParams(std::shared_ptr& g, std::vector& params); void UnpackHardSigmoid(std::shared_ptr& graph); +void UnpackAndCastMaskedFill(std::shared_ptr& graph, std::string target_device_name); +void UnpackAndCastNumToTensor(std::shared_ptr& graph, std::string target_device_name); +void UnpackAndCastFull(std::shared_ptr& graph, std::string target_device_name); +void ReplaceScalarImplicit(std::shared_ptr& graph); } // namespace passes } // namespace lowering diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 18924cccd6..7eeafdffdd 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -63,16 +63,40 @@ std::vector execute_engine(std::vector inputs, c10::intr CudaDevice curr_device = get_current_device(); LOG_DEBUG("Current Device: " << curr_device); + // Generic Target Device Prefix + std::string target_device = "cuda:"; + if (is_switch_required(curr_device, compiled_engine->device_info)) { // Scan through available CUDA devices and set the CUDA device context correctly CudaDevice device = select_cuda_device(compiled_engine->device_info); set_cuda_device(device); - std::string target_device = "cuda:" + std::to_string(device.id); + // Target device is new device + target_device += std::to_string(device.id); for (auto& in : inputs) { in = in.to(torch::Device(target_device)); } + } else { + // Target device is current device + target_device += std::to_string(curr_device.id); + } + + // For each input, ensure its current device is the desired target device + for (size_t i = 0; i < inputs.size(); i++) { + at::Tensor* in = &inputs[i]; + std::string current_tensor_device = in->device().str(); + + // If current device string does not match target device, display warning and move tensor accordingly + if (current_tensor_device != target_device) { + LOG_WARNING( + "Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device + << " but should be on " << target_device << ". This tensor is being moved by the runtime but " + << "for performance considerations, ensure your inputs are all on GPU " + << "and open an issue here (https://github.com/pytorch/TensorRT/issues) if this " + << "warning persists."); + *in = in->to(torch::Device(target_device)); + } } std::vector gpu_handles; diff --git a/cpp/src/compile_spec.cpp b/cpp/src/compile_spec.cpp index 3d7d9b15d3..daf05fd495 100644 --- a/cpp/src/compile_spec.cpp +++ b/cpp/src/compile_spec.cpp @@ -110,6 +110,7 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) { internal.convert_info.engine_settings.debug = external.debug; internal.convert_info.engine_settings.truncate_long_and_double = external.truncate_long_and_double; internal.convert_info.engine_settings.device.allow_gpu_fallback = external.device.allow_gpu_fallback; + internal.lower_info.target_device.allow_gpu_fallback = external.device.allow_gpu_fallback; TORCHTRT_CHECK( !(external.require_full_compilation && (external.torch_executed_ops.size() > 0)), @@ -130,10 +131,12 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) { switch (external.device.device_type) { case Device::DeviceType::kDLA: internal.convert_info.engine_settings.device.device_type = nvinfer1::DeviceType::kDLA; + internal.lower_info.target_device.device_type = nvinfer1::DeviceType::kDLA; break; case Device::DeviceType::kGPU: default: internal.convert_info.engine_settings.device.device_type = nvinfer1::DeviceType::kGPU; + internal.lower_info.target_device.device_type = nvinfer1::DeviceType::kGPU; } switch (external.capability) { @@ -150,6 +153,8 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) { internal.convert_info.engine_settings.device.gpu_id = external.device.gpu_id; internal.convert_info.engine_settings.device.dla_core = external.device.dla_core; + internal.lower_info.target_device.gpu_id = external.device.gpu_id; + internal.lower_info.target_device.dla_core = external.device.dla_core; internal.convert_info.engine_settings.num_avg_timing_iters = external.num_avg_timing_iters; internal.convert_info.engine_settings.workspace_size = external.workspace_size; internal.convert_info.engine_settings.dla_sram_size = external.dla_sram_size; diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index 4782f235c0..7f4e53d8a6 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -31,6 +31,10 @@ lowering_test( name = "test_conv1d_pass", ) +lowering_test( + name = "test_device_casting", +) + lowering_test( name = "test_exception_elimination_pass", ) @@ -95,6 +99,7 @@ test_suite( name = "lowering_tests", tests = [ ":test_conv1d_pass", + ":test_device_casting", ":test_exception_elimination_pass", ":test_linear_to_addmm", ":test_module_fallback_passes", diff --git a/tests/core/lowering/test_device_casting.cpp b/tests/core/lowering/test_device_casting.cpp new file mode 100644 index 0000000000..cc76f94ad6 --- /dev/null +++ b/tests/core/lowering/test_device_casting.cpp @@ -0,0 +1,194 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "core/util/prelude.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/csrc/jit/ir/subgraph_matcher.h" +#include "torch/csrc/jit/passes/common_subexpression_elimination.h" +#include "torch/torch.h" + +TEST(LoweringPasses, UnpackAndCastMaskedFillLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1: Tensor, %x.2: Tensor, %x.3: float): + %2 : Tensor = aten::masked_fill_(%x.1, %x.2, %x.3) + return (%2))IR"; + + auto in = at::rand({2, 3, 5, 7}, {at::kCUDA}); + auto in2 = at::rand({2, 3, 5, 7}, {at::kCUDA}).to(torch::kBool); + auto in3 = 7.3; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in, in2, in3}); + torch_tensorrt::core::lowering::passes::UnpackAndCastMaskedFill(g, "cuda:0"); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in, in2, in3}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +} + +TEST(LoweringPasses, UnpackAndCastNumToTensorLowersIntCorrectly) { + const auto graph = R"IR( + graph(%x.1: int): + %2 : Tensor = prim::NumToTensor(%x.1) + return (%2))IR"; + + auto in = 1; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g, "cuda:0"); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +} + +TEST(LoweringPasses, UnpackAndCastNumToTensorLowersFloatCorrectly) { + const auto graph = R"IR( + graph(%x.1: float): + %2 : Tensor = prim::NumToTensor(%x.1) + return (%2))IR"; + + auto in = 78.1; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g, "cuda:0"); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +} + +TEST(LoweringPasses, UnpackAndCastFullIntLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1: int): + %5 : NoneType = prim::Constant() + %2 : int = prim::Constant[value=3]() + %10 : int[] = prim::ListConstruct(%2, %2) + %out : Tensor = aten::full(%10, %x.1, %5, %5, %5, %5) + return (%out))IR"; + + auto in = 4; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + torch_tensorrt::core::lowering::passes::UnpackAndCastFull(g, "cuda:0"); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual( + jit_pre_results[0].toTensor(), jit_post_results[0].toTensor().cpu(), 2e-6)); +} + +TEST(LoweringPasses, UnpackAndCastFullFloatLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1: float): + %5 : NoneType = prim::Constant() + %2 : int = prim::Constant[value=5]() + %3 : int = prim::Constant[value=4]() + %10 : int[] = prim::ListConstruct(%2, %3) + %out : Tensor = aten::full(%10, %x.1, %5, %5, %5, %5) + return (%out))IR"; + + auto in = 54.1; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + torch_tensorrt::core::lowering::passes::UnpackAndCastFull(g, "cuda:0"); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual( + jit_pre_results[0].toTensor(), jit_post_results[0].toTensor().cpu(), 2e-6)); +} + +TEST(LoweringPasses, ReplaceScalarImplicitLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1: Tensor): + %5 : int = prim::Constant[value=0]() + %false : bool = prim::Constant[value=0]() + %none : NoneType = prim::Constant() + %cuda : Device = prim::Constant[value="cuda:0"]() + %3 : int = aten::size(%x.1, %5) + %y.2 : Tensor = prim::NumToTensor(%3) + %y.1 : Tensor = aten::to(%y.2, %cuda, %none, %false, %false) + %19 : Tensor[] = prim::ListConstruct(%x.1, %y.1) + %21 : Tensor, %22 : Tensor = prim::ListUnpack(%19) + %2 : Scalar = aten::ScalarImplicit(%22) + %out : Tensor = prim::NumToTensor(%2) + return (%out))IR"; + + auto in = at::rand({2, 3, 5, 7}, {at::kCUDA}); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + torch_tensorrt::core::lowering::passes::ReplaceScalarImplicit(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +} + +TEST(LoweringPasses, ReplaceScalarImplicitIntNumToTensorLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1: int): + %1 : Tensor = prim::NumToTensor(%x.1) + %2 : Scalar = aten::ScalarImplicit(%1) + %3 : Tensor = prim::NumToTensor(%2) + return (%3))IR"; + + auto in = 25; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g, "cuda:0"); + torch_tensorrt::core::lowering::passes::ReplaceScalarImplicit(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +} + +TEST(LoweringPasses, ReplaceScalarImplicitFloatLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1: float): + %1 : Tensor = prim::NumToTensor(%x.1) + %2 : Scalar = aten::ScalarImplicit(%1) + %3 : Tensor = prim::NumToTensor(%2) + return (%3))IR"; + + auto in = 2.5; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + torch_tensorrt::core::lowering::passes::ReplaceScalarImplicit(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +}