2323logger = logging .getLogger (__name__ )
2424
2525
26+ class DynamicOutputAllocator (trt .IOutputAllocator ): # type: ignore[misc]
27+ def __init__ (self , output_dtypes : Dict [str , torch .dtype ]) -> None :
28+ trt .IOutputAllocator .__init__ (self )
29+ self .buffers : Dict [str , torch .Tensor ] = {}
30+ self .shapes : Dict [str , Tuple [int , ...]] = {}
31+ self .dtypes : Dict [str , torch .dtype ] = output_dtypes
32+
33+ def reallocate_output_async (
34+ self ,
35+ tensor_name : str ,
36+ memory : int ,
37+ size : int ,
38+ alignment : int ,
39+ stream : torch .cuda .Stream ,
40+ ) -> Any :
41+ shape = (size ,)
42+ if tensor_name not in self .buffers :
43+ self .buffers [tensor_name ] = torch .empty (
44+ shape ,
45+ dtype = self .dtypes [tensor_name ],
46+ device = torch .cuda .current_device (),
47+ )
48+ else :
49+ if self .buffers [tensor_name ].shape != shape :
50+ self .buffers [tensor_name ] = torch .empty (
51+ shape ,
52+ dtype = self .dtypes [tensor_name ],
53+ device = torch .cuda .current_device (),
54+ )
55+ return self .buffers [tensor_name ].data_ptr ()
56+
57+ def notify_shape (self , tensor_name : str , shape : Tuple [int , ...]) -> None :
58+ self .shapes [tensor_name ] = tuple (shape )
59+
60+
2661class TorchTRTRuntimeStates :
2762 def __init__ (self , new_cudagraphs : bool ):
2863 # Indicates whether CUDAGraphs were enabled in the previous execute_engine
@@ -164,8 +199,11 @@ def __init__(
164199 self .runtime_states = TorchTRTRuntimeStates (
165200 torch_tensorrt .runtime .get_cudagraphs_mode ()
166201 )
202+
203+ self .contains_dds_layer = False
167204 self .pre_allocated_outputs : List [torch .Tensor ] = []
168205 self .use_pre_allocated_outputs = False
206+ self .output_allocator : Optional [DynamicOutputAllocator ] = None
169207
170208 if self .serialized_engine is not None and not self .settings .lazy_engine_init :
171209 self .setup_engine ()
@@ -238,9 +276,19 @@ def setup_engine(self) -> None:
238276 for output_name in self .output_names
239277 ]
240278
279+ self .contains_dds_layer = self ._check_dds_layer ()
280+ if self .contains_dds_layer :
281+ self .setup_output_allocator ()
282+
241283 if torch_tensorrt .runtime .get_cudagraphs_mode ():
242284 self .cudagraph = torch .cuda .CUDAGraph ()
243285
286+ def _check_dds_layer (self ) -> bool :
287+ layer_info = self .get_layer_info ()
288+ if "trainStation" in layer_info : # contains dds layer
289+ return True
290+ return False
291+
244292 def _check_initialized (self ) -> None :
245293 if not self .initialized :
246294 raise RuntimeError ("PythonTorchTensorRTModule is not initialized." )
@@ -358,19 +406,22 @@ def create_output_tensors(self) -> List[torch.Tensor]:
358406 def set_pre_allocated_outputs (self , enable : bool ) -> None :
359407 self .use_pre_allocated_outputs = enable
360408
409+ def setup_output_allocator (self ) -> None :
410+ if self .output_allocator is None :
411+ output_dtypes_dict = {}
412+ for o , output_name in enumerate (self .output_names ):
413+ output_dtypes_dict [output_name ] = self .output_dtypes [o ]
414+ self .output_allocator = DynamicOutputAllocator (output_dtypes_dict )
415+
416+ for output_name in self .output_names :
417+ if not self .context .set_output_allocator (
418+ output_name , self .output_allocator
419+ ):
420+ raise RuntimeError (f"Failed to set output allocator for { output_name } " )
421+
361422 def forward (self , * inputs : torch .Tensor ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
362- # Ensure inputs are available in all scopes and cast symbolic integers to Tensors
363- contiguous_inputs : List [torch .Tensor ] = [
364- (i .contiguous () if isinstance (i , torch .Tensor ) else torch .tensor (i ).cuda ())
365- for i in inputs
366- ]
367- with (
368- torch .autograd .profiler .record_function ("PythonTorchTensorRTModule:Forward" )
369- if self .profiling_enabled
370- else nullcontext ()
371- ):
372- self ._check_initialized ()
373423
424+ def run_cuda_graph () -> torch .Tensor | Tuple [torch .Tensor , ...]:
374425 cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
375426 shape_changed = self .validate_input_shapes (inputs )
376427 (
@@ -389,38 +440,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
389440 self ._input_buffers = [None ] * len (self .input_names )
390441 self ._output_buffers = [None ] * len (self .output_names )
391442
392- # If in safe mode, check at each iteration for whether a switch is required
393- if (
394- torch_tensorrt .runtime ._multi_device_safe_mode ._PY_RT_MULTI_DEVICE_SAFE_MODE
395- ):
396- curr_device_id = torch .cuda .current_device ()
397- curr_device_properties = torch .cuda .get_device_properties (
398- curr_device_id
399- )
400- logger .debug (f"Current Device: cuda:{ curr_device_id } " )
401-
402- # If a switch is required, move all inputs to new device and set as active device
403- if _is_switch_required (
404- curr_device_id ,
405- self .target_device_id ,
406- curr_device_properties ,
407- self .target_device_properties ,
408- ):
409- device_id , _ = _select_rt_device (
410- curr_device_id ,
411- self .target_device_id ,
412- self .target_device_properties ,
413- )
414-
415- # Update current device
416- device = torch .device (device_id )
417- torch .cuda .set_device (device_id )
418-
419- contiguous_inputs = [
420- tensor .to (device ) for tensor in contiguous_inputs
421- ]
422- logger .warning (f"Moved all input Tensors to cuda:{ device_id } " )
423-
424443 with (
425444 torch .autograd .profiler .record_function (
426445 "PythonTorchTensorRTModule:ProcessInputs"
@@ -536,6 +555,118 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
536555
537556 return outputs
538557
558+ def run_output_allocator () -> torch .Tensor | Tuple [torch .Tensor , ...]:
559+ with (
560+ torch .autograd .profiler .record_function (
561+ "PythonTorchTensorRTModule:ProcessInputs"
562+ )
563+ if self .profiling_enabled
564+ else nullcontext ()
565+ ):
566+ assert len (contiguous_inputs ) == len (
567+ self .input_names
568+ ), f"Wrong number of inputs, expect { len (self .input_names )} get { len (contiguous_inputs )} ."
569+
570+ self .setup_input_tensors (contiguous_inputs , False , False )
571+
572+ with (
573+ torch .autograd .profiler .record_function (
574+ "PythonTorchTensorRTModule:TensorRTRuntime"
575+ )
576+ if self .profiling_enabled
577+ else nullcontext ()
578+ ):
579+ self ._caller_stream = torch .cuda .current_stream ()
580+ if (
581+ self ._engine_stream == torch .cuda .default_stream ()
582+ or self ._engine_stream is None
583+ ):
584+ self ._engine_stream = torch .cuda .Stream ()
585+
586+ self ._engine_stream .wait_stream (self ._caller_stream )
587+
588+ with torch .cuda .stream (self ._engine_stream ):
589+ self .context .execute_async_v3 (
590+ self ._engine_stream .cuda_stream
591+ ) # The OutputAllocator is called by execute_async_v3()
592+
593+ self ._caller_stream .wait_stream (self ._engine_stream )
594+
595+ with (
596+ torch .autograd .profiler .record_function (
597+ "PythonTorchTensorRTModule:ProcessOutputs"
598+ )
599+ if self .profiling_enabled
600+ else nullcontext ()
601+ ):
602+ outputs = []
603+ assert self .output_allocator is not None
604+ for o , output_name in enumerate (self .output_names ):
605+ shape = self .output_allocator .shapes .get (output_name , None )
606+ dtype = self .output_dtypes [o ]
607+ output = (
608+ self .output_allocator .buffers .get (output_name , None )
609+ .clone ()
610+ .detach ()
611+ )
612+ prod = int (torch .prod (torch .tensor (shape )))
613+ output = output .reshape (- 1 ).view (dtype )[:prod ].reshape (shape )
614+ outputs .append (output )
615+
616+ if len (outputs ) == 1 :
617+ return outputs [0 ]
618+
619+ return outputs
620+
621+ # Run forward function
622+ contiguous_inputs : List [torch .Tensor ] = [
623+ (i .contiguous () if isinstance (i , torch .Tensor ) else torch .tensor (i ).cuda ())
624+ for i in inputs
625+ ]
626+ with (
627+ torch .autograd .profiler .record_function ("PythonTorchTensorRTModule:Forward" )
628+ if self .profiling_enabled
629+ else nullcontext ()
630+ ):
631+ self ._check_initialized ()
632+
633+ # If in safe mode, check at each iteration for whether a switch is required
634+ if (
635+ torch_tensorrt .runtime ._multi_device_safe_mode ._PY_RT_MULTI_DEVICE_SAFE_MODE
636+ ):
637+ curr_device_id = torch .cuda .current_device ()
638+ curr_device_properties = torch .cuda .get_device_properties (
639+ curr_device_id
640+ )
641+ logger .debug (f"Current Device: cuda:{ curr_device_id } " )
642+
643+ # If a switch is required, move all inputs to new device and set as active device
644+ if _is_switch_required (
645+ curr_device_id ,
646+ self .target_device_id ,
647+ curr_device_properties ,
648+ self .target_device_properties ,
649+ ):
650+ device_id , _ = _select_rt_device (
651+ curr_device_id ,
652+ self .target_device_id ,
653+ self .target_device_properties ,
654+ )
655+
656+ # Update current device
657+ device = torch .device (device_id )
658+ torch .cuda .set_device (device_id )
659+
660+ contiguous_inputs = [
661+ tensor .to (device ) for tensor in contiguous_inputs
662+ ]
663+ logger .warning (f"Moved all input Tensors to cuda:{ device_id } " )
664+
665+ if self .contains_dds_layer :
666+ return run_output_allocator ()
667+ else :
668+ return run_cuda_graph ()
669+
539670 def enable_profiling (self , profiler : "trt.IProfiler" = None ) -> None :
540671 """
541672 Enable TensorRT profiling. After calling this function, TensorRT will report
0 commit comments