@@ -41,18 +41,14 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
41
41
device_info = cuda_device;
42
42
set_cuda_device (device_info);
43
43
44
- rt = nvinfer1::createInferRuntime (util::logging::get_logger ());
44
+ rt = std::shared_ptr< nvinfer1::IRuntime>( nvinfer1:: createInferRuntime (util::logging::get_logger () ));
45
45
46
46
name = slugify (mod_name) + " _engine" ;
47
47
48
- cuda_engine = rt->deserializeCudaEngine (serialized_engine.c_str (), serialized_engine.size ());
48
+ cuda_engine = std::shared_ptr<nvinfer1::ICudaEngine>( rt->deserializeCudaEngine (serialized_engine.c_str (), serialized_engine.size () ));
49
49
TRTORCH_CHECK ((cuda_engine != nullptr ), " Unable to deserialize the TensorRT engine" );
50
50
51
- // Easy way to get a unique name for each engine, maybe there is a more
52
- // descriptive way (using something associated with the graph maybe)
53
- id = reinterpret_cast <EngineID>(cuda_engine);
54
-
55
- exec_ctx = cuda_engine->createExecutionContext ();
51
+ exec_ctx = std::shared_ptr<nvinfer1::IExecutionContext>(cuda_engine->createExecutionContext ());
56
52
57
53
uint64_t inputs = 0 ;
58
54
uint64_t outputs = 0 ;
@@ -74,7 +70,6 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
74
70
}
75
71
76
72
TRTEngine& TRTEngine::operator =(const TRTEngine& other) {
77
- id = other.id ;
78
73
rt = other.rt ;
79
74
cuda_engine = other.cuda_engine ;
80
75
device_info = other.device_info ;
@@ -83,12 +78,6 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
83
78
return (*this );
84
79
}
85
80
86
- TRTEngine::~TRTEngine () {
87
- delete exec_ctx;
88
- delete cuda_engine;
89
- delete rt;
90
- }
91
-
92
81
// TODO: Implement a call method
93
82
// c10::List<at::Tensor> TRTEngine::Run(c10::List<at::Tensor> inputs) {
94
83
// auto input_vec = inputs.vec();
0 commit comments