1414from .lower_setting import LowerSetting
1515from .passes .lower_pass_manager_builder import LowerPassManagerBuilder
1616from .passes .pass_utils import PassFunc , validate_inference
17+ from ..common_utils import use_python_runtime_parser
1718from torch_tensorrt .fx .tools .timing_cache_utils import TimingCacheManager
1819from torch_tensorrt .fx .tools .trt_splitter import TRTSplitter , TRTSplitterSetting
1920
@@ -48,7 +49,7 @@ def compile(
4849 save_timing_cache = False ,
4950 cuda_graph_batch_size = - 1 ,
5051 is_aten = False ,
51- use_experimental_fx_rt = False ,
52+ use_python_runtime = None ,
5253 max_aux_streams = None ,
5354 version_compatible = False ,
5455 optimization_level = None ,
@@ -70,7 +71,9 @@ def compile(
7071 timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
7172 save_timing_cache: Update timing cache with current timing cache data if set to True.
7273 cuda_graph_batch_size: Cuda graph batch size, default to be -1.
73- use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
74+ use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
75+ based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
76+ argument as None
7477 max_aux_streams: max number of aux stream to use
7578 version_compatible: enable version compatible feature
7679 optimization_level: builder optimization level
@@ -111,6 +114,9 @@ def compile(
111114 "Invalid device provided. Supported options: torch.device | torch_tensorrt.Device"
112115 )
113116
117+ # Parse user-specification of which runtime to use
118+ use_python_runtime = use_python_runtime_parser (use_python_runtime )
119+
114120 lower_setting = LowerSetting (
115121 device = device ,
116122 min_block_size = min_block_size ,
@@ -123,7 +129,7 @@ def compile(
123129 save_timing_cache = save_timing_cache ,
124130 cuda_graph_batch_size = cuda_graph_batch_size ,
125131 is_aten = is_aten ,
126- use_experimental_rt = use_experimental_fx_rt ,
132+ use_python_runtime = use_python_runtime ,
127133 max_aux_streams = max_aux_streams ,
128134 version_compatible = version_compatible ,
129135 optimization_level = optimization_level ,
@@ -202,7 +208,7 @@ def default_split_function(
202208 splitter_setting = TRTSplitterSetting ()
203209 splitter_setting .use_implicit_batch_dim = False
204210 splitter_setting .min_block_size = lower_setting .min_block_size
205- splitter_setting .use_experimental_rt = lower_setting .use_experimental_rt
211+ splitter_setting .use_experimental_rt = not lower_setting .use_python_runtime
206212 splitter = TRTSplitter (model , inputs , settings = splitter_setting )
207213 splitter .node_support_preview ()
208214 return splitter .generate_split_results ()
@@ -224,9 +230,17 @@ def lower_pass(
224230 """
225231 interpreter = create_trt_interpreter (lower_setting )
226232 interp_res : TRTInterpreterResult = interpreter (mod , input , module_name )
227- if lower_setting .use_experimental_rt :
228- import io
233+ if lower_setting .use_python_runtime :
234+ trt_module = TRTModule (
235+ engine = interp_res .engine ,
236+ input_names = interp_res .input_names ,
237+ output_names = interp_res .output_names ,
238+ cuda_graph_batch_size = lower_setting .cuda_graph_batch_size ,
239+ )
240+ return trt_module
229241
242+ else :
243+ import io
230244 from torch_tensorrt ._Device import Device
231245 from torch_tensorrt .dynamo ._TorchTensorRTModule import TorchTensorRTModule
232246
@@ -240,16 +254,6 @@ def lower_pass(
240254 input_binding_names = interp_res .input_names ,
241255 output_binding_names = interp_res .output_names ,
242256 target_device = Device (f"cuda:{ torch .cuda .current_device ()} " ),
243- # cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do
244- )
245- return trt_module
246-
247- else :
248- trt_module = TRTModule (
249- engine = interp_res .engine ,
250- input_names = interp_res .input_names ,
251- output_names = interp_res .output_names ,
252- cuda_graph_batch_size = lower_setting .cuda_graph_batch_size ,
253257 )
254258 return trt_module
255259
0 commit comments