1515
1616# Importing supported Backends
1717import torch
18+ import torch_tensorrt as torchtrt
1819from utils import (
1920 BENCHMARK_MODELS ,
2021 parse_backends ,
2324 precision_to_dtype ,
2425)
2526
26- import torch_tensorrt as torchtrt
27-
2827WARMUP_ITER = 10
2928results = []
3029
@@ -294,29 +293,30 @@ def run_tensorrt(
294293 input_tensors ,
295294 params ,
296295 precision ,
297- is_trt_engine = False ,
298296 batch_size = 1 ,
299297):
300- engine = None
301-
302- # If the model file is a TensorRT engine then directly deserialize and run inference
303- # else convert the torch module to a TensorRT engine first and then run inference
304- if not is_trt_engine :
305- compile_settings = {
306- "inputs" : input_tensors ,
307- "enabled_precisions" : {precision_to_dtype (precision )},
308- "truncate_long_and_double" : params .get ("truncate" , False ),
309- }
310-
311- print ("Converting method to TensorRT engine..." )
312- with torch .no_grad (), torchtrt .logging .errors ():
313- model = torchtrt .ts .convert_method_to_trt_engine (
314- model , "forward" , ** compile_settings
315- )
316-
298+ # Export an ONNX model and convert to TRT
299+ torch .onnx .export (model .eval ().cuda (), tuple (input_tensors ), "./tmp.onnx" )
300+ logger = trt .Logger (trt .Logger .WARNING )
301+ builder = trt .Builder (logger )
302+ network = builder .create_network (
303+ 1 << int (trt .NetworkDefinitionCreationFlag .EXPLICIT_BATCH )
304+ )
305+ parser = trt .OnnxParser (network , logger )
306+ success = parser .parse_from_file ("./tmp.onnx" )
307+ if not success :
308+ raise ValueError ("ONNX conversion failed" )
309+
310+ config = builder .create_builder_config ()
311+ if precision == "fp16" :
312+ config .set_flag (trt .BuilderFlag .FP16 )
313+ start_compile = time .time_ns ()
314+ serialized_engine = builder .build_serialized_network (network , config )
315+ end_compile = time .time_ns ()
316+ compile_time_s = (end_compile - start_compile ) / 1e9
317317 # Deserialize the TensorRT engine
318- with trt .Logger () as logger , trt . Runtime (logger ) as runtime :
319- engine = runtime .deserialize_cuda_engine (model )
318+ with trt .Runtime (logger ) as runtime :
319+ engine = runtime .deserialize_cuda_engine (serialized_engine )
320320
321321 print ("Running TensorRT for precision: " , precision , " batch_size : " , batch_size )
322322 iters = params .get ("iterations" , 20 )
@@ -351,7 +351,7 @@ def run_tensorrt(
351351 meas_time = end_time - start_time
352352 timings .append (meas_time )
353353
354- recordStats ("TensorRT" , timings , precision , batch_size )
354+ recordStats ("TensorRT" , timings , precision , batch_size , compile_time_s )
355355
356356
357357# Deploys inference run for different backend configurations
@@ -427,11 +427,10 @@ def run(
427427 )
428428 elif backend == "tensorrt" :
429429 run_tensorrt (
430- model ,
430+ model_torch ,
431431 input_tensors ,
432432 params ,
433433 precision ,
434- is_trt_engine ,
435434 batch_size ,
436435 )
437436 elif backend == "dynamo" :
@@ -440,9 +439,6 @@ def run(
440439 elif backend == "torch_compile" :
441440 run_torch_compile (model_torch , input_tensors , params , precision , batch_size )
442441
443- elif backend == "torch_compile" :
444- run_torch_compile (model_torch , input_tensors , params , precision , batch_size )
445-
446442 elif backend == "inductor" :
447443 run_inductor (model_torch , input_tensors , params , precision , batch_size )
448444
0 commit comments