diff --git a/core/compiler.cpp b/core/compiler.cpp index 025c8023dc..db20003640 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -46,7 +46,7 @@ c10::FunctionSchema GenerateGraphSchema( void AddEngineToGraph( torch::jit::script::Module mod, std::shared_ptr& g, - std::string& serialized_engine) { + const std::string& serialized_engine) { auto engine_ptr = c10::make_intrusive(mod._ivalue()->name(), serialized_engine); // Get required metadata about the engine out auto num_io = engine_ptr->num_io; @@ -173,6 +173,20 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C return new_mod; } +torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine) { + std::ostringstream engine_id; + engine_id << reinterpret_cast(&engine); + torch::jit::script::Module new_mod("tensorrt_engine_mod_" + engine_id.str()); + auto new_g = std::make_shared(); + AddEngineToGraph(new_mod, new_g, engine); + auto new_method = new_mod._ivalue()->compilation_unit()->create_function("forward", new_g); + auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g); + new_mod.type()->addMethod(new_method); + new_method->setSchema(schema); + + return new_mod; +} + void set_device(const int gpu_id) { TRTORCH_ASSERT(cudaSetDevice(gpu_id) == cudaSuccess, "Unable to set CUDA device: " << gpu_id); } diff --git a/core/compiler.h b/core/compiler.h index 512b30123f..a7d16c6b8d 100644 --- a/core/compiler.h +++ b/core/compiler.h @@ -19,6 +19,8 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std:: torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg); +torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine); + void set_device(const int gpu_id); } // namespace core diff --git a/cpp/api/include/trtorch/trtorch.h b/cpp/api/include/trtorch/trtorch.h index 51e170b769..b033b45629 100644 --- a/cpp/api/include/trtorch/trtorch.h +++ b/cpp/api/include/trtorch/trtorch.h @@ -480,6 +480,21 @@ TRTORCH_API std::string ConvertGraphToTRTEngine( const torch::jit::Module& module, std::string method_name, CompileSpec info); + +/** + * @brief Take a previously created TensorRT engine and embed it in + * in a TorchScript module + * + * @param engine: std::string - Pre-built serialized TensorRT engine + * + * Takes a pre-built serialized TensorRT engine and embeds it in a TorchScript + * module. Registers execution of the engine as the forward method of the module + * Forward is defined as: forward(Tensor[]) -> Tensor[] + * + * @return: A new module trageting a TensorRT engine + */ +TRTORCH_API torch::jit::Module EmbedEngineInNewModule(const std::string& engine); + /** * @brief Set gpu device id * diff --git a/cpp/api/src/trtorch.cpp b/cpp/api/src/trtorch.cpp index d13d34f105..1a5083fc90 100644 --- a/cpp/api/src/trtorch.cpp +++ b/cpp/api/src/trtorch.cpp @@ -31,6 +31,10 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module return core::CompileGraph(module, to_internal_compile_spec(info)); } +torch::jit::Module EmbedEngineInNewModule(const std::string& engine) { + return core::EmbedEngineInNewModule(engine); +} + std::string get_build_info() { auto info = core::util::get_build_info(); return std::string("TRTorch Version: ") + TRTORCH_VERSION + '\n' + info; diff --git a/py/trtorch/_compiler.py b/py/trtorch/_compiler.py index 65c91732e6..183644a065 100644 --- a/py/trtorch/_compiler.py +++ b/py/trtorch/_compiler.py @@ -124,6 +124,26 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st return trtorch._C.convert_graph_to_trt_engine(module._c, method_name, _parse_compile_spec(compile_spec)) +def embed_engine_in_new_module(serialized_engine: bytes) -> torch.jit.ScriptModule: + """Takes a pre-built serialized TensorRT engine and embeds it within a TorchScript module + + Takes a pre-built serialied TensorRT engine (as bytes) and embeds it within a TorchScript module. + Registers the forward method to execute the TensorRT engine with the function signature: + + forward(Tensor[]) -> Tensor[] + + Module can be save with engine embedded with torch.jit.save and moved / loaded according to TRTorch portability rules + + Args: + serialized_engine (bytes): Serialized TensorRT engine from either TRTorch or TensorRT APIs + + Returns: + torch.jit.ScriptModule: New TorchScript module with engine embedded + """ + cpp_mod = trtorch._C.embed_engine_in_new_module(serialized_engine) + return torch.jit._recursive.wrap_cpp_module(cpp_mod) + + def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) -> bool: """Checks to see if a method is fully supported by TRTorch diff --git a/py/trtorch/csrc/trtorch_py.cpp b/py/trtorch/csrc/trtorch_py.cpp index cb3d1d4e39..74c38f5d73 100644 --- a/py/trtorch/csrc/trtorch_py.cpp +++ b/py/trtorch/csrc/trtorch_py.cpp @@ -119,6 +119,10 @@ bool CheckMethodOperatorSupport(const torch::jit::Module& module, const std::str return core::CheckMethodOperatorSupport(module, method_name); } +torch::jit::Module EmbedEngineInNewModule(const py::bytes& engine) { + return core::EmbedEngineInNewModule(engine); +} + std::string get_build_info() { auto info = core::util::get_build_info(); return info; @@ -270,6 +274,10 @@ PYBIND11_MODULE(_C, m) { "check_method_op_support", &trtorch::pyapi::CheckMethodOperatorSupport, "Takes a module and a method name and checks if the method graph contains purely convertable operators"); + m.def( + "embed_engine_in_new_module", + &trtorch::pyapi::EmbedEngineInNewModule, + "Takes a serialized TensorRT engine and wraps it in the forward method of a new TorchScript module"); m.def("get_build_info", &get_build_info, "Returns build info about the compiler as a string"); m.def("_get_logging_prefix", &logging::get_logging_prefix, "Get the current prefix for the logging output"); diff --git a/tests/modules/test_modules_as_engines.cpp b/tests/modules/test_modules_as_engines.cpp index 6e6408575c..5fb1cf5862 100644 --- a/tests/modules/test_modules_as_engines.cpp +++ b/tests/modules/test_modules_as_engines.cpp @@ -16,6 +16,34 @@ TEST_P(ModuleTests, ModuleAsEngineIsClose) { ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-5)); } +TEST_P(ModuleTests, ModuleToEngineToModuleIsClose) { + std::vector inputs; + std::vector inputs_ivalues; + for (auto in_shape : input_shapes) { + inputs.push_back(at::randint(5, in_shape, {at::kCUDA})); + inputs_ivalues.push_back(inputs[inputs.size() - 1].clone()); + } + + torch::jit::IValue jit_results_ivalues = trtorch::tests::util::RunModuleForward(mod, inputs_ivalues); + std::vector jit_results; + jit_results.push_back(jit_results_ivalues.toTensor()); + + auto forward_graph = mod.get_method("forward"); + std::vector> input_ranges; + for (auto in : inputs) { + input_ranges.push_back(in.sizes()); + } + + auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", input_ranges); + auto trt_mod = trtorch::EmbedEngineInNewModule(engine); + + torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(mod, inputs_ivalues); + std::vector trt_results; + trt_results.push_back(trt_results_ivalues.toTensor()); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-5)); +} + INSTANTIATE_TEST_SUITE_P( ModuleAsEngineForwardIsCloseSuite, ModuleTests, diff --git a/tests/py/BUILD b/tests/py/BUILD index 510b3f681e..65a424466e 100644 --- a/tests/py/BUILD +++ b/tests/py/BUILD @@ -30,7 +30,7 @@ py_test( srcs = [ "test_ptq_dataloader_calibrator.py", "model_test_case.py" - ] + ], deps = [ requirement("torchvision") ] @@ -43,7 +43,7 @@ py_test( srcs = [ "test_ptq_trt_calibrator.py", "model_test_case.py" - ] + ], deps = [ requirement("torchvision") ] @@ -56,8 +56,6 @@ py_test( "test_multi_gpu.py", "model_test_case.py" ], - "//conditions:default" : [] - }), deps = [ requirement("torchvision") ] @@ -74,12 +72,23 @@ py_test( ] ) +py_test( + name = "test_trt_intercompatability", + srcs = [ + "test_trt_intercompatability.py", + "model_test_case.py" + ], + deps = [ + requirement("torchvision") + ] +) + py_test( name = "test_ptq_to_backend", srcs = [ "test_ptq_to_backend.py", "model_test_case.py" - ] + ], deps = [ requirement("torchvision") ] diff --git a/tests/py/test_api.py b/tests/py/test_api.py index a21385f6e1..31892e7a81 100644 --- a/tests/py/test_api.py +++ b/tests/py/test_api.py @@ -46,6 +46,30 @@ def test_compile_script(self): self.assertTrue(same < 2e-3) +class TestPTtoTRTtoPT(ModelTestCase): + + def setUp(self): + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + self.ts_model = torch.jit.script(self.model) + + def test_pt_to_trt_to_pt(self): + compile_spec = { + "input_shapes": [self.input.shape], + "device": { + "device_type": trtorch.DeviceType.GPU, + "gpu_id": 0, + "dla_core": 0, + "allow_gpu_fallback": False, + "disable_tf32": False + } + } + + trt_engine = trtorch.convert_method_to_trt_engine(self.ts_model, "forward", compile_spec) + trt_mod = trtorch.embed_engine_in_new_module(trt_engine) + same = (trt_mod(self.input) - self.ts_model(self.input)).abs().max() + self.assertTrue(same < 2e-3) + + class TestCheckMethodOpSupport(unittest.TestCase): def setUp(self): @@ -59,13 +83,13 @@ def test_check_support(self): class TestLoggingAPIs(unittest.TestCase): def test_logging_prefix(self): - new_prefix = "TEST" + new_prefix = "Python API Test: " trtorch.logging.set_logging_prefix(new_prefix) logging_prefix = trtorch.logging.get_logging_prefix() self.assertEqual(new_prefix, logging_prefix) def test_reportable_log_level(self): - new_level = trtorch.logging.Level.Warning + new_level = trtorch.logging.Level.Error trtorch.logging.set_reportable_log_level(new_level) level = trtorch.logging.get_reportable_log_level() self.assertEqual(new_level, level) @@ -78,10 +102,11 @@ def test_is_colored_output_on(self): def test_suite(): suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(TestLoggingAPIs)) suite.addTest(TestCompile.parametrize(TestCompile, model=models.resnet18(pretrained=True))) suite.addTest(TestCompile.parametrize(TestCompile, model=models.mobilenet_v2(pretrained=True))) + suite.addTest(TestPTtoTRTtoPT.parametrize(TestPTtoTRTtoPT, model=models.mobilenet_v2(pretrained=True))) suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport)) - suite.addTest(unittest.makeSuite(TestLoggingAPIs)) return suite diff --git a/tests/py/test_trt_intercompatability.py b/tests/py/test_trt_intercompatability.py new file mode 100644 index 0000000000..ffc4cb7217 --- /dev/null +++ b/tests/py/test_trt_intercompatability.py @@ -0,0 +1,55 @@ +import unittest +import trtorch +import torch +import torchvision.models as models +import tensorrt as trt + +from model_test_case import ModelTestCase + + +class TestPyTorchToTRTEngine(ModelTestCase): + + def setUp(self): + self.input = torch.randn((1, 3, 224, 224)).to("cuda:0") + self.ts_model = torch.jit.script(self.model) + + def test_pt_to_trt(self): + compile_spec = { + "input_shapes": [self.input.shape], + "device": { + "device_type": trtorch.DeviceType.GPU, + "gpu_id": 0, + "dla_core": 0, + "allow_gpu_fallback": False, + "disable_tf32": False + } + } + + trt_engine = trtorch.convert_method_to_trt_engine(self.ts_model, "forward", compile_spec) + + TRT_LOGGER = trt.Logger(trt.Logger.WARNING) + with trt.Runtime(TRT_LOGGER) as rt: + engine = rt.deserialize_cuda_engine(trt_engine) + with engine.create_execution_context() as ctx: + out = torch.empty(size=tuple(engine.get_binding_shape(1))).to("cuda:0") + bindings = [self.input.contiguous().data_ptr(), out.contiguous().data_ptr()] + ctx.execute_async(batch_size=1, + bindings=bindings, + stream_handle=torch.cuda.current_stream(device='cuda:0').cuda_stream) + same = (out - self.ts_model(self.input)).abs().max() + self.assertTrue(same < 2e-3) + + +def test_suite(): + suite = unittest.TestSuite() + suite.addTest(TestPyTorchToTRTEngine.parametrize(TestPyTorchToTRTEngine, model=models.resnet18(pretrained=True))) + + return suite + + +suite = test_suite() + +runner = unittest.TextTestRunner() +result = runner.run(suite) + +exit(int(not result.wasSuccessful()))