@@ -30,13 +30,37 @@ std::vector<std::string> split(const std::string& str, char delim) {
3030 return strings;
3131}
3232
33+ DynamicOutputAllocator::DynamicOutputAllocator (const std::unordered_map<std::string, at::ScalarType>& output_dtypes)
34+ : dtypes(output_dtypes) {}
35+
36+ void * DynamicOutputAllocator::reallocateOutputAsync (
37+ char const * tensorName,
38+ void * currentMemory,
39+ uint64_t size,
40+ uint64_t alignment,
41+ cudaStream_t stream) {
42+ std::vector<int64_t > shape = {static_cast <int64_t >(size)};
43+ auto it = buffers.find (tensorName);
44+ if (it == buffers.end () || it->second .sizes () != shape) {
45+ buffers[tensorName] = at::empty (shape, at::TensorOptions ().dtype (dtypes.at (tensorName)).device (at::kCUDA ));
46+ return buffers[tensorName].data_ptr ();
47+ } else {
48+ return it->second .data_ptr ();
49+ }
50+ }
51+
52+ void DynamicOutputAllocator::notifyShape (char const * tensorName, nvinfer1::Dims const & dims) noexcept {
53+ shapes[tensorName] = dims;
54+ }
55+
3356TRTEngine::TRTEngine (
3457 const std::string& serialized_engine,
3558 const RTDevice& cuda_device,
3659 const std::vector<std::string>& _in_binding_names,
3760 const std::vector<std::string>& _out_binding_names,
3861 const Platform& target_platform,
3962 bool hardware_compatible,
63+ bool requires_output_allocator,
4064 const std::string& serialized_metadata)
4165 : TRTEngine(
4266 " deserialized_trt" ,
@@ -46,6 +70,7 @@ TRTEngine::TRTEngine(
4670 _out_binding_names,
4771 target_platform,
4872 hardware_compatible,
73+ requires_output_allocator,
4974 serialized_metadata) {}
5075
5176TRTEngine::TRTEngine (std::vector<std::string> serialized_info)
@@ -57,6 +82,7 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
5782 split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM),
5883 Platform(serialized_info[TARGET_PLATFORM_IDX]),
5984 static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),
85+ static_cast<bool>(std::stoi(serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX])),
6086 serialized_info[SERIALIZED_METADATA_IDX]) {}
6187
6288TRTEngine::TRTEngine (
@@ -67,6 +93,7 @@ TRTEngine::TRTEngine(
6793 const std::vector<std::string>& _out_binding_names,
6894 const Platform& target_platform,
6995 bool hardware_compatible,
96+ bool requires_output_allocator,
7097 const std::string& serialized_metadata) {
7198 TORCHTRT_CHECK (
7299 is_supported_on_current_platform (target_platform),
@@ -79,6 +106,7 @@ TRTEngine::TRTEngine(
79106 TORCHTRT_CHECK (most_compatible_device, " No compatible device was found for instantiating TensorRT engine" );
80107
81108 this ->serialized_metadata = serialized_metadata;
109+ this ->requires_output_allocator = requires_output_allocator;
82110 device_info = most_compatible_device.value ();
83111 multi_gpu_device_check ();
84112 set_rt_device (device_info);
@@ -397,6 +425,7 @@ FlattenedState TRTEngine::__obj_flatten__() {
397425 std::tuple (" out_binding_names" , serialized_info[OUTPUT_BINDING_NAMES_IDX]),
398426 std::tuple (" hardware_compatible" , serialized_info[HW_COMPATIBLE_IDX]),
399427 std::tuple (" serialized_metadata" , serialized_info[SERIALIZED_METADATA_IDX]),
428+ std::tuple (" requires_output_allocator" , serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]),
400429 std::tuple (" target_platform" , serialized_info[TARGET_PLATFORM_IDX]));
401430}
402431
@@ -417,6 +446,7 @@ std::vector<std::string> TRTEngine::serialize() {
417446 serialized_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings (this ->in_binding_names );
418447 serialized_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings (this ->out_binding_names );
419448 serialized_info[HW_COMPATIBLE_IDX] = this ->hardware_compatible ? " 1" : " 0" ;
449+ serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = this ->requires_output_allocator ? " 1" : " 0" ;
420450 serialized_info[SERIALIZED_METADATA_IDX] = this ->serialized_metadata ;
421451 serialized_info[TARGET_PLATFORM_IDX] = this ->target_platform .serialize ();
422452
0 commit comments