@@ -293,29 +293,30 @@ def run_tensorrt(
293293 input_tensors ,
294294 params ,
295295 precision ,
296- is_trt_engine = False ,
297296 batch_size = 1 ,
298297):
299- engine = None
300-
301- # If the model file is a TensorRT engine then directly deserialize and run inference
302- # else convert the torch module to a TensorRT engine first and then run inference
303- if not is_trt_engine :
304- compile_settings = {
305- "inputs" : input_tensors ,
306- "enabled_precisions" : {precision_to_dtype (precision )},
307- "truncate_long_and_double" : params .get ("truncate" , False ),
308- }
309-
310- print ("Converting method to TensorRT engine..." )
311- with torch .no_grad (), torchtrt .logging .errors ():
312- model = torchtrt .ts .convert_method_to_trt_engine (
313- model , "forward" , ** compile_settings
314- )
315-
298+ # Export an ONNX model and convert to TRT
299+ torch .onnx .export (model .eval ().cuda (), tuple (input_tensors ), "./tmp.onnx" )
300+ logger = trt .Logger (trt .Logger .WARNING )
301+ builder = trt .Builder (logger )
302+ network = builder .create_network (
303+ 1 << int (trt .NetworkDefinitionCreationFlag .EXPLICIT_BATCH )
304+ )
305+ parser = trt .OnnxParser (network , logger )
306+ success = parser .parse_from_file ("./tmp.onnx" )
307+ if not success :
308+ raise ValueError ("ONNX conversion failed" )
309+
310+ config = builder .create_builder_config ()
311+ if precision == "fp16" :
312+ config .set_flag (trt .BuilderFlag .FP16 )
313+ start_compile = time .time_ns ()
314+ serialized_engine = builder .build_serialized_network (network , config )
315+ end_compile = time .time_ns ()
316+ compile_time_s = (end_compile - start_compile ) / 1e9
316317 # Deserialize the TensorRT engine
317- with trt .Logger () as logger , trt . Runtime (logger ) as runtime :
318- engine = runtime .deserialize_cuda_engine (model )
318+ with trt .Runtime (logger ) as runtime :
319+ engine = runtime .deserialize_cuda_engine (serialized_engine )
319320
320321 print ("Running TensorRT for precision: " , precision , " batch_size : " , batch_size )
321322 iters = params .get ("iterations" , 20 )
@@ -350,7 +351,7 @@ def run_tensorrt(
350351 meas_time = end_time - start_time
351352 timings .append (meas_time )
352353
353- recordStats ("TensorRT" , timings , precision , batch_size )
354+ recordStats ("TensorRT" , timings , precision , batch_size , compile_time_s )
354355
355356
356357# Deploys inference run for different backend configurations
@@ -426,11 +427,10 @@ def run(
426427 )
427428 elif backend == "tensorrt" :
428429 run_tensorrt (
429- model ,
430+ model_torch ,
430431 input_tensors ,
431432 params ,
432433 precision ,
433- is_trt_engine ,
434434 batch_size ,
435435 )
436436 elif backend == "dynamo" :
@@ -439,9 +439,6 @@ def run(
439439 elif backend == "torch_compile" :
440440 run_torch_compile (model_torch , input_tensors , params , precision , batch_size )
441441
442- elif backend == "torch_compile" :
443- run_torch_compile (model_torch , input_tensors , params , precision , batch_size )
444-
445442 elif backend == "inductor" :
446443 run_inductor (model_torch , input_tensors , params , precision , batch_size )
447444
0 commit comments