Skip to content

Commit e336630

Browse files
committed
feat: Using shared_ptrs to manage TRT resources in runtime
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 2cd9fad commit e336630

File tree

2 files changed

+8
-19
lines changed

2 files changed

+8
-19
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,14 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
4141
device_info = cuda_device;
4242
set_cuda_device(device_info);
4343

44-
rt = nvinfer1::createInferRuntime(util::logging::get_logger());
44+
rt = std::shared_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(util::logging::get_logger()));
4545

4646
name = slugify(mod_name) + "_engine";
4747

48-
cuda_engine = rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size());
48+
cuda_engine = std::shared_ptr<nvinfer1::ICudaEngine>(rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size()));
4949
TRTORCH_CHECK((cuda_engine != nullptr), "Unable to deserialize the TensorRT engine");
5050

51-
// Easy way to get a unique name for each engine, maybe there is a more
52-
// descriptive way (using something associated with the graph maybe)
53-
id = reinterpret_cast<EngineID>(cuda_engine);
54-
55-
exec_ctx = cuda_engine->createExecutionContext();
51+
exec_ctx = std::shared_ptr<nvinfer1::IExecutionContext>(cuda_engine->createExecutionContext());
5652

5753
uint64_t inputs = 0;
5854
uint64_t outputs = 0;
@@ -74,7 +70,6 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
7470
}
7571

7672
TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
77-
id = other.id;
7873
rt = other.rt;
7974
cuda_engine = other.cuda_engine;
8075
device_info = other.device_info;
@@ -83,12 +78,6 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
8378
return (*this);
8479
}
8580

86-
TRTEngine::~TRTEngine() {
87-
delete exec_ctx;
88-
delete cuda_engine;
89-
delete rt;
90-
}
91-
9281
// TODO: Implement a call method
9382
// c10::List<at::Tensor> TRTEngine::Run(c10::List<at::Tensor> inputs) {
9483
// auto input_vec = inputs.vec();

core/runtime/runtime.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include <map>
3+
#include <memory>
34
#include <utility>
45
#include "ATen/core/function_schema.h"
56
#include "NvInfer.h"
@@ -37,18 +38,17 @@ CudaDevice deserialize_device(std::string device_info);
3738

3839
struct TRTEngine : torch::CustomClassHolder {
3940
// Each engine needs it's own runtime object
40-
nvinfer1::IRuntime* rt;
41-
nvinfer1::ICudaEngine* cuda_engine;
42-
nvinfer1::IExecutionContext* exec_ctx;
41+
std::shared_ptr<nvinfer1::IRuntime> rt;
42+
std::shared_ptr<nvinfer1::ICudaEngine> cuda_engine;
43+
std::shared_ptr<nvinfer1::IExecutionContext> exec_ctx;
4344
std::pair<uint64_t, uint64_t> num_io;
44-
EngineID id;
4545
std::string name;
4646
CudaDevice device_info;
4747

4848
std::unordered_map<uint64_t, uint64_t> in_binding_map;
4949
std::unordered_map<uint64_t, uint64_t> out_binding_map;
5050

51-
~TRTEngine();
51+
~TRTEngine() = default;
5252
TRTEngine(std::string serialized_engine, CudaDevice cuda_device);
5353
TRTEngine(std::vector<std::string> serialized_info);
5454
TRTEngine(std::string mod_name, std::string serialized_engine, CudaDevice cuda_device);

0 commit comments

Comments
 (0)