From 1d5712d415f3f45fe24e6b8d3233377e7c1a3eea Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 24 Oct 2022 22:33:35 -0700 Subject: [PATCH 1/2] fix: Device casting issues with certain `aten` operators - Investigated issue arising with BART-base model (https://huggingface.co/facebook/bart-base) where certain tensor inputs to TensorRT were on the cpu, despite users explicitly casting all inputs properly - Traced issue to internally-generated 0D tensors, mask tensors, and operations returning CPU tensors passed between Torch and Torch-TensorRT engines - Added lowering passes to ensure function edge cases are appropriately dealt with, tensors are located on the proper device at runtime, and added validation check in runtime to avoid models crashing at runtime due to device mismatches - Added testing for lowering passes to ensure output values are accurate --- core/lowering/lowering.cpp | 4 + core/lowering/passes/BUILD | 1 + core/lowering/passes/CMakeLists.txt | 1 + core/lowering/passes/device_casting.cpp | 103 +++++++++++ core/lowering/passes/passes.h | 4 + core/runtime/execute_engine.cpp | 27 ++- tests/core/lowering/BUILD | 5 + tests/core/lowering/test_device_casting.cpp | 194 ++++++++++++++++++++ 8 files changed, 338 insertions(+), 1 deletion(-) create mode 100644 core/lowering/passes/device_casting.cpp create mode 100644 tests/core/lowering/test_device_casting.cpp diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 18893e1631..305be64185 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -70,6 +70,10 @@ void LowerGraph(std::shared_ptr& g, std::vector& graph) { + std::string masked_fill_pattern = R"IR( + graph(%self, %mask, %value): + %out: Tensor = aten::masked_fill_(%self, %mask, %value) + return (%out))IR"; + + // Calls to masked_fill_ often utilize CPU tensors, and as such + // should be casted to CUDA to avoid device mismatch errors + std::string unpacked_pattern = R"IR( + graph(%self, %mask, %value): + %device: Device = prim::Constant[value="cuda"]() + %dtype: NoneType = prim::Constant() + %false: bool = prim::Constant[value=0]() + %mask_cuda: Tensor = aten::to(%mask, %device, %dtype, %false, %false) + %self_cuda: Tensor = aten::to(%self, %device, %dtype, %false, %false) + %out: Tensor = aten::masked_fill_(%self_cuda, %mask_cuda, %value) + return (%out))IR"; + + torch::jit::SubgraphRewriter masked_fill_rewriter; + masked_fill_rewriter.RegisterRewritePattern(masked_fill_pattern, unpacked_pattern); + masked_fill_rewriter.runOnGraph(graph); + LOG_GRAPH("After unpack and cast masked_fill_: " << *graph); +} + +void UnpackAndCastNumToTensor(std::shared_ptr& graph) { + std::string num_to_tensor_cast_pattern = R"IR( + graph(%1: Scalar): + %2: Tensor = prim::NumToTensor(%1) + return (%2))IR"; + + // 0D Tensors are initialized on cpu, and need to be casted to CUDA + // to avoid device mismatch issues + std::string num_to_tensor_clean_pattern = R"IR( + graph(%1: Scalar): + %2: Tensor = prim::NumToTensor(%1) + %device: Device = prim::Constant[value="cuda"]() + %dtype: NoneType = prim::Constant() + %false: bool = prim::Constant[value=0]() + %3: Tensor = aten::to(%2, %device, %dtype, %false, %false) + return (%3))IR"; + + torch::jit::SubgraphRewriter num_to_tensor_cast_rewriter; + num_to_tensor_cast_rewriter.RegisterRewritePattern(num_to_tensor_cast_pattern, num_to_tensor_clean_pattern); + num_to_tensor_cast_rewriter.runOnGraph(graph); + + LOG_GRAPH("After unpack and cast NumToTensor: " << *graph); +} + +void UnpackAndCastFull(std::shared_ptr& graph) { + std::string full_cast_pattern = R"IR( + graph(%1, %2, %3, %4, %5, %6): + %out: Tensor = aten::full(%1, %2, %3, %4, %5, %6) + return (%out))IR"; + + // Tensors created via aten::full are initialized on cpu, and need to be casted to CUDA + // to avoid device mismatch issues + std::string full_clean_pattern = R"IR( + graph(%1, %2, %3, %4, %5, %6): + %cuda: Device = prim::Constant[value="cuda"]() + %out: Tensor = aten::full(%1, %2, %3, %4, %cuda, %6) + return (%out))IR"; + + torch::jit::SubgraphRewriter full_cast_rewriter; + full_cast_rewriter.RegisterRewritePattern(full_cast_pattern, full_clean_pattern); + full_cast_rewriter.runOnGraph(graph); + + LOG_GRAPH("After unpack and cast full: " << *graph); +} + +void ReplaceScalarImplicit(std::shared_ptr& graph) { + std::string scalar_implicit_cast_pattern = R"IR( + graph(%1: Tensor): + %2: Scalar = aten::ScalarImplicit(%1) + return (%2))IR"; + + // ScalarImplicit can only unpack 0D tensors, whereas Tensors operated on by + // TensorRT are padded to 1 dimension. aten::item() resolves this conflict + std::string scalar_implicit_clean_pattern = R"IR( + graph(%1: Tensor): + %2: Scalar = aten::item(%1) + return (%2))IR"; + + torch::jit::SubgraphRewriter scalar_implicit_cast_rewriter; + scalar_implicit_cast_rewriter.RegisterRewritePattern(scalar_implicit_cast_pattern, scalar_implicit_clean_pattern); + scalar_implicit_cast_rewriter.runOnGraph(graph); + + LOG_GRAPH("After unpack and cast full: " << *graph); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace torch_tensorrt diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index b1346714ec..1e02656294 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -41,6 +41,10 @@ void SiluToSigmoidMultipication(std::shared_ptr& graph); void UnpackHardSwish(std::shared_ptr& graph); void RewriteInputsWithParams(std::shared_ptr& g, std::vector& params); void UnpackHardSigmoid(std::shared_ptr& graph); +void UnpackAndCastMaskedFill(std::shared_ptr& graph); +void UnpackAndCastNumToTensor(std::shared_ptr& graph); +void UnpackAndCastFull(std::shared_ptr& graph); +void ReplaceScalarImplicit(std::shared_ptr& graph); } // namespace passes } // namespace lowering diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 18924cccd6..65b5ca6880 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -63,16 +63,41 @@ std::vector execute_engine(std::vector inputs, c10::intr CudaDevice curr_device = get_current_device(); LOG_DEBUG("Current Device: " << curr_device); + // Generic Target Device Prefix + std::string target_device = "cuda:"; + if (is_switch_required(curr_device, compiled_engine->device_info)) { // Scan through available CUDA devices and set the CUDA device context correctly CudaDevice device = select_cuda_device(compiled_engine->device_info); set_cuda_device(device); - std::string target_device = "cuda:" + std::to_string(device.id); + // Target device is new device + target_device += std::to_string(device.id); for (auto& in : inputs) { in = in.to(torch::Device(target_device)); } + } else { + // Target device is current device + target_device += std::to_string(curr_device.id); + + // For each input, ensure its current device is the desired target device + for (size_t i = 0; i < inputs.size(); i++) { + at::Tensor* in = &inputs[i]; + std::string current_tensor_device = in->device().str(); + + // If current device string does not match target device, display warning and move tensor accordingly + if (current_tensor_device != target_device) { + LOG_WARNING( + "Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device + << " but should be on " << target_device + << ". This tensor is being moved manually by the runtime but " + << "for performance considerations, ensure your inputs are all on GPU " + << "and open an issue here (https://github.com/pytorch/TensorRT/issues) if this " + << "warning persists."); + *in = in->to(torch::Device(target_device)); + } + } } std::vector gpu_handles; diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index 4782f235c0..7f4e53d8a6 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -31,6 +31,10 @@ lowering_test( name = "test_conv1d_pass", ) +lowering_test( + name = "test_device_casting", +) + lowering_test( name = "test_exception_elimination_pass", ) @@ -95,6 +99,7 @@ test_suite( name = "lowering_tests", tests = [ ":test_conv1d_pass", + ":test_device_casting", ":test_exception_elimination_pass", ":test_linear_to_addmm", ":test_module_fallback_passes", diff --git a/tests/core/lowering/test_device_casting.cpp b/tests/core/lowering/test_device_casting.cpp new file mode 100644 index 0000000000..ab5c04be1a --- /dev/null +++ b/tests/core/lowering/test_device_casting.cpp @@ -0,0 +1,194 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "core/util/prelude.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/csrc/jit/ir/subgraph_matcher.h" +#include "torch/csrc/jit/passes/common_subexpression_elimination.h" +#include "torch/torch.h" + +TEST(LoweringPasses, UnpackAndCastMaskedFillLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1: Tensor, %x.2: Tensor, %x.3: float): + %2 : Tensor = aten::masked_fill_(%x.1, %x.2, %x.3) + return (%2))IR"; + + auto in = at::rand({2, 3, 5, 7}, {at::kCUDA}); + auto in2 = at::rand({2, 3, 5, 7}, {at::kCUDA}).to(torch::kBool); + auto in3 = 7.3; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in, in2, in3}); + torch_tensorrt::core::lowering::passes::UnpackAndCastMaskedFill(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in, in2, in3}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +} + +TEST(LoweringPasses, UnpackAndCastNumToTensorLowersIntCorrectly) { + const auto graph = R"IR( + graph(%x.1: int): + %2 : Tensor = prim::NumToTensor(%x.1) + return (%2))IR"; + + auto in = 1; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +} + +TEST(LoweringPasses, UnpackAndCastNumToTensorLowersFloatCorrectly) { + const auto graph = R"IR( + graph(%x.1: float): + %2 : Tensor = prim::NumToTensor(%x.1) + return (%2))IR"; + + auto in = 78.1; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +} + +TEST(LoweringPasses, UnpackAndCastFullIntLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1: int): + %5 : NoneType = prim::Constant() + %2 : int = prim::Constant[value=3]() + %10 : int[] = prim::ListConstruct(%2, %2) + %out : Tensor = aten::full(%10, %x.1, %5, %5, %5, %5) + return (%out))IR"; + + auto in = 4; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + torch_tensorrt::core::lowering::passes::UnpackAndCastFull(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual( + jit_pre_results[0].toTensor(), jit_post_results[0].toTensor().cpu(), 2e-6)); +} + +TEST(LoweringPasses, UnpackAndCastFullFloatLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1: float): + %5 : NoneType = prim::Constant() + %2 : int = prim::Constant[value=5]() + %3 : int = prim::Constant[value=4]() + %10 : int[] = prim::ListConstruct(%2, %3) + %out : Tensor = aten::full(%10, %x.1, %5, %5, %5, %5) + return (%out))IR"; + + auto in = 54.1; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + torch_tensorrt::core::lowering::passes::UnpackAndCastFull(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual( + jit_pre_results[0].toTensor(), jit_post_results[0].toTensor().cpu(), 2e-6)); +} + +TEST(LoweringPasses, ReplaceScalarImplicitLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1: Tensor): + %5 : int = prim::Constant[value=0]() + %false : bool = prim::Constant[value=0]() + %none : NoneType = prim::Constant() + %cuda : Device = prim::Constant[value="cuda"]() + %3 : int = aten::size(%x.1, %5) + %y.2 : Tensor = prim::NumToTensor(%3) + %y.1 : Tensor = aten::to(%y.2, %cuda, %none, %false, %false) + %19 : Tensor[] = prim::ListConstruct(%x.1, %y.1) + %21 : Tensor, %22 : Tensor = prim::ListUnpack(%19) + %2 : Scalar = aten::ScalarImplicit(%22) + %out : Tensor = prim::NumToTensor(%2) + return (%out))IR"; + + auto in = at::rand({2, 3, 5, 7}, {at::kCUDA}); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + torch_tensorrt::core::lowering::passes::ReplaceScalarImplicit(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +} + +TEST(LoweringPasses, ReplaceScalarImplicitIntNumToTensorLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1: int): + %1 : Tensor = prim::NumToTensor(%x.1) + %2 : Scalar = aten::ScalarImplicit(%1) + %3 : Tensor = prim::NumToTensor(%2) + return (%3))IR"; + + auto in = 25; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g); + torch_tensorrt::core::lowering::passes::ReplaceScalarImplicit(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +} + +TEST(LoweringPasses, ReplaceScalarImplicitFloatLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1: float): + %1 : Tensor = prim::NumToTensor(%x.1) + %2 : Scalar = aten::ScalarImplicit(%1) + %3 : Tensor = prim::NumToTensor(%2) + return (%3))IR"; + + auto in = 2.5; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + torch_tensorrt::core::lowering::passes::ReplaceScalarImplicit(g); + torch::jit::EliminateCommonSubexpression(g); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +} From 8583a4c00fa18f45706076fce3068df0e00f18a4 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 1 Nov 2022 10:46:04 -0700 Subject: [PATCH 2/2] fix: Update paradigm for device casting to depend on user-specified device - Adde field to LowerInfo to hold device information - Update internal Device struct location to allow streamlined imports - Update BUILD files - Build strings in lowering phase using user-specified target device - Update CMakeLists to reflect IR dependency in lowering - Update runtime device location code to run regardless of whether a switch is required or not. --- core/conversion/conversionctx/BUILD | 1 + core/conversion/conversionctx/ConversionCtx.h | 11 +---- core/ir/ir.h | 8 ++++ core/lowering/BUILD | 1 + core/lowering/CMakeLists.txt | 5 +- core/lowering/lowering.cpp | 6 +-- core/lowering/lowering.h | 6 +++ core/lowering/passes/device_casting.cpp | 46 +++++++++++++------ core/lowering/passes/passes.h | 6 +-- core/runtime/execute_engine.cpp | 31 ++++++------- cpp/src/compile_spec.cpp | 5 ++ tests/core/lowering/test_device_casting.cpp | 14 +++--- 12 files changed, 87 insertions(+), 53 deletions(-) diff --git a/core/conversion/conversionctx/BUILD b/core/conversion/conversionctx/BUILD index bf76b9b905..6626ae4457 100644 --- a/core/conversion/conversionctx/BUILD +++ b/core/conversion/conversionctx/BUILD @@ -21,6 +21,7 @@ cc_library( deps = [ "@tensorrt//:nvinfer", "//core/util:prelude", + "//core/ir", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], "//conditions:default": ["@libtorch//:libtorch"], diff --git a/core/conversion/conversionctx/ConversionCtx.h b/core/conversion/conversionctx/ConversionCtx.h index 5f8d6e955b..f7ed5c89d2 100644 --- a/core/conversion/conversionctx/ConversionCtx.h +++ b/core/conversion/conversionctx/ConversionCtx.h @@ -9,20 +9,13 @@ #include "torch/csrc/jit/ir/ir.h" #include +#include "core/ir/ir.h" #include "core/util/prelude.h" namespace torch_tensorrt { namespace core { namespace conversion { -struct Device { - nvinfer1::DeviceType device_type; - int64_t gpu_id; - int64_t dla_core; - bool allow_gpu_fallback; - Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {} -}; - struct BuilderSettings { std::set enabled_precisions = {}; bool sparse_weights = false; @@ -30,7 +23,7 @@ struct BuilderSettings { bool refit = false; bool debug = false; bool truncate_long_and_double = false; - Device device; + ir::Device device; nvinfer1::EngineCapability capability = TRT_ENGINE_CAPABILITY_STANDARD; nvinfer1::IInt8Calibrator* calibrator = nullptr; uint64_t num_avg_timing_iters = 1; diff --git a/core/ir/ir.h b/core/ir/ir.h index a5225daa25..6c78908d5b 100644 --- a/core/ir/ir.h +++ b/core/ir/ir.h @@ -11,6 +11,14 @@ namespace torch_tensorrt { namespace core { namespace ir { +struct Device { + nvinfer1::DeviceType device_type; + int64_t gpu_id; + int64_t dla_core; + bool allow_gpu_fallback; + Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {} +}; + struct Input : torch::CustomClassHolder { Input(){}; Input( diff --git a/core/lowering/BUILD b/core/lowering/BUILD index ae0f39032a..9c08752bb5 100644 --- a/core/lowering/BUILD +++ b/core/lowering/BUILD @@ -24,6 +24,7 @@ cc_library( deps = [ "//core/lowering/passes", "//core/util:prelude", + "//core/ir", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], "//conditions:default": ["@libtorch//:libtorch"], diff --git a/core/lowering/CMakeLists.txt b/core/lowering/CMakeLists.txt index 445f627473..af9323784c 100644 --- a/core/lowering/CMakeLists.txt +++ b/core/lowering/CMakeLists.txt @@ -15,6 +15,8 @@ set(HEADER_FILES target_sources(${lib_name} PRIVATE ${CXX_SRCS} + PUBLIC + $ $ ) @@ -25,8 +27,9 @@ target_include_directories(${lib_name} target_link_libraries(${lib_name} PUBLIC + TensorRT::nvinfer torch - PRIVATE + core_ir core_util ) diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 305be64185..4d665b390a 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -70,9 +70,9 @@ void LowerGraph(std::shared_ptr& g, std::vector +#include "core/ir/ir.h" #include "torch/csrc/jit/ir/ir.h" namespace torch_tensorrt { @@ -15,8 +16,13 @@ struct LowerInfo { // Since these QDQ nodes will be identical as they share same input, one of them is eliminated due to CSE lowering // pass. Disable this in order to not disturb TensorRT's QAT optimizations. bool disable_cse = false; + ir::Device target_device; std::vector forced_fallback_modules; friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l); + + std::string getGPUDeviceString() { + return "cuda:" + std::to_string(target_device.gpu_id); + }; }; void LowerBlock(torch::jit::Block* b); diff --git a/core/lowering/passes/device_casting.cpp b/core/lowering/passes/device_casting.cpp index 2428270607..eee27336d4 100644 --- a/core/lowering/passes/device_casting.cpp +++ b/core/lowering/passes/device_casting.cpp @@ -8,47 +8,59 @@ namespace core { namespace lowering { namespace passes { -void UnpackAndCastMaskedFill(std::shared_ptr& graph) { +void UnpackAndCastMaskedFill(std::shared_ptr& graph, std::string target_device_name) { std::string masked_fill_pattern = R"IR( graph(%self, %mask, %value): %out: Tensor = aten::masked_fill_(%self, %mask, %value) return (%out))IR"; // Calls to masked_fill_ often utilize CPU tensors, and as such - // should be casted to CUDA to avoid device mismatch errors - std::string unpacked_pattern = R"IR( + // should be moved to gpu to avoid device mismatch errors + + // Separate string into portions to insert device name + std::string clean_pattern_part_1 = R"IR( graph(%self, %mask, %value): - %device: Device = prim::Constant[value="cuda"]() + %device: Device = prim::Constant[value=")IR"; + + std::string clean_pattern_part_2 = R"IR("]() %dtype: NoneType = prim::Constant() %false: bool = prim::Constant[value=0]() %mask_cuda: Tensor = aten::to(%mask, %device, %dtype, %false, %false) %self_cuda: Tensor = aten::to(%self, %device, %dtype, %false, %false) - %out: Tensor = aten::masked_fill_(%self_cuda, %mask_cuda, %value) + %out: Tensor = aten::masked_fill(%self_cuda, %mask_cuda, %value) return (%out))IR"; + auto unpacked_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2; + torch::jit::SubgraphRewriter masked_fill_rewriter; masked_fill_rewriter.RegisterRewritePattern(masked_fill_pattern, unpacked_pattern); masked_fill_rewriter.runOnGraph(graph); LOG_GRAPH("After unpack and cast masked_fill_: " << *graph); } -void UnpackAndCastNumToTensor(std::shared_ptr& graph) { +void UnpackAndCastNumToTensor(std::shared_ptr& graph, std::string target_device_name) { std::string num_to_tensor_cast_pattern = R"IR( graph(%1: Scalar): %2: Tensor = prim::NumToTensor(%1) return (%2))IR"; - // 0D Tensors are initialized on cpu, and need to be casted to CUDA + // 0D Tensors are initialized on cpu, and need to be moved to gpu // to avoid device mismatch issues - std::string num_to_tensor_clean_pattern = R"IR( + + // Separate string into portions to insert device name + std::string clean_pattern_part_1 = R"IR( graph(%1: Scalar): %2: Tensor = prim::NumToTensor(%1) - %device: Device = prim::Constant[value="cuda"]() + %device: Device = prim::Constant[value=")IR"; + + std::string clean_pattern_part_2 = R"IR("]() %dtype: NoneType = prim::Constant() %false: bool = prim::Constant[value=0]() %3: Tensor = aten::to(%2, %device, %dtype, %false, %false) return (%3))IR"; + auto num_to_tensor_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2; + torch::jit::SubgraphRewriter num_to_tensor_cast_rewriter; num_to_tensor_cast_rewriter.RegisterRewritePattern(num_to_tensor_cast_pattern, num_to_tensor_clean_pattern); num_to_tensor_cast_rewriter.runOnGraph(graph); @@ -56,20 +68,26 @@ void UnpackAndCastNumToTensor(std::shared_ptr& graph) { LOG_GRAPH("After unpack and cast NumToTensor: " << *graph); } -void UnpackAndCastFull(std::shared_ptr& graph) { +void UnpackAndCastFull(std::shared_ptr& graph, std::string target_device_name) { std::string full_cast_pattern = R"IR( graph(%1, %2, %3, %4, %5, %6): %out: Tensor = aten::full(%1, %2, %3, %4, %5, %6) return (%out))IR"; - // Tensors created via aten::full are initialized on cpu, and need to be casted to CUDA + // Tensors created via aten::full are initialized on cpu, and need to be casted to gpu // to avoid device mismatch issues - std::string full_clean_pattern = R"IR( + + // Separate string into portions to insert device name + std::string clean_pattern_part_1 = R"IR( graph(%1, %2, %3, %4, %5, %6): - %cuda: Device = prim::Constant[value="cuda"]() - %out: Tensor = aten::full(%1, %2, %3, %4, %cuda, %6) + %device: Device = prim::Constant[value=")IR"; + + std::string clean_pattern_part_2 = R"IR("]() + %out: Tensor = aten::full(%1, %2, %3, %4, %device, %6) return (%out))IR"; + auto full_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2; + torch::jit::SubgraphRewriter full_cast_rewriter; full_cast_rewriter.RegisterRewritePattern(full_cast_pattern, full_clean_pattern); full_cast_rewriter.runOnGraph(graph); diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index 1e02656294..713894e0c7 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -41,9 +41,9 @@ void SiluToSigmoidMultipication(std::shared_ptr& graph); void UnpackHardSwish(std::shared_ptr& graph); void RewriteInputsWithParams(std::shared_ptr& g, std::vector& params); void UnpackHardSigmoid(std::shared_ptr& graph); -void UnpackAndCastMaskedFill(std::shared_ptr& graph); -void UnpackAndCastNumToTensor(std::shared_ptr& graph); -void UnpackAndCastFull(std::shared_ptr& graph); +void UnpackAndCastMaskedFill(std::shared_ptr& graph, std::string target_device_name); +void UnpackAndCastNumToTensor(std::shared_ptr& graph, std::string target_device_name); +void UnpackAndCastFull(std::shared_ptr& graph, std::string target_device_name); void ReplaceScalarImplicit(std::shared_ptr& graph); } // namespace passes diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 65b5ca6880..7eeafdffdd 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -80,23 +80,22 @@ std::vector execute_engine(std::vector inputs, c10::intr } else { // Target device is current device target_device += std::to_string(curr_device.id); + } + + // For each input, ensure its current device is the desired target device + for (size_t i = 0; i < inputs.size(); i++) { + at::Tensor* in = &inputs[i]; + std::string current_tensor_device = in->device().str(); - // For each input, ensure its current device is the desired target device - for (size_t i = 0; i < inputs.size(); i++) { - at::Tensor* in = &inputs[i]; - std::string current_tensor_device = in->device().str(); - - // If current device string does not match target device, display warning and move tensor accordingly - if (current_tensor_device != target_device) { - LOG_WARNING( - "Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device - << " but should be on " << target_device - << ". This tensor is being moved manually by the runtime but " - << "for performance considerations, ensure your inputs are all on GPU " - << "and open an issue here (https://github.com/pytorch/TensorRT/issues) if this " - << "warning persists."); - *in = in->to(torch::Device(target_device)); - } + // If current device string does not match target device, display warning and move tensor accordingly + if (current_tensor_device != target_device) { + LOG_WARNING( + "Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device + << " but should be on " << target_device << ". This tensor is being moved by the runtime but " + << "for performance considerations, ensure your inputs are all on GPU " + << "and open an issue here (https://github.com/pytorch/TensorRT/issues) if this " + << "warning persists."); + *in = in->to(torch::Device(target_device)); } } diff --git a/cpp/src/compile_spec.cpp b/cpp/src/compile_spec.cpp index 3d7d9b15d3..daf05fd495 100644 --- a/cpp/src/compile_spec.cpp +++ b/cpp/src/compile_spec.cpp @@ -110,6 +110,7 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) { internal.convert_info.engine_settings.debug = external.debug; internal.convert_info.engine_settings.truncate_long_and_double = external.truncate_long_and_double; internal.convert_info.engine_settings.device.allow_gpu_fallback = external.device.allow_gpu_fallback; + internal.lower_info.target_device.allow_gpu_fallback = external.device.allow_gpu_fallback; TORCHTRT_CHECK( !(external.require_full_compilation && (external.torch_executed_ops.size() > 0)), @@ -130,10 +131,12 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) { switch (external.device.device_type) { case Device::DeviceType::kDLA: internal.convert_info.engine_settings.device.device_type = nvinfer1::DeviceType::kDLA; + internal.lower_info.target_device.device_type = nvinfer1::DeviceType::kDLA; break; case Device::DeviceType::kGPU: default: internal.convert_info.engine_settings.device.device_type = nvinfer1::DeviceType::kGPU; + internal.lower_info.target_device.device_type = nvinfer1::DeviceType::kGPU; } switch (external.capability) { @@ -150,6 +153,8 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) { internal.convert_info.engine_settings.device.gpu_id = external.device.gpu_id; internal.convert_info.engine_settings.device.dla_core = external.device.dla_core; + internal.lower_info.target_device.gpu_id = external.device.gpu_id; + internal.lower_info.target_device.dla_core = external.device.dla_core; internal.convert_info.engine_settings.num_avg_timing_iters = external.num_avg_timing_iters; internal.convert_info.engine_settings.workspace_size = external.workspace_size; internal.convert_info.engine_settings.dla_sram_size = external.dla_sram_size; diff --git a/tests/core/lowering/test_device_casting.cpp b/tests/core/lowering/test_device_casting.cpp index ab5c04be1a..cc76f94ad6 100644 --- a/tests/core/lowering/test_device_casting.cpp +++ b/tests/core/lowering/test_device_casting.cpp @@ -23,7 +23,7 @@ TEST(LoweringPasses, UnpackAndCastMaskedFillLowersCorrectly) { torch::jit::parseIR(graph, g.get()); auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in, in2, in3}); - torch_tensorrt::core::lowering::passes::UnpackAndCastMaskedFill(g); + torch_tensorrt::core::lowering::passes::UnpackAndCastMaskedFill(g, "cuda:0"); torch::jit::EliminateCommonSubexpression(g); auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in, in2, in3}); @@ -43,7 +43,7 @@ TEST(LoweringPasses, UnpackAndCastNumToTensorLowersIntCorrectly) { torch::jit::parseIR(graph, g.get()); auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); - torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g); + torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g, "cuda:0"); torch::jit::EliminateCommonSubexpression(g); auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); @@ -63,7 +63,7 @@ TEST(LoweringPasses, UnpackAndCastNumToTensorLowersFloatCorrectly) { torch::jit::parseIR(graph, g.get()); auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); - torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g); + torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g, "cuda:0"); torch::jit::EliminateCommonSubexpression(g); auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); @@ -86,7 +86,7 @@ TEST(LoweringPasses, UnpackAndCastFullIntLowersCorrectly) { torch::jit::parseIR(graph, g.get()); auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); - torch_tensorrt::core::lowering::passes::UnpackAndCastFull(g); + torch_tensorrt::core::lowering::passes::UnpackAndCastFull(g, "cuda:0"); torch::jit::EliminateCommonSubexpression(g); auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); @@ -110,7 +110,7 @@ TEST(LoweringPasses, UnpackAndCastFullFloatLowersCorrectly) { torch::jit::parseIR(graph, g.get()); auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); - torch_tensorrt::core::lowering::passes::UnpackAndCastFull(g); + torch_tensorrt::core::lowering::passes::UnpackAndCastFull(g, "cuda:0"); torch::jit::EliminateCommonSubexpression(g); auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); @@ -124,7 +124,7 @@ TEST(LoweringPasses, ReplaceScalarImplicitLowersCorrectly) { %5 : int = prim::Constant[value=0]() %false : bool = prim::Constant[value=0]() %none : NoneType = prim::Constant() - %cuda : Device = prim::Constant[value="cuda"]() + %cuda : Device = prim::Constant[value="cuda:0"]() %3 : int = aten::size(%x.1, %5) %y.2 : Tensor = prim::NumToTensor(%3) %y.1 : Tensor = aten::to(%y.2, %cuda, %none, %false, %false) @@ -162,7 +162,7 @@ TEST(LoweringPasses, ReplaceScalarImplicitIntNumToTensorLowersCorrectly) { torch::jit::parseIR(graph, g.get()); auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); - torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g); + torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g, "cuda:0"); torch_tensorrt::core::lowering::passes::ReplaceScalarImplicit(g); torch::jit::EliminateCommonSubexpression(g); auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});