1313from torch .fx .passes .shape_prop import TensorMetadata
1414from torch .utils ._python_dispatch import _disable_current_modes
1515from torch_tensorrt ._Input import Input
16- from torch_tensorrt .dynamo .conversion .converter_utils import get_node_name
16+ from torch_tensorrt .dynamo ._settings import CompilationSettings
17+ from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
18+ from torch_tensorrt .dynamo .conversion .converter_registry import CallingConvention
19+ from torch_tensorrt .dynamo .conversion .converter_utils import (
20+ get_node_name ,
21+ get_trt_tensor ,
22+ )
1723from torch_tensorrt .fx .observer import Observer
1824from torch_tensorrt .fx .utils import Frameworks , unified_dtype_converter
1925
@@ -46,6 +52,7 @@ def __init__(
4652 input_specs : List [Input ],
4753 logger_level : trt .ILogger .Severity = trt .ILogger .Severity .WARNING ,
4854 output_dtypes : Optional [List [torch .dtype ]] = None ,
55+ compilation_settings : CompilationSettings = CompilationSettings (),
4956 ):
5057 super ().__init__ (module )
5158
@@ -59,7 +66,9 @@ def __init__(
5966 EXPLICIT_BATCH = 1 << (int )(trt .NetworkDefinitionCreationFlag .EXPLICIT_BATCH )
6067 flag |= EXPLICIT_BATCH
6168
62- self .network = self .builder .create_network (flag )
69+ self .ctx = ConversionContext (
70+ self .builder .create_network (flag ), compilation_settings
71+ )
6372
6473 missing_ops = self .validate_conversion ()
6574 if missing_ops :
@@ -95,14 +104,14 @@ def validate_conversion(self) -> Set[str]:
95104 missing_converters : Set [str ] = set ()
96105
97106 for node in self .module .graph .nodes :
98- if node .op == "call_function" and not CONVERTERS .get (node ):
107+ if node .op == "call_function" and CONVERTERS .get (node ) is None :
99108 missing_converters .add (f"{ node .op } { _get_qualified_name (node .target )} " )
100- elif node .op == "call_method" and not CONVERTERS .get (node ):
109+ elif node .op == "call_method" and CONVERTERS .get (node ) is None :
101110 missing_converters .add (f"{ node .op } torch.Tensor.{ node .target } " )
102111 elif node .op == "call_module" :
103112 submod = self .fetch_attr (node .target )
104113 submod_type = getattr (submod , "_base_class_origin" , type (submod ))
105- if not CONVERTERS .get (node ):
114+ if CONVERTERS .get (node ) is None :
106115 missing_converters .add (f"{ node .op } { torch .typename (submod_type )} " )
107116
108117 return missing_converters
@@ -221,7 +230,7 @@ def run(
221230 if tactic_sources is not None :
222231 builder_config .set_tactic_sources (tactic_sources = tactic_sources )
223232
224- engine = self .builder .build_engine (self .network , builder_config )
233+ engine = self .builder .build_engine (self .ctx . net , builder_config )
225234 assert engine
226235
227236 serialized_cache = (
@@ -291,7 +300,7 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor:
291300 f"Unable to access shape spec for input: { target } (got: { current_input } )"
292301 )
293302
294- return self .network .add_input (
303+ return self .ctx . net .add_input (
295304 name = target ,
296305 shape = tuple (shape ),
297306 dtype = unified_dtype_converter (current_input .torch_dtype , Frameworks .TRT ),
@@ -303,30 +312,40 @@ def call_module(
303312 assert isinstance (target , str )
304313 submod = self .fetch_attr (target )
305314 submod_type = getattr (submod , "_base_class_origin" , type (submod ))
306- converter = CONVERTERS .get (self ._cur_node )
315+ converter_packet = CONVERTERS .get (self ._cur_node )
307316
308- if not converter :
317+ if converter_packet is None :
309318 raise UnsupportedOperatorException (
310319 f"Conversion of module of type { submod_type } not currently supported!"
311320 )
312321
322+ converter , calling_convention = converter_packet
323+
313324 assert self ._cur_node_name is not None
314- return converter (self .network , submod , args , kwargs , self ._cur_node_name )
325+ if calling_convention is CallingConvention .LEGACY :
326+ return converter (self .ctx .net , submod , args , kwargs , self ._cur_node_name )
327+ else :
328+ return converter (self .ctx , submod , args , kwargs , self ._cur_node_name )
315329
316330 def call_function (self , target : str , args : Any , kwargs : Any ) -> Any :
317331 # TODO: Why is this stateful? We should be able to take in the inputs
318- converter = CONVERTERS .get (self ._cur_node )
319- if not converter :
332+ converter_packet = CONVERTERS .get (self ._cur_node )
333+ if converter_packet is None :
320334 raise UnsupportedOperatorException (
321335 f"Conversion of function { torch .typename (target )} not currently supported!"
322336 )
323337
338+ converter , calling_convention = converter_packet
339+
324340 assert self ._cur_node_name is not None
325- return converter (self .network , target , args , kwargs , self ._cur_node_name )
341+ if calling_convention is CallingConvention .LEGACY :
342+ return converter (self .ctx .net , target , args , kwargs , self ._cur_node_name )
343+ else :
344+ return converter (self .ctx , target , args , kwargs , self ._cur_node_name )
326345
327346 def get_attr (self , target : str , args : Any , kwargs : Any ) -> np .ndarray :
328347 with _disable_current_modes ():
329- from torch_tensorrt .fx . converters import to_numpy
348+ from torch_tensorrt .dynamo . conversion . converter_utils import to_numpy
330349
331350 frozen_attr = self .fetch_attr (target )
332351
@@ -341,15 +360,19 @@ def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
341360
342361 def call_method (self , target : str , args : Any , kwargs : Any ) -> Any :
343362 assert isinstance (target , str )
344- converter = CONVERTERS .get (self ._cur_node )
363+ converter_packet = CONVERTERS .get (self ._cur_node )
345364
346- if not converter :
365+ if converter_packet is None :
347366 raise UnsupportedOperatorException (
348367 f"Conversion of method { target } not currently supported!"
349368 )
369+ converter , calling_convention = converter_packet
350370
351371 assert self ._cur_node_name is not None
352- return converter (self .network , target , args , kwargs , self ._cur_node_name )
372+ if calling_convention is CallingConvention .LEGACY :
373+ return converter (self .ctx .net , target , args , kwargs , self ._cur_node_name )
374+ else :
375+ return converter (self .ctx , target , args , kwargs , self ._cur_node_name )
353376
354377 def output (self , target : str , args : Any , kwargs : Any ) -> List [Any ]:
355378 assert len (args ) == 1
@@ -361,12 +384,10 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
361384 outputs = (args [0 ],)
362385
363386 for output_idx in range (len (outputs )):
364- from torch_tensorrt .dynamo .conversion .converter_utils import get_trt_tensor
365-
366387 output = outputs [output_idx ]
367388
368389 if not isinstance (output , trt .tensorrt .ITensor ):
369- new_output = get_trt_tensor (self .network , output , target )
390+ new_output = get_trt_tensor (self .ctx , output , target )
370391 outputs = (
371392 outputs [:output_idx ] + (new_output ,) + outputs [output_idx + 1 :]
372393 )
@@ -400,7 +421,7 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
400421 output_bool = False
401422 name = f"output{ i } "
402423 output .name = name
403- self .network .mark_output (output )
424+ self .ctx . net .mark_output (output )
404425 if output_bool :
405426 output .dtype = trt .bool
406427 elif self .output_dtypes is not None :
0 commit comments