@@ -105,6 +105,8 @@ def __init__(
105105 [dtype ._from (o ) for o in output_dtypes ] if output_dtypes else None
106106 )
107107
108+ _LOGGER .debug (f"Graph to be compiled to TensorRT: { self .module .graph } " )
109+
108110 def validate_conversion (self ) -> Set [str ]:
109111 missing_converters : Set [str ] = set ()
110112
@@ -121,6 +123,18 @@ def validate_conversion(self) -> Set[str]:
121123
122124 return missing_converters
123125
126+ @staticmethod
127+ def _args_str (args : List [Any ]) -> str :
128+ args_ = [
129+ (
130+ f"ITensor { a .name } (shape: { a .shape } , dtype: { a .dtype } )"
131+ if isinstance (a , trt .ITensor )
132+ else a
133+ )
134+ for a in args
135+ ]
136+ return str (tuple (args_ ))
137+
124138 @staticmethod
125139 def _all_precisions_supported (enabled_precisions : Set [dtype ]) -> bool :
126140 return enabled_precisions .issubset (_defaults .SUPPORTED_KERNEL_PRECISIONS )
@@ -359,10 +373,14 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor:
359373 f"Unable to access shape spec for input: { target } (got: { current_input } )"
360374 )
361375
376+ trt_input_dtype = current_input .dtype .to (trt .DataType , use_default = True )
377+ _LOGGER .debug (
378+ f"Adding input to in-progress INetwork: { target } (shape={ shape } , dtype={ trt_input_dtype } )"
379+ )
362380 return self .ctx .net .add_input (
363381 name = target ,
364382 shape = tuple (shape ),
365- dtype = current_input . dtype . to ( trt . DataType , use_default = True ) ,
383+ dtype = trt_input_dtype ,
366384 )
367385
368386 def call_module (
@@ -381,6 +399,9 @@ def call_module(
381399 converter , calling_convention = converter_packet
382400
383401 assert self ._cur_node_name is not None
402+ _LOGGER .debug (
403+ f"Converting node { self ._cur_node_name } (kind: { target } , args: { TRTInterpreter ._args_str (args )} )"
404+ )
384405 if calling_convention is CallingConvention .LEGACY :
385406 return converter (self .ctx .net , submod , args , kwargs , self ._cur_node_name )
386407 else :
@@ -397,6 +418,9 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
397418 converter , calling_convention = converter_packet
398419
399420 assert self ._cur_node_name is not None
421+ _LOGGER .debug (
422+ f"Converting node { self ._cur_node_name } (kind: { target } , args: { TRTInterpreter ._args_str (args )} )"
423+ )
400424 if calling_convention is CallingConvention .LEGACY :
401425 return converter (self .ctx .net , target , args , kwargs , self ._cur_node_name )
402426 else :
@@ -428,6 +452,9 @@ def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
428452 converter , calling_convention = converter_packet
429453
430454 assert self ._cur_node_name is not None
455+ _LOGGER .debug (
456+ f"Converting node { self ._cur_node_name } (kind: { target } , args: { TRTInterpreter ._args_str (args )} )"
457+ )
431458 if calling_convention is CallingConvention .LEGACY :
432459 return converter (self .ctx .net , target , args , kwargs , self ._cur_node_name )
433460 else :
@@ -485,8 +512,10 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
485512 output .dtype = trt .DataType .BOOL
486513 elif self .output_dtypes is not None :
487514 output .dtype = self .output_dtypes [i ].to (trt .DataType )
488- elif self .output_fp16 and output .dtype == trt .DataType .FLOAT :
489- output .dtype = trt .DataType .HALF
515+
490516 self ._output_names .append (name )
517+ _LOGGER .debug (
518+ f"Marking output { name } (shape: { output .shape } , dtype: { output .dtype } )"
519+ )
491520
492521 return list (outputs )
0 commit comments