Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ c10::FunctionSchema GenerateGraphSchema(
void AddEngineToGraph(
torch::jit::script::Module mod,
std::shared_ptr<torch::jit::Graph>& g,
std::string& serialized_engine) {
const std::string& serialized_engine) {
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name(), serialized_engine);
// Get required metadata about the engine out
auto num_io = engine_ptr->num_io;
Expand Down Expand Up @@ -173,6 +173,20 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
return new_mod;
}

torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine) {
std::ostringstream engine_id;
engine_id << reinterpret_cast<const int*>(&engine);
torch::jit::script::Module new_mod("tensorrt_engine_mod_" + engine_id.str());
auto new_g = std::make_shared<torch::jit::Graph>();
AddEngineToGraph(new_mod, new_g, engine);
auto new_method = new_mod._ivalue()->compilation_unit()->create_function("forward", new_g);
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
new_mod.type()->addMethod(new_method);
new_method->setSchema(schema);

return new_mod;
}

void set_device(const int gpu_id) {
TRTORCH_ASSERT(cudaSetDevice(gpu_id) == cudaSuccess, "Unable to set CUDA device: " << gpu_id);
}
Expand Down
2 changes: 2 additions & 0 deletions core/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::

torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg);

torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine);

void set_device(const int gpu_id);

} // namespace core
Expand Down
15 changes: 15 additions & 0 deletions cpp/api/include/trtorch/trtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,21 @@ TRTORCH_API std::string ConvertGraphToTRTEngine(
const torch::jit::Module& module,
std::string method_name,
CompileSpec info);

/**
* @brief Take a previously created TensorRT engine and embed it in
* in a TorchScript module
*
* @param engine: std::string - Pre-built serialized TensorRT engine
*
* Takes a pre-built serialized TensorRT engine and embeds it in a TorchScript
* module. Registers execution of the engine as the forward method of the module
* Forward is defined as: forward(Tensor[]) -> Tensor[]
*
* @return: A new module trageting a TensorRT engine
*/
TRTORCH_API torch::jit::Module EmbedEngineInNewModule(const std::string& engine);

/**
* @brief Set gpu device id
*
Expand Down
4 changes: 4 additions & 0 deletions cpp/api/src/trtorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module
return core::CompileGraph(module, to_internal_compile_spec(info));
}

torch::jit::Module EmbedEngineInNewModule(const std::string& engine) {
return core::EmbedEngineInNewModule(engine);
}

std::string get_build_info() {
auto info = core::util::get_build_info();
return std::string("TRTorch Version: ") + TRTORCH_VERSION + '\n' + info;
Expand Down
20 changes: 20 additions & 0 deletions py/trtorch/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,26 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
return trtorch._C.convert_graph_to_trt_engine(module._c, method_name, _parse_compile_spec(compile_spec))


def embed_engine_in_new_module(serialized_engine: bytes) -> torch.jit.ScriptModule:
"""Takes a pre-built serialized TensorRT engine and embeds it within a TorchScript module

Takes a pre-built serialied TensorRT engine (as bytes) and embeds it within a TorchScript module.
Registers the forward method to execute the TensorRT engine with the function signature:

forward(Tensor[]) -> Tensor[]

Module can be save with engine embedded with torch.jit.save and moved / loaded according to TRTorch portability rules

Args:
serialized_engine (bytes): Serialized TensorRT engine from either TRTorch or TensorRT APIs

Returns:
torch.jit.ScriptModule: New TorchScript module with engine embedded
"""
cpp_mod = trtorch._C.embed_engine_in_new_module(serialized_engine)
return torch.jit._recursive.wrap_cpp_module(cpp_mod)


def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) -> bool:
"""Checks to see if a method is fully supported by TRTorch

Expand Down
8 changes: 8 additions & 0 deletions py/trtorch/csrc/trtorch_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ bool CheckMethodOperatorSupport(const torch::jit::Module& module, const std::str
return core::CheckMethodOperatorSupport(module, method_name);
}

torch::jit::Module EmbedEngineInNewModule(const py::bytes& engine) {
return core::EmbedEngineInNewModule(engine);
}

std::string get_build_info() {
auto info = core::util::get_build_info();
return info;
Expand Down Expand Up @@ -270,6 +274,10 @@ PYBIND11_MODULE(_C, m) {
"check_method_op_support",
&trtorch::pyapi::CheckMethodOperatorSupport,
"Takes a module and a method name and checks if the method graph contains purely convertable operators");
m.def(
"embed_engine_in_new_module",
&trtorch::pyapi::EmbedEngineInNewModule,
"Takes a serialized TensorRT engine and wraps it in the forward method of a new TorchScript module");
m.def("get_build_info", &get_build_info, "Returns build info about the compiler as a string");

m.def("_get_logging_prefix", &logging::get_logging_prefix, "Get the current prefix for the logging output");
Expand Down
28 changes: 28 additions & 0 deletions tests/modules/test_modules_as_engines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,34 @@ TEST_P(ModuleTests, ModuleAsEngineIsClose) {
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-5));
}

TEST_P(ModuleTests, ModuleToEngineToModuleIsClose) {
std::vector<at::Tensor> inputs;
std::vector<torch::jit::IValue> inputs_ivalues;
for (auto in_shape : input_shapes) {
inputs.push_back(at::randint(5, in_shape, {at::kCUDA}));
inputs_ivalues.push_back(inputs[inputs.size() - 1].clone());
}

torch::jit::IValue jit_results_ivalues = trtorch::tests::util::RunModuleForward(mod, inputs_ivalues);
std::vector<at::Tensor> jit_results;
jit_results.push_back(jit_results_ivalues.toTensor());

auto forward_graph = mod.get_method("forward");
std::vector<c10::ArrayRef<int64_t>> input_ranges;
for (auto in : inputs) {
input_ranges.push_back(in.sizes());
}

auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", input_ranges);
auto trt_mod = trtorch::EmbedEngineInNewModule(engine);

torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(mod, inputs_ivalues);
std::vector<at::Tensor> trt_results;
trt_results.push_back(trt_results_ivalues.toTensor());

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-5));
}

INSTANTIATE_TEST_SUITE_P(
ModuleAsEngineForwardIsCloseSuite,
ModuleTests,
Expand Down
19 changes: 14 additions & 5 deletions tests/py/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ py_test(
srcs = [
"test_ptq_dataloader_calibrator.py",
"model_test_case.py"
]
],
deps = [
requirement("torchvision")
]
Expand All @@ -43,7 +43,7 @@ py_test(
srcs = [
"test_ptq_trt_calibrator.py",
"model_test_case.py"
]
],
deps = [
requirement("torchvision")
]
Expand All @@ -56,8 +56,6 @@ py_test(
"test_multi_gpu.py",
"model_test_case.py"
],
"//conditions:default" : []
}),
deps = [
requirement("torchvision")
]
Expand All @@ -74,12 +72,23 @@ py_test(
]
)

py_test(
name = "test_trt_intercompatability",
srcs = [
"test_trt_intercompatability.py",
"model_test_case.py"
],
deps = [
requirement("torchvision")
]
)

py_test(
name = "test_ptq_to_backend",
srcs = [
"test_ptq_to_backend.py",
"model_test_case.py"
]
],
deps = [
requirement("torchvision")
]
Expand Down
31 changes: 28 additions & 3 deletions tests/py/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,30 @@ def test_compile_script(self):
self.assertTrue(same < 2e-3)


class TestPTtoTRTtoPT(ModelTestCase):

def setUp(self):
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
self.ts_model = torch.jit.script(self.model)

def test_pt_to_trt_to_pt(self):
compile_spec = {
"input_shapes": [self.input.shape],
"device": {
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": False,
"disable_tf32": False
}
}

trt_engine = trtorch.convert_method_to_trt_engine(self.ts_model, "forward", compile_spec)
trt_mod = trtorch.embed_engine_in_new_module(trt_engine)
same = (trt_mod(self.input) - self.ts_model(self.input)).abs().max()
self.assertTrue(same < 2e-3)


class TestCheckMethodOpSupport(unittest.TestCase):

def setUp(self):
Expand All @@ -59,13 +83,13 @@ def test_check_support(self):
class TestLoggingAPIs(unittest.TestCase):

def test_logging_prefix(self):
new_prefix = "TEST"
new_prefix = "Python API Test: "
trtorch.logging.set_logging_prefix(new_prefix)
logging_prefix = trtorch.logging.get_logging_prefix()
self.assertEqual(new_prefix, logging_prefix)

def test_reportable_log_level(self):
new_level = trtorch.logging.Level.Warning
new_level = trtorch.logging.Level.Error
trtorch.logging.set_reportable_log_level(new_level)
level = trtorch.logging.get_reportable_log_level()
self.assertEqual(new_level, level)
Expand All @@ -78,10 +102,11 @@ def test_is_colored_output_on(self):

def test_suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(TestLoggingAPIs))
suite.addTest(TestCompile.parametrize(TestCompile, model=models.resnet18(pretrained=True)))
suite.addTest(TestCompile.parametrize(TestCompile, model=models.mobilenet_v2(pretrained=True)))
suite.addTest(TestPTtoTRTtoPT.parametrize(TestPTtoTRTtoPT, model=models.mobilenet_v2(pretrained=True)))
suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport))
suite.addTest(unittest.makeSuite(TestLoggingAPIs))

return suite

Expand Down
55 changes: 55 additions & 0 deletions tests/py/test_trt_intercompatability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import unittest
import trtorch
import torch
import torchvision.models as models
import tensorrt as trt

from model_test_case import ModelTestCase


class TestPyTorchToTRTEngine(ModelTestCase):

def setUp(self):
self.input = torch.randn((1, 3, 224, 224)).to("cuda:0")
self.ts_model = torch.jit.script(self.model)

def test_pt_to_trt(self):
compile_spec = {
"input_shapes": [self.input.shape],
"device": {
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": False,
"disable_tf32": False
}
}

trt_engine = trtorch.convert_method_to_trt_engine(self.ts_model, "forward", compile_spec)

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
with trt.Runtime(TRT_LOGGER) as rt:
engine = rt.deserialize_cuda_engine(trt_engine)
with engine.create_execution_context() as ctx:
out = torch.empty(size=tuple(engine.get_binding_shape(1))).to("cuda:0")
bindings = [self.input.contiguous().data_ptr(), out.contiguous().data_ptr()]
ctx.execute_async(batch_size=1,
bindings=bindings,
stream_handle=torch.cuda.current_stream(device='cuda:0').cuda_stream)
same = (out - self.ts_model(self.input)).abs().max()
self.assertTrue(same < 2e-3)


def test_suite():
suite = unittest.TestSuite()
suite.addTest(TestPyTorchToTRTEngine.parametrize(TestPyTorchToTRTEngine, model=models.resnet18(pretrained=True)))

return suite


suite = test_suite()

runner = unittest.TextTestRunner()
result = runner.run(suite)

exit(int(not result.wasSuccessful()))