diff --git a/torchbenchmark/util/backends/trt.py b/torchbenchmark/util/backends/trt.py index 98bbaf1942..8091769bd3 100644 --- a/torchbenchmark/util/backends/trt.py +++ b/torchbenchmark/util/backends/trt.py @@ -1,26 +1,68 @@ from typing import List import torch +import argparse -from torchbenchmark.util.backends import create_backend +from torchbenchmark.util.backends import create_backend from torchbenchmark.util.env_check import is_hf_model + +def parse_torch_trt_args(backend_args: List[str]): + """Parses CLI-provided backend arguments to extract Torch-TRT keywords + + Returns kwargs dictionary and remainder arguments which were unrecognized + """ + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument( + "--truncate_long_and_double", + default=None, + action="store_true", + help="Whether to automatically truncate long and double operations", + ) + arg_parser.add_argument( + "--workspace_size", type=int, help="Size of workspace allotted to TensorRT" + ) + arg_parser.add_argument( + "--min_block_size", + type=int, + help="Minimum number of operations in an accelerated TRT block", + ) + arg_parser.add_argument( + "--ir", + type=str, + help="Which internal representation to use: {'ts', 'dynamo_compile', 'fx_ts_compat', ...}", + ) + args, unknown = arg_parser.parse_known_args(backend_args) + + # Remove unspecified arguments from the args dictionary + # (Only pass through user-specified args) + parsed_args = vars(args) + for key in list(parsed_args.keys()): + if parsed_args[key] is None: + del parsed_args[key] + + return parsed_args, unknown + + @create_backend -def fx2trt(model: 'torchbenchmark.util.model.BenchmarkModel', backend_args: List[str]): +def fx2trt(model: "torchbenchmark.util.model.BenchmarkModel", backend_args: List[str]): FP16 = True if model.dargs.precision == "fp16" else False HF_MODEL = True if is_hf_model(model) else False - assert model.device == "cuda" and model.test == "eval", f"fx2trt only works on CUDA inference tests." + assert ( + model.device == "cuda" and model.test == "eval" + ), f"fx2trt only works on CUDA inference tests." + def _fx2trt(): from torch_tensorrt.fx import compile from torch_tensorrt.fx.utils import LowerPrecision + module, example_inputs = model.get_module() precision = LowerPrecision.FP16 if FP16 else LowerPrecision.FP32 if HF_MODEL: from transformers.utils.fx import symbolic_trace as hf_symbolic_trace + traced_model = hf_symbolic_trace( - module, - batch_size = model.batch_size, - sequence_lenghth = model.max_length + module, batch_size=model.batch_size, sequence_lenghth=model.max_length ) trt_model = compile( traced_model, @@ -31,27 +73,58 @@ def _fx2trt(): max_workspace_size=20 << 30, ) else: - trt_model = compile(module=module, - input=example_inputs, - max_batch_size=model.batch_size, - lower_precision=precision) + trt_model = compile( + module=module, + input=example_inputs, + max_batch_size=model.batch_size, + lower_precision=precision, + ) model.set_module(trt_model) + return _fx2trt, backend_args + @create_backend -def torch_trt(model: 'torchbenchmark.util.model.BenchmarkModel', backend_args: List[str]): +def torch_trt( + model: "torchbenchmark.util.model.BenchmarkModel", backend_args: List[str] +): + """Backend for Torch-TRT + + Can be directly invoked from the command line, for example via: + python run.py resnet18 -d cuda -t eval --backend torch_trt --precision fp32 --truncate_long_and_double + + Options include: + --truncate_long_and_double: Whether to automatically truncate long and double operations + --min_block_size: Minimum number of operations in an accelerated TRT block + --workspace_size: Size of workspace allotted to TensorRT + --ir: Which internal representation to use: {"ts", "dynamo_compile", "fx_ts_compat", ...} + """ FP16 = True if model.dargs.precision == "fp16" else False - assert model.device == "cuda" and model.test == "eval", f"fx2trt only works on CUDA inference tests." + assert ( + model.device == "cuda" and model.test == "eval" + ), f"Torch-TRT only works on CUDA inference tests." + + # Extract relevant Torch-TRT arguments from the provided CLI arguments + torch_trt_kwargs, backend_args = parse_torch_trt_args(backend_args) + def _torch_trt(): + """Helper function for invoking Torch-TRT""" import torch_tensorrt + module, example_inputs = model.get_module() - if FP16: - torchtrt_dtype = torch_tensorrt.dtype.half - torch_dtype = torch.half - else: - torchtrt_dtype = torch_tensorrt.dtype.float - torch_dtype = torch.float32 - trt_input = [torch_tensorrt.Input(shape=example_inputs[0].shape, dtype=torch_dtype)] - trt_module = torch_tensorrt.compile(module, inputs=trt_input, enabled_precisions=torchtrt_dtype) + torch_dtype_precision = torch.half if FP16 else torch.float32 + + print( + f"Compiling {model.name} with batch size {model.batch_size}, precision {model.dargs.precision}, " + + f"and {'default' if 'ir' not in torch_trt_kwargs else torch_trt_kwargs['ir']} IR" + ) + + trt_module = torch_tensorrt.compile( + module, + inputs=example_inputs, + enabled_precisions={torch_dtype_precision}, + **torch_trt_kwargs, + ) model.set_module(trt_module) + return _torch_trt, backend_args diff --git a/userbenchmark/torch_trt/__init__.py b/userbenchmark/torch_trt/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/userbenchmark/torch_trt/run.py b/userbenchmark/torch_trt/run.py new file mode 100644 index 0000000000..8547bee176 --- /dev/null +++ b/userbenchmark/torch_trt/run.py @@ -0,0 +1,223 @@ +import argparse +import traceback +import torch + +import numpy as np + +import json +import os +import time +from datetime import datetime +from typing import List + +from torchbenchmark import ( + load_canary_model_by_name, + load_model_by_name, + list_models, + ModelNotFoundError, +) + + +def cli(args: List[str]): + """Parse input arguments, extracting model specification and batch size""" + arg_parser = argparse.ArgumentParser(args) + arg_parser.add_argument( + "--model", + help="Full or partial name of a model to run. If partial, picks the first match.", + default="", + type=str, + ) + arg_parser.add_argument( + "--bs", + help="Input batch size to test.", + default=1, + type=int, + ) + arg_parser.add_argument( + "--num_warmup", + help="Number of inference warmup iterations.", + default=10, + type=int, + ) + arg_parser.add_argument( + "--num_iter", + help="Number of inference iterations for benchmarking.", + default=100, + type=int, + ) + parsed_args, unknown = arg_parser.parse_known_args() + + return vars(parsed_args), unknown + + +def save_metrics(metrics): + """Save metrics to a JSON file with formatted filename""" + metrics_json = { + "name": "torch_trt", + "environ": { + "metrics_version": "v0.1", + "pytorch_git_version": torch.version.git_version, + }, + "metrics": metrics, + } + + # Obtain target save directory for JSON metrics from current save directory + current_dir = os.path.dirname(os.path.abspath(__file__)) + target_dir = os.path.normpath( + os.path.join(current_dir, "../../.userbenchmark/torch_trt/") + ) + + os.makedirs(target_dir, exist_ok=True) + + # Format filename and path to save metrics + metrics_file = "metrics-{}.json".format( + datetime.fromtimestamp(time.time()).strftime("%Y%m%d%H%M%S") + ) + metrics_save_path = os.path.join(target_dir, metrics_file) + + with open(metrics_save_path, "w") as f: + json.dump(metrics_json, f, indent=4) + + +def run_single_model( + Model, + batch_size: int, + extra_args: List[str], + selected_ir: str, + num_warmup: int, + num_iter: int, +): + """Run inference benchmarking on a single model""" + # Build TorchBench model instance, with backend having the userbenchmark name + # This invokes the torch_trt backend functionality directly + model = Model( + device="cuda", + test="eval", + jit=False, + batch_size=batch_size, + extra_args=[ + "--backend", + ] + + extra_args, + ) + + metrics = run_one_step(model.invoke, model, num_warmup, num_iter, selected_ir) + + # Print dynamo compilation metrics, if there are any. + try: + if model.pt2_compilation_time: + metrics[ + f"{model.name}.bs_{model.batch_size}.precision_{model.dargs.precision}." + + f"ir_{selected_ir}.pt2_compilation_time" + ] = model.pt2_compilation_time + except: + pass + + return metrics + + +def run_one_step(func, model, num_warmup, num_iter, selected_ir): + # Warmup model inference + for _ in range(num_warmup): + func() + + result_summary = [] + + # Run inference for the specified number of iterations + for _ in range(num_iter): + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Collect time_ns() instead of time() which does not provide better precision than 1 + # second according to https://docs.python.org/3/library/time.html#time.time. + t0 = time.time_ns() + start_event.record() + func() + end_event.record() + torch.cuda.synchronize() + t1 = time.time_ns() + result_summary.append( + (start_event.elapsed_time(end_event), (t1 - t0) / 1_000_000) + ) + + # Get median times for GPU and CPU Walltime + gpu_time = np.median(list(map(lambda x: x[0], result_summary))) + cpu_walltime = np.median(list(map(lambda x: x[1], result_summary))) + + if hasattr(model, "NUM_BATCHES"): + median_gpu_time_per_batch = gpu_time / model.NUM_BATCHES + median_cpu_walltime_per_batch = cpu_walltime / model.NUM_BATCHES + else: + median_gpu_time_per_batch = gpu_time + median_cpu_walltime_per_batch = cpu_walltime + + metrics = { + f"{model.name}.bs_{model.batch_size}.precision_{model.dargs.precision}." + + f"ir_{selected_ir}.median_gpu_time_per_batch": median_gpu_time_per_batch, + f"{model.name}.bs_{model.batch_size}.precision_{model.dargs.precision}." + + f"ir_{selected_ir}.median_cpu_walltime_per_batch": median_cpu_walltime_per_batch, + } + + return metrics + + +def run(args: List[str]): + """Run inference and extract requested metrics""" + parsed_args, unknown_args = cli(args) + + # Attempt to extract specified IR for logging purposes + try: + ir_idx = unknown_args.index("--ir") + selected_ir = unknown_args[ir_idx + 1] + except (ValueError, IndexError): + selected_ir = "default" + + # Parse model string if specified, otherwise run all models + # Adapted from benchmark/run.py + if parsed_args["model"]: + try: + Model = load_model_by_name(parsed_args["model"]) + except ModuleNotFoundError: + traceback.print_exc() + exit(-1) + except ModelNotFoundError: + print( + f"Warning: The model {parsed_args['model']} cannot be found at core set." + ) + if not Model: + try: + Model = load_canary_model_by_name(parsed_args["model"]) + except ModuleNotFoundError: + traceback.print_exc() + exit(-1) + except ModelNotFoundError: + print( + f"Error: The model {parsed_args['model']} cannot be found at either core or canary model set." + ) + exit(-1) + + all_metrics = run_single_model( + Model, + parsed_args["bs"], + unknown_args, + selected_ir, + parsed_args["num_warmup"], + parsed_args["num_iter"], + ) + + else: + all_metrics = {} + + for Model in list_models(): + metrics = run_single_model( + Model, + parsed_args["bs"], + unknown_args, + selected_ir, + parsed_args["num_warmup"], + parsed_args["num_iter"], + ) + all_metrics = {**all_metrics, **metrics} + + save_metrics(all_metrics)