88import torch
99from torch .nn import Module
1010from torch_tensorrt ._Device import Device
11- from torch_tensorrt .dynamo .runtime .tools import _is_switch_required , _select_rt_device
11+ from torch_tensorrt .dynamo .runtime .tools import (
12+ _is_switch_required ,
13+ _select_rt_device ,
14+ multi_gpu_device_check ,
15+ )
1216from torch_tensorrt .fx .utils import Frameworks , unified_dtype_converter
1317
1418import torch_tensorrt
@@ -33,6 +37,10 @@ def __init__(
3337 ):
3438 super (PythonTorchTensorRTModule , self ).__init__ ()
3539 self ._register_state_dict_hook (PythonTorchTensorRTModule ._on_state_dict )
40+
41+ # Run multi-gpu device check to validate engine instantiation
42+ multi_gpu_device_check ()
43+
3644 self .engine = engine
3745 self .input_names = input_names if input_names is not None else []
3846 self .output_names = output_names if output_names is not None else []
@@ -133,6 +141,9 @@ def _load_from_state_dict(
133141 ) -> None :
134142 engine_bytes = state_dict [prefix + "engine" ]
135143
144+ # Run multi-gpu device check to validate engine instantiation
145+ multi_gpu_device_check ()
146+
136147 logger = trt .Logger ()
137148 runtime = trt .Runtime (logger )
138149 self .engine = runtime .deserialize_cuda_engine (engine_bytes )
@@ -162,7 +173,9 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
162173 self ._check_initialized ()
163174
164175 # If in safe mode, check at each iteration for for whether a switch is required
165- if torch_tensorrt ._compile .SAFE_MODE :
176+ if (
177+ torch_tensorrt .runtime .multi_device_safe_mode ._PY_RT_MULTI_DEVICE_SAFE_MODE
178+ ):
166179 curr_device_id = torch .cuda .current_device ()
167180 curr_device_properties = torch .cuda .get_device_properties (
168181 curr_device_id
@@ -202,24 +215,22 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
202215 )
203216
204217 for i , input_name in enumerate (self .input_names ):
205- # Check that the inputs are on cuda and have the correct data type if in safe mode
206- if torch_tensorrt ._compile .SAFE_MODE :
207- if not contiguous_inputs [i ].is_cuda :
208- logger .warning (
209- f"Detected input { input_name } of engine { self .engine .name } is not on a cuda device. "
210- "This tensor is being moved by the runtime but for performance considerations, "
211- "ensure your inputs are all on GPU and open an issue here "
212- "(https://github.com/pytorch/TensorRT/issues) if this warning persists."
213- )
214- contiguous_inputs = (
215- contiguous_inputs [:i ]
216- + [contiguous_inputs [i ].cuda ()]
217- + contiguous_inputs [i + 1 :]
218- )
219-
220- assert (
221- contiguous_inputs [i ].dtype == self .input_dtypes [i ]
222- ), f"Dtype mismatch for { i } th input({ input_name } ). Expect { self .input_dtypes [i ]} , got { contiguous_inputs [i ].dtype } ."
218+ if not contiguous_inputs [i ].is_cuda :
219+ logger .warning (
220+ f"Detected input { input_name } of engine { self .engine .name } is not on a cuda device. "
221+ "This tensor is being moved by the runtime but for performance considerations, "
222+ "ensure your inputs are all on GPU and open an issue here "
223+ "(https://github.com/pytorch/TensorRT/issues) if this warning persists."
224+ )
225+ contiguous_inputs = (
226+ contiguous_inputs [:i ]
227+ + [contiguous_inputs [i ].cuda ()]
228+ + contiguous_inputs [i + 1 :]
229+ )
230+
231+ assert (
232+ contiguous_inputs [i ].dtype == self .input_dtypes [i ]
233+ ), f"Dtype mismatch for { i } th input({ input_name } ). Expect { self .input_dtypes [i ]} , got { contiguous_inputs [i ].dtype } ."
223234
224235 idx = self .input_binding_indices_in_order [i ]
225236 bindings [idx ] = contiguous_inputs [i ].data_ptr ()
0 commit comments