11from __future__ import annotations
22
33import logging
4+ from contextlib import nullcontext
45from typing import Any , Dict , List , Optional , Sequence , Tuple
56
67import tensorrt as trt
78import torch
89from torch .nn import Module
10+ from torch_tensorrt ._Device import Device
11+ from torch_tensorrt .dynamo .runtime .tools import _is_switch_required , _select_rt_device
912from torch_tensorrt .fx .utils import Frameworks , unified_dtype_converter
1013
14+ import torch_tensorrt
15+
1116logger = logging .getLogger (__name__ )
1217
1318
@@ -23,13 +28,22 @@ def __init__(
2328 engine : trt .ICudaEngine ,
2429 input_names : Optional [List [str ]] = None ,
2530 output_names : Optional [List [str ]] = None ,
31+ target_device : Device = Device ._current_device (),
32+ profiling_enabled : Optional [bool ] = None ,
2633 ):
2734 super (PythonTorchTensorRTModule , self ).__init__ ()
2835 self ._register_state_dict_hook (PythonTorchTensorRTModule ._on_state_dict )
2936 self .engine = engine
3037 self .input_names = input_names if input_names is not None else []
3138 self .output_names = output_names if output_names is not None else []
3239 self .initialized = False
40+ self .target_device_id = target_device .gpu_id
41+ self .target_device_properties = torch .cuda .get_device_properties (
42+ self .target_device_id
43+ )
44+ self .profiling_enabled = (
45+ profiling_enabled if profiling_enabled is not None else False
46+ )
3347 self ._initialize ()
3448
3549 def _initialize (self ) -> None :
@@ -141,15 +155,41 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
141155 if self .engine :
142156 self .context = self .engine .create_execution_context ()
143157
144- def forward (self , * inputs : Any ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
158+ def forward (self , * inputs : torch . Tensor ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
145159 with torch .autograd .profiler .record_function (
146160 "PythonTorchTensorRTModule:Forward"
147- ):
161+ ) if self . profiling_enabled else nullcontext () :
148162 self ._check_initialized ()
149163
164+ # If in safe mode, check at each iteration for for whether a switch is required
165+ if torch_tensorrt ._compile .SAFE_MODE :
166+ curr_device_id = torch .cuda .current_device ()
167+ curr_device_properties = torch .cuda .get_device_properties (
168+ curr_device_id
169+ )
170+ logger .debug (f"Current Device: cuda:{ curr_device_id } " )
171+
172+ # If a switch is required, move all inputs to new device and set as active device
173+ if _is_switch_required (
174+ curr_device_id ,
175+ self .target_device_id ,
176+ curr_device_properties ,
177+ self .target_device_properties ,
178+ ):
179+ device_id , _ = _select_rt_device (
180+ curr_device_id ,
181+ self .target_device_id ,
182+ self .target_device_properties ,
183+ )
184+ device = torch .device (device_id )
185+ torch .cuda .set_device (device_id )
186+
187+ inputs = tuple ([tensor .to (device ) for tensor in inputs ])
188+ logger .warning (f"Moved all input Tensors to cuda:{ device_id } " )
189+
150190 with torch .autograd .profiler .record_function (
151191 "PythonTorchTensorRTModule:ProcessInputs"
152- ):
192+ ) if self . profiling_enabled else nullcontext () :
153193 assert len (inputs ) == len (
154194 self .input_names
155195 ), f"Wrong number of inputs, expect { len (self .input_names )} get { len (inputs )} ."
@@ -162,22 +202,24 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
162202 )
163203
164204 for i , input_name in enumerate (self .input_names ):
165- if not contiguous_inputs [i ].is_cuda :
166- logger .warning (
167- f"Detected input { input_name } of engine { self .engine .name } is not on a cuda device. "
168- "This tensor is being moved by the runtime but for performance considerations, "
169- "ensure your inputs are all on GPU and open an issue here "
170- "(https://github.com/pytorch/TensorRT/issues) if this warning persists."
171- )
172- contiguous_inputs = (
173- contiguous_inputs [:i ]
174- + [contiguous_inputs [i ].cuda ()]
175- + contiguous_inputs [i + 1 :]
176- )
177-
178- assert (
179- contiguous_inputs [i ].dtype == self .input_dtypes [i ]
180- ), f"Dtype mismatch for { i } th input({ input_name } ). Expect { self .input_dtypes [i ]} , got { contiguous_inputs [i ].dtype } ."
205+ # Check that the inputs are on cuda and have the correct data type if in safe mode
206+ if torch_tensorrt ._compile .SAFE_MODE :
207+ if not contiguous_inputs [i ].is_cuda :
208+ logger .warning (
209+ f"Detected input { input_name } of engine { self .engine .name } is not on a cuda device. "
210+ "This tensor is being moved by the runtime but for performance considerations, "
211+ "ensure your inputs are all on GPU and open an issue here "
212+ "(https://github.com/pytorch/TensorRT/issues) if this warning persists."
213+ )
214+ contiguous_inputs = (
215+ contiguous_inputs [:i ]
216+ + [contiguous_inputs [i ].cuda ()]
217+ + contiguous_inputs [i + 1 :]
218+ )
219+
220+ assert (
221+ contiguous_inputs [i ].dtype == self .input_dtypes [i ]
222+ ), f"Dtype mismatch for { i } th input({ input_name } ). Expect { self .input_dtypes [i ]} , got { contiguous_inputs [i ].dtype } ."
181223
182224 idx = self .input_binding_indices_in_order [i ]
183225 bindings [idx ] = contiguous_inputs [i ].data_ptr ()
@@ -188,7 +230,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
188230
189231 with torch .autograd .profiler .record_function (
190232 "PythonTorchTensorRTModule:ProcessOutputs"
191- ):
233+ ) if self . profiling_enabled else nullcontext () :
192234 # create output tensors
193235 outputs : List [torch .Tensor ] = []
194236
@@ -215,7 +257,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
215257
216258 with torch .autograd .profiler .record_function (
217259 "PythonTorchTensorRTModule:TensorRTRuntime"
218- ):
260+ ) if self . profiling_enabled else nullcontext () :
219261 self .context .execute_async_v2 (
220262 bindings , torch .cuda .current_stream ().cuda_stream
221263 )
@@ -235,6 +277,8 @@ def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None:
235277 if not self .context .profiler :
236278 self .context .profiler = trt .Profiler () if profiler is None else profiler
237279
280+ self .profiling_enabled = True
281+
238282 def disable_profiling (self ) -> None :
239283 """
240284 Disable TensorRT profiling.
@@ -244,6 +288,7 @@ def disable_profiling(self) -> None:
244288 torch .cuda .synchronize ()
245289 del self .context
246290 self .context = self .engine .create_execution_context ()
291+ self .profiling_enabled = False
247292
248293 def get_layer_info (self ) -> str :
249294 """
0 commit comments