diff --git a/cpp/bin/torchtrtc/BUILD b/cpp/bin/torchtrtc/BUILD index 9265948b97..9d58e3211b 100644 --- a/cpp/bin/torchtrtc/BUILD +++ b/cpp/bin/torchtrtc/BUILD @@ -19,6 +19,9 @@ cc_binary( "parser_util.h", "parser_util.cpp" ], + linkopts = [ + "-l:libdl.so" + ], deps = [ "//third_party/args", "//cpp:torch_tensorrt", diff --git a/cpp/bin/torchtrtc/CMakeLists.txt b/cpp/bin/torchtrtc/CMakeLists.txt index 0ebfd87609..b12461e12a 100644 --- a/cpp/bin/torchtrtc/CMakeLists.txt +++ b/cpp/bin/torchtrtc/CMakeLists.txt @@ -10,7 +10,7 @@ add_executable(${executable_name} if (MSVC) target_link_libraries(${executable_name} PRIVATE torch torchtrt) else() - target_link_libraries(${executable_name} PRIVATE torch "-Wl,--no-as-needed" torchtrt "-Wl,--as-needed") + target_link_libraries(${executable_name} PRIVATE torch "-Wl,--no-as-needed -ldl" torchtrt "-Wl,--as-needed") set_target_properties( ${executable_name} PROPERTIES INSTALL_RPATH_USE_LINK_PATH FALSE # diff --git a/cpp/bin/torchtrtc/README.md b/cpp/bin/torchtrtc/README.md index 498f25ea17..6466ada390 100644 --- a/cpp/bin/torchtrtc/README.md +++ b/cpp/bin/torchtrtc/README.md @@ -108,6 +108,8 @@ torchtrtc [input_file_path] [output_file_path] TorchScript program, save the created engine to the path specified as the output path + --custom-torch-ops=[lib] (repeatable) Shared object/DLL containing custom torch operators + --custom-converters=[lib] (repeatable) Shared object/DLL containing custom converters input_file_path Path to input TorchScript file output_file_path Path for compiled TorchScript (or TensorRT engine) file @@ -131,3 +133,14 @@ e.g. ``` torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16 ``` + + +To run with custom torch operators +``` +torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts --custom-torch-ops= "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16 +``` + +To run with custom converters +``` +torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts --custom-converters= "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16 +``` \ No newline at end of file diff --git a/cpp/bin/torchtrtc/main.cpp b/cpp/bin/torchtrtc/main.cpp index 6c207d78da..f98ed848de 100644 --- a/cpp/bin/torchtrtc/main.cpp +++ b/cpp/bin/torchtrtc/main.cpp @@ -15,6 +15,33 @@ #include "luts.h" #include "parser_util.h" +#if defined(_WIN32) +#include +#else +#include +#endif + +void* load_library(std::string& custom_lib) { + void* handle = {nullptr}; +#if defined(_WIN32) + handle = LoadLibrary(custom_lib.c_str()); +#else + handle = dlopen(custom_lib.c_str(), RTLD_LAZY); +#endif + return handle; +} + +bool unload_library(void* custom_lib) { + bool success = false; +#if defined(_WIN32) + // Returns status non-zero for success + success = FreeLibrary(custom_lib) ? true : false; +#else + success = dlclose(custom_lib) ? false : true; +#endif + return success; +} + int main(int argc, char** argv) { torchtrt::logging::set_is_colored_output_on(true); torchtrt::logging::set_reportable_log_level(torchtrt::logging::Level::kWARNING); @@ -117,8 +144,7 @@ int main(int argc, char** argv) { parser, "num_iters", "Number of averaging timing iterations used to select kernels", {"num-avg-timing-iters"}); args::ValueFlag workspace_size( parser, "workspace_size", "Maximum size of workspace given to TensorRT", {"workspace-size"}); - args::ValueFlag dla_sram_size( - parser, "dla_sram_size", "DLA managed SRAM size", {"dla-sram-size"}); + args::ValueFlag dla_sram_size(parser, "dla_sram_size", "DLA managed SRAM size", {"dla-sram-size"}); args::ValueFlag dla_local_dram_size( parser, "dla_local_dram_size", "DLA Local DRAM size", {"dla-local-dram-size"}); args::ValueFlag dla_global_dram_size( @@ -147,6 +173,18 @@ int main(int argc, char** argv) { "save_engine", "Instead of compiling a full a TorchScript program, save the created engine to the path specified as the output path", {"save-engine"}); + args::ValueFlagList custom_torch_ops( + parser, + "custom-torch-ops", + "(repeatable) Shared object/DLL containing custom torch operators", + {"custom-torch-ops"}); + + args::ValueFlagList custom_converters( + parser, + "custom-converters", + "(repeatable) Shared object/DLL containing custom converters", + {"custom-converters"}); + args::Positional input_path(parser, "input_file_path", "Path to input TorchScript file"); args::Positional output_path( parser, "output_file_path", "Path for compiled TorchScript (or TensorRT engine) file"); @@ -174,6 +212,34 @@ int main(int argc, char** argv) { torchtrt::logging::set_reportable_log_level(torchtrt::logging::Level::kERROR); } + std::vector> custom_torch_op, custom_converter_op; + if (custom_torch_ops) { + for (auto& op : args::get(custom_torch_ops)) { + auto* handle = load_library(op); + if (handle == nullptr) { + torchtrt::logging::log( + torchtrt::logging::Level::kERROR, std::string("Could not load custom_torch_ops library " + op)); + } else { + torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Loaded custom_torch_ops library " + op)); + + custom_torch_op.push_back({op, handle}); + } + } + } + + if (custom_converters) { + for (auto& op : args::get(custom_converters)) { + auto* handle = load_library(op); + if (handle == nullptr) { + torchtrt::logging::log( + torchtrt::logging::Level::kERROR, std::string("Could not load custom_converter library " + op)); + } else { + torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Loaded custom_converter library " + op)); + custom_converter_op.push_back({op, handle}); + } + } + } + auto real_input_path = torchtrtc::fileio::resolve_path(args::get(input_path)); if (check_method_op_support) { @@ -189,7 +255,7 @@ int main(int argc, char** argv) { auto method = args::get(check_method_op_support); auto result = torchtrt::ts::check_method_operator_support(mod, method); if (result) { - std::cout << "The method is supported end to end by Torch-TensorRT" << std::endl; + torchtrt::logging::log(torchtrt::logging::Level::kINFO, "The method is supported end to end by Torch-TensorRT"); return 0; } else { torchtrt::logging::log(torchtrt::logging::Level::kERROR, "Method is not currently supported by Torch-TensorRT"); @@ -477,5 +543,29 @@ int main(int argc, char** argv) { trt_mod.save(real_output_path); } + if (custom_torch_ops) { + for (auto& p : custom_torch_op) { + auto status = unload_library(p.second); + if (status) { + torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Unloaded custom library " + p.first)); + } else { + torchtrt::logging::log( + torchtrt::logging::Level::kERROR, std::string("Could not unload custom library " + p.first)); + } + } + } + + if (custom_converters) { + for (auto& p : custom_converter_op) { + auto status = unload_library(p.second); + if (status) { + torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Unloaded custom library " + p.first)); + } else { + torchtrt::logging::log( + torchtrt::logging::Level::kERROR, std::string("Could not unload custom library " + p.first)); + } + } + } + return 0; } diff --git a/docsrc/tutorials/torchtrtc.rst b/docsrc/tutorials/torchtrtc.rst index 5a2808bb8d..68f599a5cd 100644 --- a/docsrc/tutorials/torchtrtc.rst +++ b/docsrc/tutorials/torchtrtc.rst @@ -111,6 +111,8 @@ to standard TorchScript. Load with ``torch.jit.load()`` and run like you would r TorchScript program, save the created engine to the path specified as the output path + --custom-torch-ops (repeatable) Shared object/DLL containing custom torch operators + --custom-converters (repeatable) Shared object/DLL containing custom converters input_file_path Path to input TorchScript file output_file_path Path for compiled TorchScript (or TensorRT engine) file @@ -132,3 +134,13 @@ e.g. .. code-block:: shell torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@f16%contiguous" -p f16 + + +To run with custom torch operators +.. code-block:: shell +torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts --custom-torch-ops= "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16 + + +To run with custom converters +.. code-block:: shell +torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts --custom-converters= "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16