diff --git a/tools/perf/README.md b/tools/perf/README.md index 4c4a58bfd0..45630b4f29 100644 --- a/tools/perf/README.md +++ b/tools/perf/README.md @@ -4,7 +4,9 @@ This is a comprehensive Python benchmark suite to run perf runs using different 1. Torch 2. Torch-TensorRT -3. TensorRT +3. FX-TRT +4. TensorRT + Note: Please note that for ONNX models, user can convert the ONNX model to TensorRT serialized engine and then use this package. @@ -25,21 +27,35 @@ Benchmark scripts depends on following Python packages in addition to requiremen │ └── vgg16.yml ├── models ├── perf_run.py +├── hub.py +├── custom_models.py +├── requirements.txt +├── benchmark.sh └── README.md ``` -Please save your configuration files at config directory. Similarly, place your model files at models path. + + +* `config` - Directory which contains sample yaml configuration files for VGG network. +* `models` - Model directory +* `perf_run.py` - Performance benchmarking script which supports torch, torch_tensorrt, fx2trt, tensorrt backends +* `hub.py` - Script to download torchscript models for VGG16, Resnet50, EfficientNet-B0, VIT, HF-BERT +* `custom_models.py` - Script which includes custom models other than torchvision and timm (eg: HF BERT) +* `utils.py` - utility functions script +* `benchmark.sh` - This is used for internal performance testing of VGG16, Resnet50, EfficientNet-B0, VIT, HF-BERT. ## Usage +There are two ways you can run a performance benchmark. + +### Using YAML config files + To run the benchmark for a given configuration file: -``` +```python python perf_run.py --config=config/vgg16.yml ``` -## Configuration - There are two sample configuration files added. * vgg16.yml demonstrates a configuration with all the supported backends (Torch, Torch-TensorRT, TensorRT) @@ -48,23 +64,17 @@ There are two sample configuration files added. ### Supported fields -| Name | Supported Values | Description | -| --- | --- | --- | -| backend | all, torch, torch_tensorrt, tensorrt | Supported backends for inference. | -| input | - | Input binding names. Expected to list shapes of each input bindings | -| model | - | Configure the model filename and name | -| filename | - | Model file name to load from disk. | -| name | - | Model name | -| runtime | - | Runtime configurations | -| device | 0 | Target device ID to run inference. Range depends on available GPUs | -| precision | fp32, fp16 or half, int8 | Target precision to run inference. int8 cannot be used with 'all' backend | -| calibration_cache | - | Calibration cache file expected for torch_tensorrt runtime in int8 precision | - -Note: -1. Please note that torch runtime perf is not supported for int8 yet. -2. Torchscript module filename should end with .jit.pt otherwise it will be treated as a TensorRT engine. - - +| Name | Supported Values | Description | +| ----------------- | ------------------------------------ | ------------------------------------------------------------ | +| backend | all, torch, torch_tensorrt, tensorrt | Supported backends for inference. | +| input | - | Input binding names. Expected to list shapes of each input bindings | +| model | - | Configure the model filename and name | +| filename | - | Model file name to load from disk. | +| name | - | Model name | +| runtime | - | Runtime configurations | +| device | 0 | Target device ID to run inference. Range depends on available GPUs | +| precision | fp32, fp16 or half, int8 | Target precision to run inference. int8 cannot be used with 'all' backend | +| calibration_cache | - | Calibration cache file expected for torch_tensorrt runtime in int8 precision | Additional sample use case: @@ -88,3 +98,41 @@ runtime: - fp32 - fp16 ``` + +Note: + +1. Please note that measuring INT8 performance is only supported via a `calibration cache` file or QAT mode for `torch_tensorrt` backend. +2. TensorRT engine filename should end with `.plan` otherwise it will be treated as Torchscript module. + +### Using CompileSpec options via CLI + +Here are the list of `CompileSpec` options that can be provided directly to compile the pytorch module + +* `--backends` : Comma separated string of backends. Eg: torch,torch_tensorrt, tensorrt or fx2trt +* `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `fx2trt`, the input should be a Pytorch module (instead of a torchscript module) and the options for model are (`vgg16` | `resnet50` | `efficientnet_b0`) +* `--inputs` : List of input shapes & dtypes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT +* `--batch_size` : Batch size +* `--precision` : Comma separated list of precisions to build TensorRT engine Eg: fp32,fp16 +* `--device` : Device ID +* `--truncate` : Truncate long and double weights in the network in Torch-TensorRT +* `--is_trt_engine` : Boolean flag to be enabled if the model file provided is a TensorRT engine. +* `--report` : Path of the output file where performance summary is written. + +Eg: + +``` + python perf_run.py --model ${MODELS_DIR}/vgg16_scripted.jit.pt \ + --precision fp32,fp16 --inputs="(1, 3, 224, 224)@fp32" \ + --batch_size 1 \ + --backends torch,torch_tensorrt,tensorrt \ + --report "vgg_perf_bs1.txt" +``` + +### Example models + +This tool benchmarks any pytorch model or torchscript module. As an example, we provide VGG16, Resnet50, EfficientNet-B0, VIT, HF-BERT models in `hub.py` that we internally test for performance. +The torchscript modules for these models can be generated by running +``` +python hub.py +``` +You can refer to `benchmark.sh` on how we run/benchmark these models. diff --git a/tools/perf/benchmark.sh b/tools/perf/benchmark.sh new file mode 100644 index 0000000000..b84061025d --- /dev/null +++ b/tools/perf/benchmark.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +MODELS_DIR="models" + +# Download the Torchscript models +python hub.py + +batch_sizes=(1 2 4 8 16 32 64 128 256) + +#Benchmark VGG16 model +echo "Benchmarking VGG16 model" +for bs in ${batch_sizes[@]} +do + python perf_run.py --model ${MODELS_DIR}/vgg16_scripted.jit.pt \ + --precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \ + --batch_size ${bs} \ + --backends torch,torch_tensorrt,tensorrt \ + --report "vgg_perf_bs${bs}.txt" +done + +# Benchmark Resnet50 model +echo "Benchmarking Resnet50 model" +for bs in ${batch_sizes[@]} +do + python perf_run.py --model ${MODELS_DIR}/resnet50_scripted.jit.pt \ + --precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \ + --batch_size ${bs} \ + --backends torch,torch_tensorrt,tensorrt \ + --report "rn50_perf_bs${bs}.txt" +done + +# Benchmark VIT model +echo "Benchmarking VIT model" +for bs in ${batch_sizes[@]} +do + python perf_run.py --model ${MODELS_DIR}/vit_scripted.jit.pt \ + --precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \ + --batch_size ${bs} \ + --backends torch,torch_tensorrt,tensorrt \ + --report "vit_perf_bs${bs}.txt" +done + +# Benchmark EfficientNet-B0 model +echo "Benchmarking EfficientNet-B0 model" +for bs in ${batch_sizes[@]} +do + python perf_run.py --model ${MODELS_DIR}/efficientnet_b0_scripted.jit.pt \ + --precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \ + --batch_size ${bs} \ + --backends torch,torch_tensorrt,tensorrt \ + --report "eff_b0_perf_bs${bs}.txt" +done + +# Benchmark BERT model +echo "Benchmarking Huggingface BERT base model" +for bs in ${batch_sizes[@]} +do + python perf_run.py --model ${MODELS_DIR}/bert_base_uncased_traced.jit.pt \ + --precision fp32 --inputs="(${bs}, 128)@int32;(${bs}, 128)@int32" \ + --batch_size ${bs} \ + --backends torch,torch_tensorrt \ + --truncate \ + --report "bert_base_perf_bs${bs}.txt" +done diff --git a/tools/perf/config/vgg16.yml b/tools/perf/config/vgg16.yml index 458dc1b1f6..d88d489458 100755 --- a/tools/perf/config/vgg16.yml +++ b/tools/perf/config/vgg16.yml @@ -8,8 +8,9 @@ input: - 224 - 224 num_inputs: 1 + batch_size: 1 model: - filename: models/vgg16_traced.jit.pt + filename: models/vgg16_scripted.jit.pt name: vgg16 runtime: device: 0 diff --git a/tools/perf/custom_models.py b/tools/perf/custom_models.py new file mode 100644 index 0000000000..a8b8a5dae0 --- /dev/null +++ b/tools/perf/custom_models.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn +from transformers import BertModel, BertTokenizer, BertConfig +import torch.nn.functional as F + + +def BertModule(): + model_name = "bert-base-uncased" + enc = BertTokenizer.from_pretrained(model_name) + text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" + tokenized_text = enc.tokenize(text) + masked_index = 8 + tokenized_text[masked_index] = "[MASK]" + indexed_tokens = enc.convert_tokens_to_ids(tokenized_text) + segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] + tokens_tensor = torch.tensor([indexed_tokens]) + segments_tensors = torch.tensor([segments_ids]) + config = BertConfig( + vocab_size_or_config_json_file=32000, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + torchscript=True, + ) + model = BertModel(config) + model.eval() + model = BertModel.from_pretrained(model_name, torchscript=True) + traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) + return traced_model diff --git a/tools/perf/hub.py b/tools/perf/hub.py new file mode 100644 index 0000000000..e54734f8a1 --- /dev/null +++ b/tools/perf/hub.py @@ -0,0 +1,132 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +import timm +from transformers import BertModel, BertTokenizer, BertConfig +import os +import json +import custom_models as cm + +torch.hub._validate_not_a_forked_repo = lambda a, b, c: True + +torch_version = torch.__version__ + +# Detect case of no GPU before deserialization of models on GPU +if not torch.cuda.is_available(): + raise Exception( + "No GPU found. Please check if installed torch version is compatible with CUDA version" + ) + +# Downloads all model files again if manifest file is not present +MANIFEST_FILE = "model_manifest.json" + +BENCHMARK_MODELS = { + "vgg16": {"model": models.vgg16(weights=None), "path": "script"}, + "resnet50": {"model": models.resnet50(weights=None), "path": "script"}, + "efficientnet_b0": { + "model": timm.create_model("efficientnet_b0", pretrained=True), + "path": "script", + }, + "vit": { + "model": timm.create_model("vit_base_patch16_224", pretrained=True), + "path": "script", + }, + "bert_base_uncased": {"model": cm.BertModule(), "path": "trace"}, +} + + +def get(n, m, manifest): + print("Downloading {}".format(n)) + traced_filename = "models/" + n + "_traced.jit.pt" + script_filename = "models/" + n + "_scripted.jit.pt" + x = torch.ones((1, 3, 300, 300)).cuda() + if n == "bert-base-uncased": + traced_model = m["model"] + torch.jit.save(traced_model, traced_filename) + manifest.update({n: [traced_filename]}) + else: + m["model"] = m["model"].eval().cuda() + if m["path"] == "both" or m["path"] == "trace": + trace_model = torch.jit.trace(m["model"], [x]) + torch.jit.save(trace_model, traced_filename) + manifest.update({n: [traced_filename]}) + if m["path"] == "both" or m["path"] == "script": + script_model = torch.jit.script(m["model"]) + torch.jit.save(script_model, script_filename) + if n in manifest.keys(): + files = list(manifest[n]) if type(manifest[n]) != list else manifest[n] + files.append(script_filename) + manifest.update({n: files}) + else: + manifest.update({n: [script_filename]}) + return manifest + + +def download_models(version_matches, manifest): + # Download all models if torch version is different than model version + if not version_matches: + for n, m in BENCHMARK_MODELS.items(): + manifest = get(n, m, manifest) + else: + for n, m in BENCHMARK_MODELS.items(): + scripted_filename = "models/" + n + "_scripted.jit.pt" + traced_filename = "models/" + n + "_traced.jit.pt" + # Check if model file exists on disk + if ( + ( + m["path"] == "both" + and os.path.exists(scripted_filename) + and os.path.exists(traced_filename) + ) + or (m["path"] == "script" and os.path.exists(scripted_filename)) + or (m["path"] == "trace" and os.path.exists(traced_filename)) + ): + print("Skipping {} ".format(n)) + continue + manifest = get(n, m, manifest) + + +def main(): + manifest = None + version_matches = False + manifest_exists = False + + # Check if Manifest file exists or is empty + if not os.path.exists(MANIFEST_FILE) or os.stat(MANIFEST_FILE).st_size == 0: + manifest = {"version": torch_version} + + # Creating an empty manifest file for overwriting post setup + os.system("touch {}".format(MANIFEST_FILE)) + else: + manifest_exists = True + + # Load manifest if already exists + with open(MANIFEST_FILE, "r") as f: + manifest = json.load(f) + if manifest["version"] == torch_version: + version_matches = True + else: + print( + "Torch version: {} mismatches \ + with manifest's version: {}. Re-downloading \ + all models".format( + torch_version, manifest["version"] + ) + ) + + # Overwrite the manifest version as current torch version + manifest["version"] = torch_version + + download_models(version_matches, manifest) + + # Write updated manifest file to disk + with open(MANIFEST_FILE, "r+") as f: + data = f.read() + f.seek(0) + record = json.dumps(manifest) + f.write(record) + f.truncate() + + +main() diff --git a/tools/perf/perf_run.py b/tools/perf/perf_run.py index f0386f4e5a..fbdf3b6c40 100644 --- a/tools/perf/perf_run.py +++ b/tools/perf/perf_run.py @@ -15,7 +15,17 @@ # Importing supported Backends import torch import torch_tensorrt as torchtrt +from torch_tensorrt.fx.lower import compile +from torch_tensorrt.fx.utils import LowerPrecision + import tensorrt as trt +from utils import ( + parse_inputs, + parse_backends, + precision_to_dtype, + parse_precisions, + BENCHMARK_MODELS, +) WARMUP_ITER = 10 results = [] @@ -49,8 +59,8 @@ def get(self, key, default_value=None): # Runs inference using Torch backend -def run_torch(model, input_tensors, params, precision): - print("Running Torch for precision: ", precision) +def run_torch(model, input_tensors, params, precision, batch_size): + print("Running Torch for precision: ", precision, " batch_size : ", batch_size) iters = params.get("iterations", 20) # Warm up @@ -69,19 +79,25 @@ def run_torch(model, input_tensors, params, precision): end_time = timeit.default_timer() meas_time = end_time - start_time timings.append(meas_time) - print("Iteration {}: {:.6f} s".format(i, end_time - start_time)) - printStats("Torch", timings, precision) + recordStats("Torch", timings, precision, batch_size) # Runs inference using Torch-TensorRT backend -def run_torch_tensorrt(model, input_tensors, params, precision): - print("Running Torch-TensorRT") - +def run_torch_tensorrt( + model, input_tensors, params, precision, truncate_long_and_double, batch_size +): + print( + "Running Torch-TensorRT for precision: ", + precision, + " batch_size : ", + batch_size, + ) # Compiling Torch-TensorRT model compile_settings = { "inputs": input_tensors, "enabled_precisions": {precision_to_dtype(precision)}, + "truncate_long_and_double": truncate_long_and_double, } if precision == "int8": @@ -106,9 +122,47 @@ def run_torch_tensorrt(model, input_tensors, params, precision): end_time = timeit.default_timer() meas_time = end_time - start_time timings.append(meas_time) - print("Iteration {}: {:.6f} s".format(i, end_time - start_time)) - printStats("Torch-TensorRT", timings, precision) + recordStats("Torch-TensorRT", timings, precision, batch_size) + + +# Runs inference using FX2TRT backend +def run_fx2trt(model, input_tensors, params, precision, batch_size): + print("Running FX2TRT for precision: ", precision, " batch_size : ", batch_size) + if precision == "fp32": + precision = LowerPrecision.FP32 + elif precision == "fp16": + precision = LowerPrecision.FP16 + model.half() + input_tensors = [tensor.half() for tensor in input_tensors] + # Run lowering eager mode benchmark + model = compile( + model, + input_tensors, + max_batch_size=batch_size, + lower_precision=precision, + verbose_log=False, + ) + + iters = params.get("iterations", 20) + # Warm up + with torch.no_grad(): + for _ in range(WARMUP_ITER): + features = model(*input_tensors) + + torch.cuda.synchronize() + + timings = [] + with torch.no_grad(): + for i in range(iters): + start_time = timeit.default_timer() + features = model(*input_tensors) + torch.cuda.synchronize() + end_time = timeit.default_timer() + meas_time = end_time - start_time + timings.append(meas_time) + + recordStats("FX-TensorRT", timings, precision, batch_size) def torch_dtype_from_trt(dtype): @@ -135,7 +189,15 @@ def torch_device_from_trt(device): return TypeError("%s is not supported by torch" % device) -def run_tensorrt(model, input_tensors, params, precision, is_trt_engine=False): +def run_tensorrt( + model, + input_tensors, + params, + precision, + truncate_long_and_double=False, + is_trt_engine=False, + batch_size=1, +): engine = None # If the model file is a TensorRT engine then directly deserialize and run inference @@ -144,10 +206,11 @@ def run_tensorrt(model, input_tensors, params, precision, is_trt_engine=False): compile_settings = { "inputs": input_tensors, "enabled_precisions": {precision_to_dtype(precision)}, + "truncate_long_and_double": truncate_long_and_double, } print("Converting method to TensorRT engine...") - with torch.no_grad(): + with torch.no_grad(), torchtrt.logging.errors(): model = torchtrt.ts.convert_method_to_trt_engine( model, "forward", **compile_settings ) @@ -156,17 +219,15 @@ def run_tensorrt(model, input_tensors, params, precision, is_trt_engine=False): with trt.Logger() as logger, trt.Runtime(logger) as runtime: engine = runtime.deserialize_cuda_engine(model) - print("Running TensorRT") + print("Running TensorRT for precision: ", precision, " batch_size : ", batch_size) iters = params.get("iterations", 20) - batch_size = params.get("batch", 1) # Compiling the bindings bindings = engine.num_bindings * [None] - k = 0 for idx, _ in enumerate(bindings): dtype = torch_dtype_from_trt(engine.get_binding_dtype(idx)) - shape = (batch_size,) + tuple(engine.get_binding_shape(idx)) + shape = tuple(engine.get_binding_shape(idx)) device = torch_device_from_trt(engine.get_location(idx)) if not engine.binding_is_input(idx): # Output bindings @@ -180,29 +241,32 @@ def run_tensorrt(model, input_tensors, params, precision, is_trt_engine=False): timings = [] with engine.create_execution_context() as context: for i in range(WARMUP_ITER): - context.execute_async( - batch_size, bindings, torch.cuda.current_stream().cuda_stream - ) + context.execute_async_v2(bindings, torch.cuda.current_stream().cuda_stream) torch.cuda.synchronize() for i in range(iters): start_time = timeit.default_timer() - context.execute_async( - batch_size, bindings, torch.cuda.current_stream().cuda_stream - ) + context.execute_async_v2(bindings, torch.cuda.current_stream().cuda_stream) torch.cuda.synchronize() end_time = timeit.default_timer() meas_time = end_time - start_time timings.append(meas_time) - print("Iterations {}: {:.6f} s".format(i, end_time - start_time)) - printStats("TensorRT", timings, precision) + recordStats("TensorRT", timings, precision, batch_size) # Deploys inference run for different backend configurations -def run(model, input_tensors, params, precision, is_trt_engine=False): - for backend in params.get("backend"): - +def run( + model, + backends, + input_tensors, + params, + precision, + truncate_long_and_double=False, + batch_size=1, + is_trt_engine=False, +): + for backend in backends: if precision == "int8": if backend == "all" or backend == "torch": print( @@ -219,22 +283,55 @@ def run(model, input_tensors, params, precision, is_trt_engine=False): return False if backend == "all": - run_torch(model, input_tensors, params, precision) - run_torch_tensorrt(model, input_tensors, params, precision) - run_tensorrt(model, input_tensors, params, precision, is_trt_engine) + run_torch(model, input_tensors, params, precision, batch_size) + run_torch_tensorrt( + model, + input_tensors, + params, + precision, + truncate_long_and_double, + batch_size, + ) + run_tensorrt( + model, + input_tensors, + params, + precision, + truncate_long_and_double, + is_trt_engine, + batch_size, + ) elif backend == "torch": - run_torch(model, input_tensors, params, precision) + run_torch(model, input_tensors, params, precision, batch_size) elif backend == "torch_tensorrt": - run_torch_tensorrt(model, input_tensors, params, precision) + run_torch_tensorrt( + model, + input_tensors, + params, + precision, + truncate_long_and_double, + batch_size, + ) + + elif backend == "fx2trt": + run_fx2trt(model, input_tensors, params, precision, batch_size) elif backend == "tensorrt": - run_tensorrt(model, input_tensors, params, precision, is_trt_engine) + run_tensorrt( + model, + input_tensors, + params, + precision, + truncate_long_and_double, + is_trt_engine, + batch_size, + ) # Generate report -def printStats(backend, timings, precision, batch_size=1): +def recordStats(backend, timings, precision, batch_size=1): times = np.array(timings) steps = len(times) speeds = batch_size / times @@ -245,43 +342,16 @@ def printStats(backend, timings, precision, batch_size=1): speed_mean = np.mean(speeds) speed_med = np.median(speeds) - msg = ( - "\n%s =================================\n" - "batch size=%d, num iterations=%d\n" - " Median FPS: %.1f, mean: %.1f\n" - " Median latency: %.6f, mean: %.6f, 99th_p: %.6f, std_dev: %.6f\n" - ) % ( - backend, - batch_size, - steps, - speed_med, - speed_mean, - time_med, - time_mean, - time_99th, - time_std, - ) - print(msg) - meas = { + stats = { "Backend": backend, - "precision": precision, + "Precision": precision, + "Batch size": batch_size, "Median(FPS)": speed_med, "Mean(FPS)": speed_mean, - "Median-Latency(ms)": time_med, - "Mean-Latency(ms)": time_mean, - "99th_p": time_99th, - "std_dev": time_std, + "Median-Latency(ms)": time_med * 1000, + "Mean-Latency(ms)": time_mean * 1000, } - results.append(meas) - - -def precision_to_dtype(pr): - if pr == "fp32": - return torch.float - elif pr == "fp16" or pr == "half": - return torch.half - else: - return torch.int8 + results.append(stats) def load_model(params): @@ -289,15 +359,21 @@ def load_model(params): is_trt_engine = False # Load torch model traced/scripted model_file = params.get("model").get("filename") + try: + model_name = params.get("model").get("name") + except: + model_name = model_file - if model_file.endswith(".jit.pt"): - model = torch.jit.load(model_file).cuda() - else: + print("Loading model: ", model_file) + if model_file.endswith(".plan"): is_trt_engine = True # Read the TensorRT engine file with open(model_file, "rb") as fin: model = fin.read() - return model, is_trt_engine + else: + model = torch.jit.load(model_file).cuda() + + return model, model_name, is_trt_engine if __name__ == "__main__": @@ -306,57 +382,147 @@ def load_model(params): ) arg_parser.add_argument( "--config", + type=str, help="Load YAML based configuration file to run the inference. If this is used other params will be ignored", ) + # The following options are manual user provided settings + arg_parser.add_argument( + "--backends", + type=str, + help="Comma separated string of backends. Eg: torch,torch_tensorrt,fx2trt,tensorrt", + ) + arg_parser.add_argument("--model", type=str, help="Name of the model file") + arg_parser.add_argument( + "--inputs", + type=str, + help="List of input shapes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT", + ) + arg_parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size to build and run" + ) + arg_parser.add_argument( + "--precision", + default="fp32", + type=str, + help="Comma separated list of precisions to build TensorRT engine Eg: fp32,fp16", + ) + arg_parser.add_argument( + "--calibration_cache", type=str, help="Name of the calibration cache file" + ) + arg_parser.add_argument("--device", type=int, help="device id") + arg_parser.add_argument( + "--truncate", + action="store_true", + help="Truncate long and double weights in the network in Torch-TensorRT", + ) + arg_parser.add_argument( + "--is_trt_engine", + action="store_true", + help="Boolean flag to determine if the user provided model is a TRT engine or not", + ) + arg_parser.add_argument( + "--report", + type=str, + help="Path of the output file where performance summary is written.", + ) args = arg_parser.parse_args() - parser = ConfigParser(args.config) - # Load YAML params - params = parser.read_config() - print("Loading model: ", params.get("model").get("filename")) - - model = None - - # Default device is set to 0. Configurable using yaml config file. - torch.cuda.set_device(params.get("runtime").get("device", 0)) - - # Load the model file from disk. If the loaded file is TensorRT engine then is_trt_engine is returned as True - model, is_trt_engine = load_model(params) cudnn.benchmark = True - # Create random input tensor of certain size torch.manual_seed(12345) + model_name = "Model" + if args.config: + parser = ConfigParser(args.config) + # Load YAML params + params = parser.read_config() + model, model_name, is_trt_engine = load_model(params) + + # Default device is set to 0. Configurable using yaml config file. + torch.cuda.set_device(params.get("runtime").get("device", 0)) + + num_input = params.get("input").get("num_inputs") + truncate_long_and_double = params.get("runtime").get( + "truncate_long_and_double", False + ) + batch_size = params.get("input").get("batch_size", 1) + for precision in params.get("runtime").get("precision", "fp32"): + input_tensors = [] + num_input = params.get("input").get("num_inputs", 1) + for i in range(num_input): + inp_tensor = params.get("input").get("input" + str(i)) + input_tensors.append( + torch.randint( + 0, + 2, + tuple(d for d in inp_tensor), + dtype=precision_to_dtype(precision), + ).cuda() + ) - num_input = params.get("input").get("num_inputs") - for precision in params.get("runtime").get("precision", "fp32"): - input_tensors = [] - num_input = params.get("input").get("num_inputs", 1) - for i in range(num_input): - inp_tensor = params.get("input").get("input" + str(i)) - input_tensors.append( - torch.randint( - 0, - 2, - tuple(d for d in inp_tensor), - dtype=precision_to_dtype(precision), - ).cuda() - ) + if is_trt_engine: + print( + "Warning, TensorRT engine file is configured. Please make sure the precision matches with the TRT engine for reliable results" + ) - if is_trt_engine: - print( - "Warning, TensorRT engine file is configured. Please make sure the precision matches with the TRT engine for reliable results" + if not is_trt_engine and (precision == "fp16" or precision == "half"): + # If model is TensorRT serialized engine then model.half will report failure + model = model.half() + + backends = params.get("backend") + # Run inference + status = run( + model, + backends, + input_tensors, + params, + precision, + truncate_long_and_double, + batch_size, + is_trt_engine, + ) + else: + params = vars(args) + model_name = params["model"] + if os.path.exists(model_name): + print("Loading user provided model: ", model_name) + model = torch.jit.load(model_name).cuda().eval() + elif model_name in BENCHMARK_MODELS: + model = BENCHMARK_MODELS[model_name]["model"].eval().cuda() + else: + raise ValueError( + "Invalid model name. Please provide a torchscript model file or model name (among the following options vgg16|resnet50|efficientnet_b0|vit)" ) - if not is_trt_engine and precision == "fp16" or precision == "half": - # If model is TensorRT serialized engine then model.half will report failure - model = model.half() + backends = parse_backends(params["backends"]) + truncate_long_and_double = params["truncate"] + batch_size = params["batch_size"] + is_trt_engine = params["is_trt_engine"] + precisions = parse_precisions(params["precision"]) - # Run inference - status = run(model, input_tensors, params, precision, is_trt_engine) - if status == False: - continue + for precision in precisions: + input_tensors = parse_inputs( + params["inputs"], precision_to_dtype(precision) + ) + if not is_trt_engine and (precision == "fp16" or precision == "half"): + # If model is TensorRT serialized engine then model.half will report failure + model = model.half() + status = run( + model, + backends, + input_tensors, + params, + precision, + truncate_long_and_double, + batch_size, + is_trt_engine, + ) # Generate report - print("Model Summary:") + print("Model Summary: ", model_name) summary = pd.DataFrame(results) print(summary) + if args.report: + with open(args.report, "w") as file: + file.write("Model Summary: " + model_name + "\n") + file.write(summary.to_string()) + file.close() diff --git a/tools/perf/utils.py b/tools/perf/utils.py new file mode 100644 index 0000000000..3d63dcd4b7 --- /dev/null +++ b/tools/perf/utils.py @@ -0,0 +1,61 @@ +import torch +import torch_tensorrt +import custom_models as cm +import torchvision.models as models +import timm + +BENCHMARK_MODELS = { + "vgg16": {"model": models.vgg16(pretrained=True), "path": "script"}, + "resnet50": { + "model": torch.hub.load("pytorch/vision:v0.9.0", "resnet50", pretrained=True), + "path": "script", + }, + "efficientnet_b0": { + "model": timm.create_model("efficientnet_b0", pretrained=True), + "path": "script", + }, + "vit": { + "model": timm.create_model("vit_base_patch16_224", pretrained=True), + "path": "script", + }, + "bert_base_uncased": {"model": cm.BertModule(), "path": "trace"}, +} + + +def precision_to_dtype(pr): + if pr == "fp32": + return torch.float + elif pr == "fp16" or pr == "half": + return torch.half + elif pr == "int32": + return torch.int32 + elif pr == "bool": + return torch.bool + else: + return torch.float32 + + +def parse_inputs(user_inputs, dtype): + parsed_inputs = user_inputs.split(";") + torchtrt_inputs = [] + for input in parsed_inputs: + input_shape = [] + input_shape_and_dtype = input.split("@") + dtype = ( + precision_to_dtype(input_shape_and_dtype[1]) + if len(input_shape_and_dtype) == 2 + else dtype + ) + for input_dim in input_shape_and_dtype[0][1:-1].split(","): + input_shape.append(int(input_dim)) + torchtrt_inputs.append(torch.randint(0, 5, input_shape, dtype=dtype).cuda()) + + return torchtrt_inputs + + +def parse_backends(backends): + return backends.split(",") + + +def parse_precisions(precisions): + return precisions.split(",")