88import torch
99from torch .nn import Module
1010from torch_tensorrt ._Device import Device
11- from torch_tensorrt .dynamo .runtime .tools import _is_switch_required , _select_rt_device
11+ from torch_tensorrt .dynamo .runtime .tools import multi_gpu_device_check
1212from torch_tensorrt .fx .utils import Frameworks , unified_dtype_converter
1313
14- import torch_tensorrt
15-
1614logger = logging .getLogger (__name__ )
1715
1816
@@ -33,6 +31,10 @@ def __init__(
3331 ):
3432 super (PythonTorchTensorRTModule , self ).__init__ ()
3533 self ._register_state_dict_hook (PythonTorchTensorRTModule ._on_state_dict )
34+
35+ # Run multi-gpu device check to validate engine instantiation
36+ multi_gpu_device_check ()
37+
3638 self .engine = engine
3739 self .input_names = input_names if input_names is not None else []
3840 self .output_names = output_names if output_names is not None else []
@@ -133,6 +135,9 @@ def _load_from_state_dict(
133135 ) -> None :
134136 engine_bytes = state_dict [prefix + "engine" ]
135137
138+ # Run multi-gpu device check to validate engine instantiation
139+ multi_gpu_device_check ()
140+
136141 logger = trt .Logger ()
137142 runtime = trt .Runtime (logger )
138143 self .engine = runtime .deserialize_cuda_engine (engine_bytes )
@@ -161,32 +166,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
161166 ) if self .profiling_enabled else nullcontext ():
162167 self ._check_initialized ()
163168
164- # If in safe mode, check at each iteration for for whether a switch is required
165- if torch_tensorrt ._compile .SAFE_MODE :
166- curr_device_id = torch .cuda .current_device ()
167- curr_device_properties = torch .cuda .get_device_properties (
168- curr_device_id
169- )
170- logger .debug (f"Current Device: cuda:{ curr_device_id } " )
171-
172- # If a switch is required, move all inputs to new device and set as active device
173- if _is_switch_required (
174- curr_device_id ,
175- self .target_device_id ,
176- curr_device_properties ,
177- self .target_device_properties ,
178- ):
179- device_id , _ = _select_rt_device (
180- curr_device_id ,
181- self .target_device_id ,
182- self .target_device_properties ,
183- )
184- device = torch .device (device_id )
185- torch .cuda .set_device (device_id )
186-
187- inputs = tuple ([tensor .to (device ) for tensor in inputs ])
188- logger .warning (f"Moved all input Tensors to cuda:{ device_id } " )
189-
190169 with torch .autograd .profiler .record_function (
191170 "PythonTorchTensorRTModule:ProcessInputs"
192171 ) if self .profiling_enabled else nullcontext ():
@@ -202,24 +181,22 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
202181 )
203182
204183 for i , input_name in enumerate (self .input_names ):
205- # Check that the inputs are on cuda and have the correct data type if in safe mode
206- if torch_tensorrt ._compile .SAFE_MODE :
207- if not contiguous_inputs [i ].is_cuda :
208- logger .warning (
209- f"Detected input { input_name } of engine { self .engine .name } is not on a cuda device. "
210- "This tensor is being moved by the runtime but for performance considerations, "
211- "ensure your inputs are all on GPU and open an issue here "
212- "(https://github.com/pytorch/TensorRT/issues) if this warning persists."
213- )
214- contiguous_inputs = (
215- contiguous_inputs [:i ]
216- + [contiguous_inputs [i ].cuda ()]
217- + contiguous_inputs [i + 1 :]
218- )
219-
220- assert (
221- contiguous_inputs [i ].dtype == self .input_dtypes [i ]
222- ), f"Dtype mismatch for { i } th input({ input_name } ). Expect { self .input_dtypes [i ]} , got { contiguous_inputs [i ].dtype } ."
184+ if not contiguous_inputs [i ].is_cuda :
185+ logger .warning (
186+ f"Detected input { input_name } of engine { self .engine .name } is not on a cuda device. "
187+ "This tensor is being moved by the runtime but for performance considerations, "
188+ "ensure your inputs are all on GPU and open an issue here "
189+ "(https://github.com/pytorch/TensorRT/issues) if this warning persists."
190+ )
191+ contiguous_inputs = (
192+ contiguous_inputs [:i ]
193+ + [contiguous_inputs [i ].cuda ()]
194+ + contiguous_inputs [i + 1 :]
195+ )
196+
197+ assert (
198+ contiguous_inputs [i ].dtype == self .input_dtypes [i ]
199+ ), f"Dtype mismatch for { i } th input({ input_name } ). Expect { self .input_dtypes [i ]} , got { contiguous_inputs [i ].dtype } ."
223200
224201 idx = self .input_binding_indices_in_order [i ]
225202 bindings [idx ] = contiguous_inputs [i ].data_ptr ()
0 commit comments