11import logging
22import math
33from dataclasses import dataclass , field
4- from typing import List , Tuple
4+ from typing import Any , Dict , List
55
6- import torch
6+ from torch_tensorrt . dynamo . _settings import CompilationSettings
77
88logger = logging .getLogger (__name__ )
99
@@ -15,18 +15,18 @@ class PerSubgraphData:
1515 Args:
1616 subgraph_name (str): Name of the subgraph in the GraphModule
1717 subgraph_op_count (int): Number of operations in the subgraph
18- subgraph_input_shapes (List[Tuple[int, ...]] ): Shapes of input Tensors of the subgraph
19- subgraph_input_dtypes (List[torch.device] ): Input data types of the subgraph
20- subgraph_output_shapes (List[Tuple[int, ...]] ): Shapes of output Tensors of the subgraph
21- subgraph_output_dtypes (List[torch.device] ): Output data types of the subgraph
18+ subgraph_input_shapes (Any ): Shapes of input Tensors of the subgraph
19+ subgraph_input_dtypes (Any ): Input data types of the subgraph
20+ subgraph_output_shapes (Any ): Shapes of output Tensors of the subgraph
21+ subgraph_output_dtypes (Any ): Output data types of the subgraph
2222 """
2323
2424 subgraph_name : str = ""
2525 subgraph_op_count : int = 0
26- subgraph_input_shapes : List [ Tuple [ int , ...]] = field (default_factory = list )
27- subgraph_input_dtypes : List [ torch . device ] = field (default_factory = list )
28- subgraph_output_shapes : List [ Tuple [ int , ...]] = field (default_factory = list )
29- subgraph_output_dtypes : List [ torch . device ] = field (default_factory = list )
26+ subgraph_input_shapes : Any = field (default_factory = list )
27+ subgraph_input_dtypes : Any = field (default_factory = list )
28+ subgraph_output_shapes : Any = field (default_factory = list )
29+ subgraph_output_dtypes : Any = field (default_factory = list )
3030
3131
3232@dataclass
@@ -36,95 +36,86 @@ class DryRunTracker:
3636 Args:
3737 total_ops_in_graph (int): Total number of operators in graph
3838 supported_ops_in_graph (int): Number of supported operators in graph
39- graph_input_shapes (List[Tuple[int, ...]] ): Shapes of input Tensors of the graph
40- graph_input_dtypes (List[torch.device] ): Input data types of the graph
41- graph_output_shapes (List[Tuple[int, ...]] ): Shapes of output Tensors of the graph
42- graph_output_dtypes (List[torch.device] ): Output data types of the graph
39+ graph_input_shapes (Any ): Shapes of input Tensors of the graph
40+ graph_input_dtypes (Any ): Input data types of the graph
41+ graph_output_shapes (Any ): Shapes of output Tensors of the graph
42+ graph_output_dtypes (Any ): Output data types of the graph
4343 per_subgraph_data (List[PerSubgraphData]): Per-subgraph data, see above class
4444 tensorrt_graph_count (int): Number of TensorRT engines to be generated
45- truncated_long_and_double (bool): Whether truncate_long_and_double was enabled
45+ compilation_settings (CompilationSettings): User Compilation Settings
46+ unsupported_ops (Dict[str, int]): Set of operators not supported in TRT
4647 """
4748
4849 total_ops_in_graph : int = 0
4950 supported_ops_in_graph : int = 0
50- graph_input_shapes : List [ Tuple [ int , ...]] = field (default_factory = list )
51- graph_input_dtypes : List [ torch . device ] = field (default_factory = list )
52- graph_output_shapes : List [ Tuple [ int , ...]] = field (default_factory = list )
53- graph_output_dtypes : List [ torch . device ] = field (default_factory = list )
51+ graph_input_shapes : Any = field (default_factory = list )
52+ graph_input_dtypes : Any = field (default_factory = list )
53+ graph_output_shapes : Any = field (default_factory = list )
54+ graph_output_dtypes : Any = field (default_factory = list )
5455 per_subgraph_data : List [PerSubgraphData ] = field (default_factory = list )
5556 tensorrt_graph_count : int = 0
56- truncated_long_and_double : bool = False
57+ compilation_settings : CompilationSettings = field (
58+ default_factory = CompilationSettings
59+ )
60+ unsupported_ops : Dict [str , int ] = field (default_factory = dict )
5761
5862
5963def dryrun_stats_display (dryrun_tracker : DryRunTracker , dryrun_enabled : bool ) -> None :
60- """Displays statistics about the dryrun either to debug logs or info logs"""
61- # If user specified "dryrun=True", print to info logs, else debug
62- if dryrun_enabled :
63- dryrun_logger = logger .info
64- else :
65- dryrun_logger = logger .debug
66-
64+ """Displays statistics about the dryrun either to debug logs or stdout"""
6765 formatted_stats = "\n "
6866
6967 # Print overall stats about the graph, operator counts, etc.
70- formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n "
68+ formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n \n "
7169 formatted_stats += (
7270 f"The graph consists of { dryrun_tracker .total_ops_in_graph } Total Operators, "
7371 f"of which { dryrun_tracker .supported_ops_in_graph } operators are supported, "
74- f"{ round (dryrun_tracker .supported_ops_in_graph * 100 / dryrun_tracker .total_ops_in_graph , 2 )} % coverage\n "
75- )
76- formatted_stats += f"Long and double inputs were { '' if dryrun_tracker .truncated_long_and_double else 'not' } truncated (truncate_long_and_double={ dryrun_tracker .truncated_long_and_double } )\n "
77- formatted_stats += (
78- f"{ dryrun_tracker .tensorrt_graph_count } TRT Engine(s) were generated\n "
72+ f"{ round (dryrun_tracker .supported_ops_in_graph * 100 / dryrun_tracker .total_ops_in_graph , 2 )} % coverage\n \n "
7973 )
74+ formatted_stats += f"The following ops are currently unsupported and set to run in Torch: { dryrun_tracker .unsupported_ops } \n \n "
75+ formatted_stats += f"Compiled with: { dryrun_tracker .compilation_settings } \n \n "
8076
8177 assert len (dryrun_tracker .per_subgraph_data ) == dryrun_tracker .tensorrt_graph_count
8278
8379 # Print schematic of the graph structure, as in:
8480 #
85- # Inputs: [Tensor: (1, 3, 224, 224)@float32]
81+ # Inputs: List [Tensor: (1, 3, 224, 224)@float32]
8682 # ...
87- # TRT Engine #1: _run_on_acc_0
88- # Engine Inputs: [Tensor: (1, 3, 224, 224)@float32]
89- # Number of Operators in Engine: 1
90- # Engine Outputs: [ Tensor: (1, 64, 112, 112)@float32]
83+ # TRT Engine #1 - Submodule name : _run_on_acc_0
84+ # Engine Inputs: List [Tensor: (1, 3, 224, 224)@float32]
85+ # Number of Operators in Engine: 1
86+ # Engine Outputs: Tensor: (1, 64, 112, 112)@float32
9187 # ...
92- # Outputs: [Tensor: (1, 1000)@float32]
88+ # Outputs: List [Tensor: (1, 1000)@float32]
9389 #
9490 formatted_stats += " " * 2 + "Graph Structure:\n \n "
9591 formatted_stats += (
9692 " " * 3
97- + f"Inputs: [ { input_formatter (dryrun_tracker .graph_input_shapes , dryrun_tracker .graph_input_dtypes )} ] \n "
93+ + f"Inputs: { input_formatter (dryrun_tracker .graph_input_shapes , dryrun_tracker .graph_input_dtypes )} \n "
9894 )
9995
10096 for i , trt_subgraph_data in enumerate (dryrun_tracker .per_subgraph_data ):
101- assert len (trt_subgraph_data .subgraph_input_dtypes ) == len (
102- trt_subgraph_data .subgraph_input_shapes
103- )
104- assert len (trt_subgraph_data .subgraph_output_dtypes ) == len (
105- trt_subgraph_data .subgraph_output_shapes
106- )
10797 formatted_stats += " " * 4 + "...\n "
10898 formatted_stats += (
109- " " * 4 + f"TRT Engine #{ i + 1 } : { trt_subgraph_data .subgraph_name } \n "
99+ " " * 4
100+ + f"TRT Engine #{ i + 1 } - Submodule name: { trt_subgraph_data .subgraph_name } \n "
110101 )
111102 formatted_stats += (
112103 " " * 5
113- + f"Engine Inputs: [ { input_formatter (trt_subgraph_data .subgraph_input_shapes , trt_subgraph_data .subgraph_input_dtypes )} ] \n "
104+ + f"Engine Inputs: { input_formatter (trt_subgraph_data .subgraph_input_shapes , trt_subgraph_data .subgraph_input_dtypes )} \n "
114105 )
115106 formatted_stats += (
116107 " " * 5
117108 + f"Number of Operators in Engine: { trt_subgraph_data .subgraph_op_count } \n "
118109 )
119110 formatted_stats += (
120111 " " * 5
121- + f"Engine Outputs: [ { input_formatter (trt_subgraph_data .subgraph_output_shapes , trt_subgraph_data .subgraph_output_dtypes )} ] \n "
112+ + f"Engine Outputs: { input_formatter (trt_subgraph_data .subgraph_output_shapes , trt_subgraph_data .subgraph_output_dtypes )} \n "
122113 )
123114
124115 formatted_stats += " " * 4 + "...\n "
125116 formatted_stats += (
126117 " " * 3
127- + f"Outputs: [ { input_formatter (dryrun_tracker .graph_output_shapes , dryrun_tracker .graph_output_dtypes )} ] \n "
118+ + f"Outputs: { input_formatter (dryrun_tracker .graph_output_shapes , dryrun_tracker .graph_output_dtypes )} \n "
128119 )
129120
130121 # Print aggregate statistics about the graph structure, including recommended "min_block_size" options
@@ -167,23 +158,23 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
167158 + " " * 3
168159 + "- For minimal graph segmentation, select min_block_size="
169160 + f"{ most_ops_in_an_engine } which would generate "
170- + f"{ len ([1 for trt_subgraph in dryrun_tracker .per_subgraph_data if trt_subgraph .subgraph_op_count >= most_ops_in_an_engine ])} TRT engines "
161+ + f"{ len ([1 for trt_subgraph in dryrun_tracker .per_subgraph_data if trt_subgraph .subgraph_op_count >= most_ops_in_an_engine ])} TRT engine(s) "
171162 )
172163 if math .ceil (avg_ops_per_engine ) != most_ops_in_an_engine :
173164 formatted_stats += (
174165 "\n "
175166 + " " * 3
176167 + "- For moderate graph segmentation, select min_block_size="
177168 + f"{ math .ceil (avg_ops_per_engine )} which would generate "
178- + f"{ len ([1 for trt_subgraph in dryrun_tracker .per_subgraph_data if trt_subgraph .subgraph_op_count >= math .ceil (avg_ops_per_engine )])} TRT engines "
169+ + f"{ len ([1 for trt_subgraph in dryrun_tracker .per_subgraph_data if trt_subgraph .subgraph_op_count >= math .ceil (avg_ops_per_engine )])} TRT engine(s) "
179170 )
180171
181172 formatted_stats += (
182173 "\n "
183174 + " " * 3
184175 + "- The current level of graph segmentation is equivalent to selecting min_block_size="
185176 + f"{ min_ops_in_an_engine } which generates "
186- + f"{ len ([1 for trt_subgraph in dryrun_tracker .per_subgraph_data if trt_subgraph .subgraph_op_count >= min_ops_in_an_engine ])} TRT engines "
177+ + f"{ len ([1 for trt_subgraph in dryrun_tracker .per_subgraph_data if trt_subgraph .subgraph_op_count >= min_ops_in_an_engine ])} TRT engine(s) "
187178 )
188179 else :
189180 formatted_stats += (
@@ -192,14 +183,45 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
192183 + "Aggregate stats not available since no TRT Engines were generated."
193184 )
194185
195- dryrun_logger (formatted_stats )
186+ # If user specified "dryrun=True", print to stdout, else debug
187+ if dryrun_enabled :
188+ print (formatted_stats )
189+ else :
190+ logger .debug (formatted_stats )
196191
197192
198- def input_formatter (shapes : List [ Tuple [ int , ...]], dtypes : List [ torch . dtype ] ) -> str :
193+ def input_formatter (shapes : Any , dtypes : Any ) -> str :
199194 """Format shapes and dtypes of input Tensors into a readable string"""
200- formatted_str = ", "
201195
202- for shape , dtype in zip (shapes , dtypes ):
203- formatted_str += f"Tensor: { shape } @{ str (dtype )[6 :]} , "
196+ def input_formatter_helper (shapes : Any , dtypes : Any ) -> str :
197+ """Helper for input formatter"""
198+ # Base case - single shape, single dtype
199+ if isinstance (shapes , tuple ) and all (isinstance (elt , int ) for elt in shapes ):
200+ return f"Tensor: { shapes } @{ str (dtypes )[6 :]} , "
201+
202+ # Shapes is a sequence
203+ elif isinstance (shapes , (list , tuple )):
204+ formatted_str = "List[" if isinstance (shapes , list ) else "Tuple("
205+ for shape , dtype in zip (shapes , dtypes ):
206+ formatted_str += input_formatter_helper (shape , dtype )
207+ formatted_str = formatted_str [:- 2 ] + (
208+ "], " if isinstance (shapes , list ) else "), "
209+ )
210+ return formatted_str
211+
212+ # Shapes is a dictionary
213+ elif isinstance (shapes , dict ):
214+ formatted_str = "Dict{"
215+
216+ for key , shape in shapes .items ():
217+ formatted_str += input_formatter_helper (shape , dtypes [key ])
218+
219+ formatted_str = formatted_str [:- 2 ] + "}, "
220+ return formatted_str
221+
222+ else :
223+ raise ValueError (
224+ f"Invalid input type { type (shapes )} encountered in parse_complex_tensor_structs parsing."
225+ )
204226
205- return formatted_str [ 2 :- 2 ]
227+ return input_formatter_helper ( shapes , dtypes )[ :- 2 ]
0 commit comments