@@ -11,7 +11,7 @@ namespace trtorch {
1111namespace core {
1212namespace runtime {
1313
14- typedef enum { ABI_TARGET_IDX = 0 , DEVICE_IDX, ENGINE_IDX } SerializedInfoIndex;
14+ typedef enum { ABI_TARGET_IDX = 0 , NAME_IDX, DEVICE_IDX, ENGINE_IDX } SerializedInfoIndex;
1515
1616std::string slugify (std::string s) {
1717 std::replace (s.begin (), s.end (), ' .' , ' _' );
@@ -37,8 +37,8 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
3737 TRTORCH_CHECK (
3838 serialized_info[ABI_TARGET_IDX] == ABI_VERSION,
3939 " Program to be deserialized targets a different TRTorch ABI Version ("
40- << serialized_info[ABI_TARGET_IDX] << " ) than the TRTorch Runtime ABI (" << ABI_VERSION << " )" );
41- std::string _name = " deserialized_trt " ;
40+ << serialized_info[ABI_TARGET_IDX] << " ) than the TRTorch Runtime ABI Version (" << ABI_VERSION << " )" );
41+ std::string _name = serialized_info[NAME_IDX] ;
4242 std::string engine_info = serialized_info[ENGINE_IDX];
4343
4444 CudaDevice cuda_device = deserialize_device (serialized_info[DEVICE_IDX]);
@@ -55,7 +55,7 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
5555
5656 rt = nvinfer1::createInferRuntime (logger);
5757
58- name = slugify (mod_name) + " _engine " ;
58+ name = slugify (mod_name);
5959
6060 cuda_engine = rt->deserializeCudaEngine (serialized_engine.c_str (), serialized_engine.size ());
6161 TRTORCH_CHECK ((cuda_engine != nullptr ), " Unable to deserialize the TensorRT engine" );
@@ -70,8 +70,8 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
7070 uint64_t outputs = 0 ;
7171
7272 for (int64_t x = 0 ; x < cuda_engine->getNbBindings (); x++) {
73- std::string name = cuda_engine->getBindingName (x);
74- std::string idx_s = name .substr (name .find (" _" ) + 1 );
73+ std::string bind_name = cuda_engine->getBindingName (x);
74+ std::string idx_s = bind_name .substr (bind_name .find (" _" ) + 1 );
7575 uint64_t idx = static_cast <uint64_t >(std::stoi (idx_s));
7676
7777 if (cuda_engine->bindingIsInput (x)) {
@@ -124,9 +124,12 @@ static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
124124 auto trt_engine = std::string ((const char *)serialized_trt_engine->data (), serialized_trt_engine->size ());
125125
126126 std::vector<std::string> serialize_info;
127- serialize_info.push_back (ABI_VERSION);
128- serialize_info.push_back (serialize_device (self->device_info ));
129- serialize_info.push_back (trt_engine);
127+ serialize_info.resize (ENGINE_IDX + 1 );
128+
129+ serialize_info[ABI_TARGET_IDX] = ABI_VERSION;
130+ serialize_info[NAME_IDX] = self->name ;
131+ serialize_info[DEVICE_IDX] = serialize_device (self->device_info );
132+ serialize_info[ENGINE_IDX] = trt_engine;
130133 return serialize_info;
131134 },
132135 [](std::vector<std::string> seralized_info) -> c10::intrusive_ptr<TRTEngine> {
0 commit comments