11import logging
22import math
3+ import operator
4+ import os
35from dataclasses import dataclass , field
4- from typing import Any , Dict , List
6+ from typing import Any , Dict , List , Union
57
8+ import torch
69from torch_tensorrt .dynamo ._settings import CompilationSettings
10+ from torch_tensorrt .dynamo .conversion ._ConverterRegistry import ConverterRegistry
11+ from torch_tensorrt .dynamo .conversion .converter_utils import get_node_name
712
813logger = logging .getLogger (__name__ )
914
@@ -44,6 +49,7 @@ class DryRunTracker:
4449 tensorrt_graph_count (int): Number of TensorRT engines to be generated
4550 compilation_settings (CompilationSettings): User Compilation Settings
4651 unsupported_ops (Dict[str, int]): Set of operators not supported in TRT
52+ to_run_in_torch (List[str]): Set of nodes to run in Torch
4753 """
4854
4955 total_ops_in_graph : int = 0
@@ -58,9 +64,12 @@ class DryRunTracker:
5864 default_factory = CompilationSettings
5965 )
6066 unsupported_ops : Dict [str , int ] = field (default_factory = dict )
67+ to_run_in_torch : List [str ] = field (default_factory = list )
6168
6269
63- def dryrun_stats_display (dryrun_tracker : DryRunTracker , dryrun_enabled : bool ) -> None :
70+ def dryrun_stats_display (
71+ dryrun_tracker : DryRunTracker , dryrun_enabled : Union [bool , str ]
72+ ) -> None :
6473 """Displays statistics about the dryrun either to debug logs or stdout"""
6574 formatted_stats = "\n "
6675
@@ -71,7 +80,19 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
7180 f"of which { dryrun_tracker .supported_ops_in_graph } operators are supported, "
7281 f"{ round (dryrun_tracker .supported_ops_in_graph * 100 / dryrun_tracker .total_ops_in_graph , 2 )} % coverage\n \n "
7382 )
74- formatted_stats += f"The following ops are currently unsupported and set to run in Torch: { dryrun_tracker .unsupported_ops } \n \n "
83+ if dryrun_tracker .unsupported_ops :
84+ parsed_ops = "\n " .join (
85+ [f"{ str (k )} : { str (v )} " for k , v in dryrun_tracker .unsupported_ops .items ()]
86+ )
87+ formatted_stats += f"The following ops are currently unsupported or excluded from conversion, and are listed with their op-count in the graph:\n { parsed_ops } \n \n "
88+
89+ if dryrun_tracker .to_run_in_torch :
90+ formatted_nodes = "\n " .join (dryrun_tracker .to_run_in_torch )
91+ formatted_stats += (
92+ f"The following nodes are currently set to run in Torch:\n { formatted_nodes } \n "
93+ "Note: Some of the above nodes may be supported, but were not included in a TRT graph by the partitioner\n \n "
94+ )
95+
7596 formatted_stats += f"Compiled with: { dryrun_tracker .compilation_settings } \n \n "
7697
7798 assert len (dryrun_tracker .per_subgraph_data ) == dryrun_tracker .tensorrt_graph_count
@@ -184,8 +205,17 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
184205 )
185206
186207 # If user specified "dryrun=True", print to stdout, else debug
208+ # If user specified a filepath, save the output to the path as well
187209 if dryrun_enabled :
188210 print (formatted_stats )
211+ if isinstance (dryrun_enabled , str ):
212+ if os .path .exists (dryrun_enabled ):
213+ logger .warning (
214+ f"File already exists at path { dryrun_enabled } , not saving dryrun output"
215+ )
216+ else :
217+ with open (dryrun_enabled , "w+" ) as f :
218+ f .write (formatted_stats )
189219 else :
190220 logger .debug (formatted_stats )
191221
@@ -225,3 +255,23 @@ def input_formatter_helper(shapes: Any, dtypes: Any) -> str:
225255 )
226256
227257 return input_formatter_helper (shapes , dtypes )[:- 2 ]
258+
259+
260+ def parse_non_trt_nodes (graph_module : torch .fx .GraphModule ) -> List [str ]:
261+ """Parses call_function and call_method nodes from a GraphModule
262+ Excludes getitem nodes
263+
264+ Returns a string representation of the nodes
265+ """
266+ to_run_in_torch = []
267+ for node in graph_module .graph .nodes :
268+ # getitem nodes are excluded since they are a Tensor-collection op
269+ if (
270+ node .op in ("call_function" , "call_method" )
271+ and node .target != operator .getitem
272+ ):
273+ to_run_in_torch .append (
274+ f"Node: { ConverterRegistry .qualified_name_or_str (node .target )} , "
275+ f"with layer location: { get_node_name (node )} "
276+ )
277+ return to_run_in_torch
0 commit comments