1515# Importing supported Backends
1616import torch
1717import torch_tensorrt as torchtrt
18+ import torch_tensorrt .fx .tracer .acc_tracer .acc_tracer as acc_tracer
19+ from torch_tensorrt .fx import InputTensorSpec , TRTInterpreter
20+ from torch_tensorrt .fx import TRTModule
1821import tensorrt as trt
22+ from utils import parse_inputs , parse_backends , precision_to_dtype , BENCHMARK_MODELS
1923
2024WARMUP_ITER = 10
2125results = []
@@ -71,7 +75,7 @@ def run_torch(model, input_tensors, params, precision, batch_size):
7175
7276# Runs inference using Torch-TensorRT backend
7377def run_torch_tensorrt (model , input_tensors , params , precision , truncate_long_and_double , batch_size ):
74- print ("Running Torch-TensorRT" )
78+ print ("Running Torch-TensorRT for precision: " , precision )
7579 # Compiling Torch-TensorRT model
7680 compile_settings = {
7781 "inputs" : input_tensors ,
@@ -82,8 +86,8 @@ def run_torch_tensorrt(model, input_tensors, params, precision, truncate_long_an
8286 if precision == 'int8' :
8387 compile_settings .update ({"calib" : params .get ('calibration_cache' )})
8488
85-
86- model = torchtrt .compile (model , ** compile_settings )
89+ with torchtrt . logging . errors ():
90+ model = torchtrt .compile (model , ** compile_settings )
8791
8892 iters = params .get ('iterations' , 20 )
8993 # Warm up
@@ -106,6 +110,55 @@ def run_torch_tensorrt(model, input_tensors, params, precision, truncate_long_an
106110
107111 printStats ("Torch-TensorRT" , timings , precision , batch_size )
108112
113+ # Runs inference using FX2TRT backend
114+ def run_fx2trt (model , input_tensors , params , precision , batch_size ):
115+ print ("Running FX2TRT for precision: " , precision )
116+
117+ # Trace the model with acc_tracer.
118+ acc_mod = acc_tracer .trace (model , input_tensors )
119+ # Generate input specs
120+ input_specs = InputTensorSpec .from_tensors (input_tensors )
121+ # Build a TRT interpreter. Set explicit_batch_dimension accordingly.
122+ interpreter = TRTInterpreter (
123+ acc_mod , input_specs , explicit_batch_dimension = True
124+ )
125+ trt_interpreter_result = interpreter .run (
126+ max_batch_size = batch_size ,
127+ lower_precision = precision ,
128+ max_workspace_size = 1 << 25 ,
129+ sparse_weights = False ,
130+ force_fp32_output = False ,
131+ strict_type_constraints = False ,
132+ algorithm_selector = None ,
133+ timing_cache = None ,
134+ profiling_verbosity = None )
135+
136+ model = TRTModule (
137+ trt_interpreter_result .engine ,
138+ trt_interpreter_result .input_names ,
139+ trt_interpreter_result .output_names )
140+
141+ iters = params .get ('iterations' , 20 )
142+ # Warm up
143+ with torch .no_grad ():
144+ for _ in range (WARMUP_ITER ):
145+ features = model (* input_tensors )
146+
147+ torch .cuda .synchronize ()
148+
149+ timings = []
150+ with torch .no_grad ():
151+ for i in range (iters ):
152+ start_time = timeit .default_timer ()
153+ features = model (* input_tensors )
154+ torch .cuda .synchronize ()
155+ end_time = timeit .default_timer ()
156+ meas_time = end_time - start_time
157+ timings .append (meas_time )
158+ print ("Iteration {}: {:.6f} s" .format (i , end_time - start_time ))
159+
160+ printStats ("FX-TensorRT" , timings , precision , batch_size )
161+
109162def torch_dtype_from_trt (dtype ):
110163 if dtype == trt .int8 :
111164 return torch .int8
@@ -141,19 +194,18 @@ def run_tensorrt(model, input_tensors, params, precision, is_trt_engine=False, b
141194 }
142195
143196 print ("Converting method to TensorRT engine..." )
144- with torch .no_grad ():
197+ with torch .no_grad (), torchtrt . logging . errors () :
145198 model = torchtrt .ts .convert_method_to_trt_engine (model , "forward" , ** compile_settings )
146199
147200 # Deserialize the TensorRT engine
148201 with trt .Logger () as logger , trt .Runtime (logger ) as runtime :
149202 engine = runtime .deserialize_cuda_engine (model )
150203
151- print ("Running TensorRT" )
204+ print ("Running TensorRT for precision: " , precision )
152205 iters = params .get ('iterations' , 20 )
153206
154207 # Compiling the bindings
155208 bindings = engine .num_bindings * [None ]
156- # import pdb; pdb.set_trace()
157209 k = 0
158210 for idx ,_ in enumerate (bindings ):
159211 dtype = torch_dtype_from_trt (engine .get_binding_dtype (idx ))
@@ -171,12 +223,12 @@ def run_tensorrt(model, input_tensors, params, precision, is_trt_engine=False, b
171223 timings = []
172224 with engine .create_execution_context () as context :
173225 for i in range (WARMUP_ITER ):
174- context .execute_async ( 1 , bindings , torch .cuda .current_stream ().cuda_stream )
226+ context .execute_async_v2 ( bindings , torch .cuda .current_stream ().cuda_stream )
175227 torch .cuda .synchronize ()
176228
177229 for i in range (iters ):
178230 start_time = timeit .default_timer ()
179- context .execute_async ( 1 , bindings , torch .cuda .current_stream ().cuda_stream )
231+ context .execute_async_v2 ( bindings , torch .cuda .current_stream ().cuda_stream )
180232 torch .cuda .synchronize ()
181233 end_time = timeit .default_timer ()
182234 meas_time = end_time - start_time
@@ -186,9 +238,8 @@ def run_tensorrt(model, input_tensors, params, precision, is_trt_engine=False, b
186238 printStats ("TensorRT" , timings , precision , batch_size )
187239
188240# Deploys inference run for different backend configurations
189- def run (model , input_tensors , params , precision , truncate_long_and_double = False , batch_size = 1 , is_trt_engine = False ):
190- for backend in params .get ('backend' ):
191-
241+ def run (model , backends , input_tensors , params , precision , truncate_long_and_double = False , batch_size = 1 , is_trt_engine = False ):
242+ for backend in backends :
192243 if precision == 'int8' :
193244 if backend == 'all' or backend == 'torch' :
194245 print ("int8 precision is not supported for torch runtime in this script yet" )
@@ -201,7 +252,6 @@ def run(model, input_tensors, params, precision, truncate_long_and_double = Fals
201252 if backend == 'all' :
202253 run_torch (model , input_tensors , params , precision , batch_size )
203254 run_torch_tensorrt (model , input_tensors , params , precision , truncate_long_and_double , batch_size )
204- # import pdb; pdb.set_trace()
205255 run_tensorrt (model , input_tensors , params , precision , is_trt_engine , batch_size )
206256
207257 elif backend == "torch" :
@@ -210,6 +260,9 @@ def run(model, input_tensors, params, precision, truncate_long_and_double = Fals
210260 elif backend == "torch_tensorrt" :
211261 run_torch_tensorrt (model , input_tensors , params , precision , truncate_long_and_double , batch_size )
212262
263+ elif backend == "fx2trt" :
264+ run_fx2trt (model , input_tensors , params , precision , batch_size )
265+
213266 elif backend == "tensorrt" :
214267 run_tensorrt (model , input_tensors , params , precision , is_trt_engine , batch_size )
215268
@@ -246,14 +299,6 @@ def printStats(backend, timings, precision, batch_size = 1):
246299 }
247300 results .append (meas )
248301
249- def precision_to_dtype (pr ):
250- if pr == 'fp32' :
251- return torch .float
252- elif pr == 'fp16' or pr == 'half' :
253- return torch .half
254- else :
255- return torch .int8
256-
257302def load_model (params ):
258303 model = None
259304 is_trt_engine = False
@@ -272,47 +317,68 @@ def load_model(params):
272317
273318if __name__ == '__main__' :
274319 arg_parser = argparse .ArgumentParser (description = "Run inference on a model with random input values" )
275- arg_parser .add_argument ("--config" , help = "Load YAML based configuration file to run the inference. If this is used other params will be ignored" )
320+ arg_parser .add_argument ("--config" , type = str , help = "Load YAML based configuration file to run the inference. If this is used other params will be ignored" )
321+ # The following options are manual user provided settings
322+ arg_parser .add_argument ("--backends" , type = str , help = "Comma separated string of backends. Eg: torch,torch_tensorrt,tensorrt" )
323+ arg_parser .add_argument ("--model" , type = str , help = "Name of the model file" )
324+ arg_parser .add_argument ("--inputs" , type = str , help = "List of input shapes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT" )
325+ arg_parser .add_argument ("--batch_size" , type = int , default = 1 , help = "Batch size" )
326+ arg_parser .add_argument ("--precision" , default = "fp32" , type = str , help = "Precision of TensorRT engine" )
327+ arg_parser .add_argument ("--device" , type = int , help = "device id" )
328+ arg_parser .add_argument ("--truncate" , action = 'store_true' , help = "Truncate long and double weights in the network" )
329+ arg_parser .add_argument ("--is_trt_engine" , action = 'store_true' , help = "Boolean flag to determine if the user provided model is a TRT engine or not" )
276330 args = arg_parser .parse_args ()
277331
278- parser = ConfigParser (args .config )
279- # Load YAML params
280- params = parser .read_config ()
281- print ("Loading model: " , params .get ('model' ).get ('filename' ))
282-
283- model = None
284-
285- # Default device is set to 0. Configurable using yaml config file.
286- torch .cuda .set_device (params .get ('runtime' ).get ('device' , 0 ))
287-
288- # Load the model file from disk. If the loaded file is TensorRT engine then is_trt_engine is returned as True
289- model , is_trt_engine = load_model (params )
290332 cudnn .benchmark = True
291-
292333 # Create random input tensor of certain size
293334 torch .manual_seed (12345 )
294335
295- num_input = params .get ('input' ).get ('num_inputs' )
296- truncate_long_and_double = params .get ('runtime' ).get ('truncate_long_and_double' , False )
297- batch_size = params .get ('input' ).get ('batch_size' , 1 )
298- for precision in params .get ('runtime' ).get ('precision' , 'fp32' ):
299- input_tensors = []
300- num_input = params .get ('input' ).get ('num_inputs' , 1 )
301- for i in range (num_input ):
302- inp_tensor = params .get ('input' ).get ('input' + str (i ))
303- input_tensors .append (torch .randint (0 , 2 , tuple (d for d in inp_tensor ), dtype = precision_to_dtype (precision )).cuda ())
304-
305- if is_trt_engine :
306- print ("Warning, TensorRT engine file is configured. Please make sure the precision matches with the TRT engine for reliable results" )
307-
308- if not is_trt_engine and precision == "fp16" or precision == "half" :
309- # If model is TensorRT serialized engine then model.half will report failure
310- model = model .half ()
311-
336+ if args .config :
337+ parser = ConfigParser (args .config )
338+ # Load YAML params
339+ params = parser .read_config ()
340+ print ("Loading model: " , params .get ('model' ).get ('filename' ))
341+ model_file = params .get ('model' ).get ('filename' )
342+ # Default device is set to 0. Configurable using yaml config file.
343+ torch .cuda .set_device (params .get ('runtime' ).get ('device' , 0 ))
344+
345+ num_input = params .get ('input' ).get ('num_inputs' )
346+ truncate_long_and_double = params .get ('runtime' ).get ('truncate_long_and_double' , False )
347+ batch_size = params .get ('input' ).get ('batch_size' , 1 )
348+ for precision in params .get ('runtime' ).get ('precision' , 'fp32' ):
349+ input_tensors = []
350+ num_input = params .get ('input' ).get ('num_inputs' , 1 )
351+ for i in range (num_input ):
352+ inp_tensor = params .get ('input' ).get ('input' + str (i ))
353+ input_tensors .append (torch .randint (0 , 2 , tuple (d for d in inp_tensor ), dtype = precision_to_dtype (precision )).cuda ())
354+
355+ if is_trt_engine :
356+ print ("Warning, TensorRT engine file is configured. Please make sure the precision matches with the TRT engine for reliable results" )
357+
358+ if not is_trt_engine and precision == "fp16" or precision == "half" :
359+ # If model is TensorRT serialized engine then model.half will report failure
360+ model = model .half ()
361+ backends = params .get ('backend' )
362+ # Run inference
363+ status = run (model , backends , input_tensors , params , precision , truncate_long_and_double , batch_size , is_trt_engine )
364+ else :
365+ params = vars (args )
366+ model_name = params ['model' ]
367+ if os .path .exists (model_name ):
368+ print ("Loading user provided model: " , model_name )
369+ model = torch .jit .load (model_name ).cuda ().eval ()
370+ elif model_name in BENCHMARK_MODELS :
371+ model = BENCHMARK_MODELS [model_name ]['model' ].eval ().cuda ()
372+ else :
373+ raise ValueError ("Invalid model name. Please provide a torchscript model file or model name (among the following options vgg16|resnet50|efficientnet_b0|vit)" )
374+ precision = params ['precision' ]
375+ input_tensors = parse_inputs (params ['inputs' ])
376+ backends = parse_backends (params ['backends' ])
377+ truncate_long_and_double = params .get ('truncate' , False )
378+ batch_size = params ['batch_size' ]
379+ is_trt_engine = params ['is_trt_engine' ]
312380 # Run inference
313- status = run (model , input_tensors , params , precision , truncate_long_and_double , batch_size , is_trt_engine )
314- if status == False :
315- continue
381+ status = run (model , backends , input_tensors , params , precision , truncate_long_and_double , batch_size , is_trt_engine )
316382
317383 # Generate report
318384 print ('Model Summary:' )
0 commit comments