44from typing import Any , Callable , Dict , List , NamedTuple , Optional , Sequence , Set
55
66import numpy as np
7+ import tensorrt as trt
78import torch
89import torch .fx
910from torch .fx .node import _get_qualified_name
2324from torch_tensorrt .fx .observer import Observer
2425from torch_tensorrt .fx .utils import Frameworks , unified_dtype_converter
2526
26- # @manual=//deeplearning/trt/python:py_tensorrt
27- import tensorrt as trt
2827from packaging import version
2928
3029_LOGGER : logging .Logger = logging .getLogger (__name__ )
@@ -96,6 +95,7 @@ def __init__(
9695 self ._itensor_to_tensor_meta : Dict [
9796 trt .tensorrt .ITensor , TensorMetadata
9897 ] = dict ()
98+ self .compilation_settings = compilation_settings
9999
100100 # Data types for TRT Module output Tensors
101101 self .output_dtypes = output_dtypes
@@ -118,40 +118,25 @@ def validate_conversion(self) -> Set[str]:
118118
119119 def run (
120120 self ,
121- workspace_size : int = 0 ,
122- precision : torch .dtype = torch .float32 , # TODO: @peri044 Needs to be expanded to set
123- sparse_weights : bool = False ,
124- disable_tf32 : bool = False ,
125121 force_fp32_output : bool = False ,
126122 strict_type_constraints : bool = False ,
127123 algorithm_selector : Optional [trt .IAlgorithmSelector ] = None ,
128124 timing_cache : Optional [trt .ITimingCache ] = None ,
129- profiling_verbosity : Optional [trt .ProfilingVerbosity ] = None ,
130125 tactic_sources : Optional [int ] = None ,
131- max_aux_streams : Optional [int ] = None ,
132- version_compatible : bool = False ,
133- optimization_level : Optional [int ] = None ,
134126 ) -> TRTInterpreterResult :
135127 """
136128 Build TensorRT engine with some configs.
137129 Args:
138- workspace_size: Amount of memory used by TensorRT to store intermediate buffers within an operation.
139- precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
140- sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity
141130 force_fp32_output: force output to be fp32
142131 strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons.
143132 algorithm_selector: set up algorithm selection for certain layer
144133 timing_cache: enable timing cache for TensorRT
145- profiling_verbosity: TensorRT logging level
146- max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
147- version_compatible: Provide version forward-compatibility for engine plan files
148- optimization_level: Builder optimization 0-5, higher levels imply longer build time,
149- searching for more optimization options. TRT defaults to 3
150134 Return:
151135 TRTInterpreterResult
152136 """
153137 TRT_INTERPRETER_CALL_PRE_OBSERVER .observe (self .module )
154138
139+ precision = self .compilation_settings .precision
155140 # For float outputs, we set their dtype to fp16 only if precision == torch.float16 and
156141 # force_fp32_output=False. Overriden by specifying output_dtypes
157142 self .output_fp16 = not force_fp32_output and precision == torch .float16
@@ -172,9 +157,9 @@ def run(
172157
173158 builder_config = self .builder .create_builder_config ()
174159
175- if workspace_size != 0 :
160+ if self . compilation_settings . workspace_size != 0 :
176161 builder_config .set_memory_pool_limit (
177- trt .MemoryPoolType .WORKSPACE , workspace_size
162+ trt .MemoryPoolType .WORKSPACE , self . compilation_settings . workspace_size
178163 )
179164
180165 cache = None
@@ -187,34 +172,66 @@ def run(
187172
188173 if version .parse (trt .__version__ ) >= version .parse ("8.2" ):
189174 builder_config .profiling_verbosity = (
190- profiling_verbosity
191- if profiling_verbosity
175+ trt . ProfilingVerbosity . VERBOSE
176+ if self . compilation_settings . debug
192177 else trt .ProfilingVerbosity .LAYER_NAMES_ONLY
193178 )
194179
195180 if version .parse (trt .__version__ ) >= version .parse ("8.6" ):
196- if max_aux_streams is not None :
197- _LOGGER .info (f"Setting max aux streams to { max_aux_streams } " )
198- builder_config .max_aux_streams = max_aux_streams
199- if version_compatible :
181+ if self .compilation_settings .max_aux_streams is not None :
182+ _LOGGER .info (
183+ f"Setting max aux streams to { self .compilation_settings .max_aux_streams } "
184+ )
185+ builder_config .max_aux_streams = (
186+ self .compilation_settings .max_aux_streams
187+ )
188+ if self .compilation_settings .version_compatible :
200189 _LOGGER .info ("Using version compatible" )
201190 builder_config .set_flag (trt .BuilderFlag .VERSION_COMPATIBLE )
202- if optimization_level is not None :
203- _LOGGER .info (f"Using optimization level { optimization_level } " )
204- builder_config .builder_optimization_level = optimization_level
191+ if self .compilation_settings .optimization_level is not None :
192+ _LOGGER .info (
193+ f"Using optimization level { self .compilation_settings .optimization_level } "
194+ )
195+ builder_config .builder_optimization_level = (
196+ self .compilation_settings .optimization_level
197+ )
198+
199+ builder_config .engine_capability = self .compilation_settings .engine_capability
200+ builder_config .avg_timing_iterations = (
201+ self .compilation_settings .num_avg_timing_iters
202+ )
203+
204+ if self .compilation_settings .device .device_type == trt .DeviceType .DLA :
205+ builder_config .DLA_core = self .compilation_settings .device .dla_core
206+ _LOGGER .info (f"Using DLA core { self .compilation_settings .device .dla_core } " )
207+ builder_config .set_memory_pool_limit (
208+ trt .MemoryPoolType .DLA_MANAGED_SRAM ,
209+ self .compilation_settings .dla_sram_size ,
210+ )
211+ builder_config .set_memory_pool_limit (
212+ trt .MemoryPoolType .DLA_LOCAL_DRAM ,
213+ self .compilation_settings .dla_local_dram_size ,
214+ )
215+ builder_config .set_memory_pool_limit (
216+ trt .MemoryPoolType .DLA_GLOBAL_DRAM ,
217+ self .compilation_settings .dla_global_dram_size ,
218+ )
205219
206220 if precision == torch .float16 :
207221 builder_config .set_flag (trt .BuilderFlag .FP16 )
208222
209223 if precision == torch .int8 :
210224 builder_config .set_flag (trt .BuilderFlag .INT8 )
211225
212- if sparse_weights :
226+ if self . compilation_settings . sparse_weights :
213227 builder_config .set_flag (trt .BuilderFlag .SPARSE_WEIGHTS )
214228
215- if disable_tf32 :
229+ if self . compilation_settings . disable_tf32 :
216230 builder_config .clear_flag (trt .BuilderFlag .TF32 )
217231
232+ if self .compilation_settings .refit :
233+ builder_config .set_flag (trt .BuilderFlag .REFIT )
234+
218235 if strict_type_constraints :
219236 builder_config .set_flag (trt .BuilderFlag .STRICT_TYPES )
220237
0 commit comments