From 263e441c75d6bdced22068f3df9e806338629f5c Mon Sep 17 00:00:00 2001
From: Wei
Date: Sun, 26 Jun 2022 00:42:21 -0700
Subject: [PATCH 01/10] Create getting_started_with_fx_path.rst
---
.../getting_started_with_fx_path.rst | 304 ++++++++++++++++++
1 file changed, 304 insertions(+)
create mode 100644 docsrc/tutorials/getting_started_with_fx_path.rst
diff --git a/docsrc/tutorials/getting_started_with_fx_path.rst b/docsrc/tutorials/getting_started_with_fx_path.rst
new file mode 100644
index 0000000000..9f15616e28
--- /dev/null
+++ b/docsrc/tutorials/getting_started_with_fx_path.rst
@@ -0,0 +1,304 @@
+.. user_guide:
+Torch-TensorRT (FX Path) User Guide
+========================
+Torch-TensorRT (FX Path) is a tool that can convert a PyTorch model through torch.FX to an TensorRT engine optimized targeting running on Nvidia GPUs. TensorRT is the inference engine developed by Nvidia which composed of various kinds of optimization including kernel fusion, graph optimization, low precision, etc..
+This tool is developed in Python environment providing most usability to researchers and engineers. There are a few stages that a user want to use this tool and we will introduce them here.
+
+Installation
+------------
+* Method 1. Follow the instrucions for Torch-TensorRT
+* Method 2. To install FX path only (Python path) and avoid the C++ build for torchscript path
+
+.. code-block:: shell
+
+ $ conda create --name python_env python=3.8
+ $ conda activate python_env
+
+ # Recommend to install PyTorch 1.12 and later
+ $ conda install pytorch torchvision torchtext cudatoolkit=11.3 -c pytorch-nightly
+
+ # Install TensorRT python package
+ $ pip3 install nvidia-pyindex
+ $ pip3 install nvidia-tensorrt==8.2.4.2
+ $ git clone https://github.com/pytorch/TensorRT.git
+ $ cd TensorRT/py && python setup.py install --fx-only && cd ..
+
+ $ pyton -c "import torch_tensorrt.fx"
+ # Test an example by
+ $ python py/torch_tensorrt/fx/example/lower_example.py
+
+
+Converting a PyTorch Model to TensorRT Engine
+---------------------------------------------
+We will go through an example to illustrate the major steps that FX path uses to
+
+* **Step 1: Trace the model with acc_tracer**
+Acc_tracer is a tracer inheritated from FX tracer. It comes with args normalizer to convert all args to kwargs and pass to TRT converters.
+
+.. code-block:: shell
+
+ import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
+
+ # Build the model which needs to be a PyTorch nn.Module.
+ my_pytorch_model = build_model()
+
+ # Prepare inputs to the model. Inputs have to be a List of Tensors
+ inputs = [Tensor, Tensor, ...]
+
+ # Trace the model with acc_tracer.
+ acc_mod = acc_tracer.trace(my_pytorch_model, inputs)
+
+*Common Errors:*
+
+symbolically traced variables cannot be used as inputs to control flow
+This means the model contains dynamic control flow. Please refer to section “Dynamic Control Flow” in `FX guide `_.
+
+* **Step 2: Build TensorRT engine**
+There are `two different modes `_ for how TensorRT handles batch dimension, explicit batch dimension and implicit batch dimension. This mode was used by early versions of TensorRT, and is now deprecated but continues to be supported for backwards compatibility. In explicit batch mode, all dimensions are explicit and can be dynamic, that is their length can change at execution time. Many new features, such as dynamic shapes and loops, are available only in this mode. User can still choose to use implicit batch mode when they set ``explicit_batch_dimension=False`` in ``lower_to_trt()``. We do not recommend to use it since it will lack of support in future TensorRT versions.
+
+Explicit batch is the default mode and it must be set for dynamic shape. For most of vision task, user can choose to enable ``dynamic_batch`` in ``lower_to_trt()`` if they want to get the similar effects as implicit mode where only batch dimension changes. It has some requirements:
+1. Shapes of inputs, outputs and activations are fixed except batch dimension.
+2. Inputs, outputs and activations have batch dimension as the major dimension.
+3. All the operators in the model do not modify batch dimension (permute, transpose, split, etc.) or compute over batch dimension (sum, softmax, etc.).
+
+For examples of the last path, if we have a 3D tensor t shaped as (batch, sequence, dimension), operations such as torch.transpose(0, 2). If any of these three are not satisfied, we’ll need to specify InputTensorSpec as inputs with dynamic range.
+
+.. code-block:: shell
+
+ import deeplearning.trt.fx2trt.converter.converters
+ from torch.fx.experimental.fx2trt.fx2trt import InputTensorSpec, TRTInterpreter
+
+ # InputTensorSpec is a dataclass we use to store input information.
+ # There're two ways we can build input_specs.
+ # Option 1, build it manually.
+ input_specs = [
+ InputTensorSpec(shape=(1, 2, 3), dtype=torch.float32),
+ InputTensorSpec(shape=(1, 4, 5), dtype=torch.float32),
+ ]
+ # Option 2, build it using sample_inputs where user provide a sample
+ inputs = [
+ torch.rand((1,2,3), dtype=torch.float32),
+ torch.rand((1,4,5), dtype=torch.float32),
+ ]
+ input_specs = InputTensorSpec.from_tensors(inputs)
+
+ # IMPORTANT: If dynamic shape is needed, we need to build it slightly differently.
+ input_specs = [
+ InputTensorSpec(
+ shape=(-1, 2, 3),
+ dtype=torch.float32,
+ # Currently we only support one set of dynamic range. User may set other dimensions but it is not promised to work for any models
+ # (min_shape, optimize_target_shape, max_shape)
+ # For more information refer to fx/input_tensor_spec.py
+ shape_ranges = [
+ ((1, 2, 3), (4, 2, 3), (100, 2, 3)),
+ ],
+ ),
+ InputTensorSpec(shape=(1, 4, 5), dtype=torch.float32),
+ ]
+
+ # Build a TRT interpreter. Set explicit_batch_dimension accordingly.
+ interpreter = TRTInterpreter(
+ acc_mod, input_specs, explicit_batch_dimension=True/False
+ )
+
+ # The output of TRTInterpreter run() is wrapped as TRTInterpreterResult.
+ # The TRTInterpreterResult contains required parameter to build TRTModule,
+ # and other informational output from TRTInterpreter run.
+ class TRTInterpreterResult(NamedTuple):
+ engine: Any
+ input_names: Sequence[str]
+ output_names: Sequence[str]
+ serialized_cache: bytearray
+
+ #max_batch_size: set accordingly for maximum batch size you will use.
+ #max_workspace_size: set to the maximum size we can afford for temporary buffer
+ #lower_precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
+ #sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity
+ #force_fp32_output: force output to be fp32
+ #strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric #reasons.
+ #algorithm_selector: set up algorithm selection for certain layer
+ #timing_cache: enable timing cache for TensorRT
+ #profiling_verbosity: TensorRT logging level
+ trt_interpreter_result = interpreter.run(
+ max_batch_size=64,
+ max_workspace_size=1 << 25,
+ sparse_weights=False,
+ force_fp32_output=False,
+ strict_type_constraints=False,
+ algorithm_selector=None,
+ timing_cache=None,
+ profiling_verbosity=None,
+ )
+
+
+*Common Errors:*
+
+RuntimeError: Conversion of function xxx not currently supported!
+- This means we don’t have the support for this xxx operator. Please refer to section “How to add a missing op” below for further instructions.
+
+* **Step 3: Run the model**
+One way is using TRTModule, which is basically a PyTorch nn.Module.
+
+.. code-block:: shell
+
+ from torch_tensorrt.fx import TRTModule
+ mod = TRTModule(
+ trt_interpreter_result.engine,
+ trt_interpreter_result.input_names,
+ trt_interpreter_result.output_names)
+ # Just like all other PyTorch modules
+ outputs = mod(*inputs)
+ torch.save(mod, "trt.pt")
+ reload_trt_mod = torch.load("trt.pt")
+ reload_model_output = reload_trt_mod(*inputs)
+
+So far, we give a detailed explanation of major steps in convterting a PyTorch model into TensorRT engine. Users are welcome to refer to the source code for some parameters explanations. In the converting scheme, there are two important actions in it. One is acc tracer which helps us to convert a PyTorch model to acc graph. The other is FX path converter which helps to convert the acc graph's operation to corresponding TensorRT operation and build up the TensoRT engine for it.
+
+Acc Tracer
+---------
+
+Acc tracer is a custom FX symbolic tracer. It does a couple more things compare to the vanilla FX symbolic tracer. We mainly depend on it to convert PyTorch ops or builtin ops to acc ops. There are two main purposes for fx2trt to use acc ops:
+
+1. there’re many ops that do similar things in PyTorch ops and builtin ops such like torch.add, builtin.add and torch.Tensor.add. Using acc tracer, we normalize these three ops to a single acc_ops.add. This helps reduce the number of converters we need to write.
+2. acc ops only have kwargs which makes writing converter easier as we don’t need to add additional logic to find arguments in args and kwargs.
+
+FX2TRT
+--------
+After symbolic tracing, we have the graph representation of a PyTorch model. fx2trt leverages the power of fx.Interpreter. fx.Interpreter goes through the whole graph node by node and calls the function that node represents. fx2trt overrides the original behavior of calling the function with invoking corresponding converts for each node. Each converter function adds corresponding TensorRT layer(s).
+
+Below is an example of a converter function. The decorator is used to register this converter function with the corresponding node. In this example, we register this converter to a fx node whose target is acc_ops.sigmoid.
+
+.. code-block:: shell
+
+ @tensorrt_converter(acc_ops.sigmoid)
+ def acc_ops_sigmoid(network, target, args, kwargs, name):
+ """
+ network: TensorRT network. We'll be adding layers to it.
+
+ The rest arguments are attributes of fx node.
+ """
+ input_val = kwargs['input']
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(f'Sigmoid received input {input_val} that is not part '
+ 'of the TensorRT region!')
+
+ layer = network.add_activation(input=input_val, type=trt.ActivationType.SIGMOID)
+ layer.name = name
+ return layer.get_output(0)
+
+How to Add a Missing Op
+****************
+
+You can actually add it wherever you want just need to remember import the file so that all acc ops and mapper will be registered before tracing with acc_tracer.
+
+* **Step 1. Add a new acc op**
+
+TODO: Need to explain more on the logistic of acc op like when we want to break down an op and when we want to reuse other ops.
+
+In `acc tracer `_, we convert nodes in the graph to acc ops if there’s a mapping registered for the node to an acc op.
+
+In order to make the conversion to acc ops to happen, there’re two things required. One is that there should be an acc op function defined and the other is there should be a mapping registered.
+
+Defining an acc op is simple, we first just need a function and register the function as an acc op via this decorator `acc_normalizer.py `_. e.g. the following code adds an acc op named foo() which adds two given inputs.
+
+.. code-block:: shell
+
+ # NOTE: all acc ops should only take kwargs as inputs, therefore we need the "*"
+ # at the beginning.
+ @register_acc_op
+ def foo(*, input, other, alpha):
+ return input + alpha * other
+
+There’re two ways to register a mapping. One is `register_acc_op_mapping() `_. Let’s register a mapping from torch.add to foo() we just created above. We need to add decorator register_acc_op_mapping to it.
+
+.. code-block:: shell
+
+ this_arg_is_optional = True
+
+ @register_acc_op_mapping(
+ op_and_target=("call_function", torch.add),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("other", "other"),
+ ("alpha", "alpha", this_arg_is_optional),
+ ],
+ )
+ @register_acc_op
+ def foo(*, input, other, alpha=1.0):
+ return input + alpha * other
+
+``op_and_target`` determines which node will trigger this mapping. op and target are the attributes of FX node. In acc_normalization when we see a node with the same op and target as set in the ``op_and_target``, we will trigger the mapping. Since we want to map from ``torch.add``, then op would be call_function and target would be ``torch.add``. ``arg_replacement_tuples`` determines how we construct kwargs for new acc op node using args and kwargs from original node. Each tuple in ``arg_replacement_tuples`` represents one argument mapping rule. It contains two or three elements. The third element is a boolean variable that determines whether this kwarg is optional in *original node*. We only need to specify the third element if it’s True. The first element is the argument name in original node which will be used as the acc op node’s argument whose name is the second element in the tuple. The sequence of the tuples does matter because the position of the tuple determines where the argument is in original node’s args. We use this information to map args from original node to kwargs in acc op node.
+We don’t have to specify arg_replacement_tuples if none of the followings are true.
+
+1. kwargs of original nodes and acc op nodes have different name.
+2. there’re optional arguments.
+
+The other way to register a mapping is through `register_custom_acc_mapper_fn() `_. This one is designed to reduce the redundant op registration as it allows you to use a function to map to one or more existing acc ops throught some combinations. In the function, you can do basically whatever you want. Let’s use an example to explain how it works.
+
+.. code-block:: shell
+
+ @register_acc_op
+ def foo(*, input, other, alpha=1.0):
+ return input + alpha * other
+
+ @register_custom_acc_mapper_fn(
+ op_and_target=("call_function", torch.add),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("other", "other"),
+ ("alpha", "alpha", this_arg_is_optional),
+ ],
+ )
+ def custom_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
+ """
+ `node` is original node, which is a call_function node with target
+ being torch.add.
+ """
+ alpha = 1
+ if "alpha" in node.kwargs:
+ alpha = node.kwargs["alpha"]
+ foo_kwargs = {"input": node["input"], "other": node["other"], "alpha": alpha}
+ with node.graph.inserting_before(node):
+ foo_node = node.graph.call_function(foo, kwargs=foo_kwargs)
+ foo_node.meta = node.meta.copy()
+ return foo_node
+
+
+In the custom mapper function, we construct an acc op node and return it. The node we returns here would take over all the children nodes of original nodes `acc_normalizer.py `_.
+
+The last step would be *adding unit test* for the new acc op or mapper function we added. The place to add the unit test is here `test_acc_tracer.py `_.
+
+* **Step 2. Add a new fx2trt converter**
+
+All the developed converters for acc ops are all in `acc_op_converter.py `_. It could give you a good example of how the converter is added.
+
+Essentially, the converter is the mapping mechanism that maps the acc ops to a TensorRT layer. If we are able to find all the TensorRT layers we need we can get start to add a converter for the node using `TensorRT APIs `_.
+
+.. code-block:: shell
+
+ @tensorrt_converter(acc_ops.sigmoid)
+ def acc_ops_sigmoid(network, target, args, kwargs, name):
+ """
+ network: TensorRT network. We'll be adding layers to it.
+
+ The rest arguments are attributes of fx node.
+ """
+ input_val = kwargs['input']
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(f'Sigmoid received input {input_val} that is not part '
+ 'of the TensorRT region!')
+
+ layer = network.add_activation(input=input_val, type=trt.ActivationType.SIGMOID)
+ layer.name = name
+ return layer.get_output(0)
+
+We need to use ``tensorrt_converter`` decorator to register the converter. The argument for the decorator is the target of the fx node that we need to convert. In the converter, we can find the inputs to the fx node in kwargs. As in the example, the original node is `acc_ops.sigmoid` which only has one argument “input” in acc_ops.py. We get the input and check if it’s a TensorRT tensor. After that, we add a sigmoid layer to TensorRT network and return the output of the layer. The output we returned will be passed to the children nodes of acc_ops.sigmoid by fx.Interpreter.
+
+**What if we can not find corresponding layers in TensorRT that do the same thing as the node.**
+
+In this case, we would need to do a bit more work. TensorRT provides plugins which serves as custom layers. *We have not implement this feature yet. We will update once it is enabled*.
+
+Last step would be adding the unit test for the new converter we added. User could add corresponding unit test in this `folder `_.
From eb8ebb368d324abb7cfe38cbf88af45649ed6a6c Mon Sep 17 00:00:00 2001
From: Wei
Date: Mon, 27 Jun 2022 15:14:16 -0700
Subject: [PATCH 02/10] Create get_started_fx_path.ipynb
---
notebooks/get_started_fx_path.ipynb | 453 ++++++++++++++++++++++++++++
1 file changed, 453 insertions(+)
create mode 100644 notebooks/get_started_fx_path.ipynb
diff --git a/notebooks/get_started_fx_path.ipynb b/notebooks/get_started_fx_path.ipynb
new file mode 100644
index 0000000000..fb88c64714
--- /dev/null
+++ b/notebooks/get_started_fx_path.ipynb
@@ -0,0 +1,453 @@
+{
+ "metadata": {
+ "dataExplorerConfig": {},
+ "bento_stylesheets": {
+ "bento/extensions/flow/main.css": true,
+ "bento/extensions/kernel_selector/main.css": true,
+ "bento/extensions/kernel_ui/main.css": true,
+ "bento/extensions/new_kernel/main.css": true,
+ "bento/extensions/system_usage/main.css": true,
+ "bento/extensions/theme/main.css": true
+ },
+ "kernelspec": {
+ "display_name": "accelerators",
+ "language": "python",
+ "name": "bento_kernel_accelerators",
+ "metadata": {
+ "kernel_name": "bento_kernel_accelerators",
+ "nightly_builds": true,
+ "fbpkg_supported": true,
+ "cinder_runtime": false,
+ "is_prebuilt": true
+ }
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3"
+ },
+ "last_server_session_id": "42b65868-6af0-4f04-bf2f-b7e2511f23dd",
+ "last_kernel_id": "a08a7dfc-0fcc-4486-a2d5-604483260888",
+ "last_base_url": "https://devgpu005.ftw6.facebook.com:8093/",
+ "last_msg_id": "3f4cd9a4-65001843cf56aec954e05889_80",
+ "outputWidgetContext": {}
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2,
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "f9189964-8f5f-4c3d-b58c-ebff076ba890",
+ "showInput": true,
+ "customInput": null,
+ "code_folding": [],
+ "hidden_ranges": []
+ },
+ "source": [
+ "Here is a benchmark example demonstrates the basic usage of `lower_to_trt` interface.\n",
+ "It shows the boosted performance of TensorRT after lowering. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "aac0295c-e26e-45cb-b1b6-7796ee860152",
+ "showInput": false,
+ "customInput": null
+ },
+ "source": [
+ "The purpose of this example is to demonstrate the overall flow of lowering a PyTorch\n",
+ "model to TensorRT via FX with existing FX based tooling. The general lowering flow would be like:\n",
+ "1. Use splitter to split the model if there're ops in the model that we don't want to lower to TensorRT for some reasons like the ops are not supported in TensorRT or running them on other backends provides better performance.\n",
+ "2. Lower the model (or part of the model if splitter is used) to TensorRT via fx2trt.\n",
+ "If we know the model is fully supported by fx2trt then we can skip the splitter."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "ca68b029-68a6-42d6-968e-95bb7c1aae73",
+ "showInput": true,
+ "customInput": null,
+ "collapsed": false,
+ "requestMsgId": "f56944ff-ade2-4041-bdd6-3bce44b1405f",
+ "customOutput": null,
+ "executionStartTime": 1656367410991,
+ "executionStopTime": 1656367412604
+ },
+ "source": [
+ "import torch\n",
+ "import torch.fx\n",
+ "import torch.nn as nn\n",
+ "from torch_tensorrt.fx.utils import LowerPrecision\n",
+ "import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer\n",
+ "from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule\n",
+ "from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter"
+ ],
+ "execution_count": 1,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "8f974ab2-d187-4ffe-a09b-16cd85949be4",
+ "showInput": true,
+ "customInput": null,
+ "code_folding": [],
+ "hidden_ranges": [],
+ "collapsed": false,
+ "requestMsgId": "564359f5-ac69-4666-91e1-41b299495ed1",
+ "customOutput": null,
+ "executionStartTime": 1656367414494,
+ "executionStopTime": 1656367422756
+ },
+ "source": [
+ "class Model(nn.Module):\n",
+ " def __init__(self):\n",
+ " super().__init__()\n",
+ " self.linear = nn.Linear(10, 10)\n",
+ " self.relu = nn.ReLU()\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x = self.linear(x)\n",
+ " x = self.relu(x)\n",
+ " x = torch.linalg.norm(x, ord=2, dim=1)\n",
+ " x = self.relu(x)\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "inputs = [torch.randn((1, 10), device=torch.device('cuda'))]\n",
+ "model = Model().cuda().eval()"
+ ],
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "0d407e92-e9e7-48aa-9c9e-1c21a9b5fd8f",
+ "showInput": false,
+ "customInput": null,
+ "code_folding": [],
+ "hidden_ranges": []
+ },
+ "source": [
+ "acc_tracer is a custom fx tracer that maps nodes whose targets are PyTorch operators\n",
+ "to acc ops."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "a1d9c8c2-8ec7-425a-8518-6f7e53ab1e67",
+ "showInput": true,
+ "customInput": null,
+ "code_folding": [],
+ "hidden_ranges": [],
+ "collapsed": false,
+ "requestMsgId": "ee2da608-5f1c-4f63-9927-544717e84e8a",
+ "customOutput": null,
+ "executionStartTime": 1656367480626,
+ "executionStopTime": 1656367482881
+ },
+ "source": [
+ "traced = acc_tracer.trace(model, inputs)"
+ ],
+ "execution_count": 3,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "246613eb-14b5-488e-9aae-35306fc99db1",
+ "showInput": false,
+ "customInput": null,
+ "code_folding": [],
+ "hidden_ranges": []
+ },
+ "source": [
+ "Splitter will split the model into several submodules. The name of submodules will\n",
+ "be either `run_on_acc_{}` or `run_on_gpu_{}`. Submodules named `run_on_acc_{}` can\n",
+ "be fully lowered to TensorRT via fx2trt while submodules named `run_on_gpu_{}` has\n",
+ "unsupported ops and can't be lowered by fx2trt. We can still run `run_on_gpu_{}`\n",
+ "submodules on GPU if ops there have cuda implementation.\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "1103c70a-3766-4d89-ad2f-cdcb1c3891e0",
+ "showInput": true,
+ "customInput": null,
+ "collapsed": false,
+ "requestMsgId": "feb888ea-ef9c-4577-b0c6-cf95bc1dd25e",
+ "customOutput": null,
+ "executionStartTime": 1656367487073,
+ "executionStopTime": 1656367487154
+ },
+ "source": [
+ "splitter = TRTSplitter(traced, inputs)"
+ ],
+ "execution_count": 4,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "3d65e07e-57ed-47d5-adb9-4685c69c9c6b",
+ "showInput": false,
+ "customInput": null,
+ "code_folding": [],
+ "hidden_ranges": []
+ },
+ "source": [
+ "Preview functionality allows us to see what are the supported ops and unsupported\n",
+ "ops. We can optionally the dot graph which will color supported ops and unsupported\n",
+ "ops differently."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "6aaed2d5-61b7-438e-a72a-63f91d0709e2",
+ "showInput": true,
+ "customInput": null,
+ "collapsed": false,
+ "requestMsgId": "2948c2f8-854b-4bc2-b399-321469da320c",
+ "customOutput": null,
+ "executionStartTime": 1656367489373,
+ "executionStopTime": 1656367489556
+ },
+ "source": [
+ "splitter.node_support_preview()"
+ ],
+ "execution_count": 5,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\nSupported node types in the model:\nacc_ops.linear: ((), {'input': torch.float32, 'weight': torch.float32, 'bias': torch.float32})\nacc_ops.relu: ((), {'input': torch.float32})\n\nUnsupported node types in the model:\nacc_ops.linalg_norm: ((), {'input': torch.float32})\n\n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": "\"\\nSupported node types in the model:\\nacc_ops.linear: ((), {'input': torch.float32, 'weight': torch.float32, 'bias': torch.float32})\\nacc_ops.relu: ((), {'input': torch.float32})\\n\\nUnsupported node types in the model:\\nacc_ops.linalg_norm: ((), {'input': torch.float32})\\n\""
+ },
+ "metadata": {
+ "bento_obj_id": "139812830161136"
+ },
+ "execution_count": 5
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "8d8035ab-869e-4096-b8e1-3539a0cfe1af",
+ "showInput": false,
+ "customInput": null
+ },
+ "source": [
+ "After split, there are three submodules, _run_on_acc_0 and _run_on_gpu_1. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "80e03730-955a-4cc8-b071-7f92a2cff3df",
+ "showInput": true,
+ "customInput": null,
+ "code_folding": [],
+ "hidden_ranges": [],
+ "collapsed": false,
+ "requestMsgId": "2ca46574-7176-4699-a809-2a2e2d5ffda0",
+ "customOutput": null,
+ "executionStartTime": 1656367495077,
+ "executionStopTime": 1656367495250
+ },
+ "source": [
+ "split_mod = splitter()\n",
+ "print(split_mod.graph)"
+ ],
+ "execution_count": 6,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Got 2 acc subgraphs and 1 non-acc subgraphs\ngraph():\n %x : [#users=1] = placeholder[target=x]\n %_run_on_acc_0 : [#users=1] = call_module[target=_run_on_acc_0](args = (%x,), kwargs = {})\n %_run_on_gpu_1 : [#users=1] = call_module[target=_run_on_gpu_1](args = (%_run_on_acc_0,), kwargs = {})\n %_run_on_acc_2 : [#users=1] = call_module[target=_run_on_acc_2](args = (%_run_on_gpu_1,), kwargs = {})\n return _run_on_acc_2\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "9ce75161-978e-468e-9989-ecdbc9af0d5b",
+ "showInput": true,
+ "customInput": null,
+ "collapsed": false,
+ "requestMsgId": "0370de27-39ec-4be0-826b-9aec90df1155",
+ "customOutput": null,
+ "executionStartTime": 1656367496353,
+ "executionStopTime": 1656367496452
+ },
+ "source": [
+ "print(split_mod._run_on_acc_0.graph)\n",
+ "print(split_mod._run_on_gpu_1.graph)\n",
+ "print(split_mod._run_on_acc_2.graph)"
+ ],
+ "execution_count": 7,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "graph():\n %x : [#users=1] = placeholder[target=x]\n %linear_weight : [#users=1] = get_attr[target=linear.weight]\n %linear_bias : [#users=1] = get_attr[target=linear.bias]\n %linear_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linear](args = (), kwargs = {input: %x, weight: %linear_weight, bias: %linear_bias})\n %relu_2 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.relu](args = (), kwargs = {input: %linear_1, inplace: False})\n return relu_2\ngraph():\n %relu_2 : [#users=1] = placeholder[target=relu_2]\n %linalg_norm_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linalg_norm](args = (), kwargs = {input: %relu_2, ord: 2, dim: 1, keepdim: False})\n return linalg_norm_1\ngraph():\n %linalg_norm_1 : [#users=1] = placeholder[target=linalg_norm_1]\n %relu_3 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.relu](args = (), kwargs = {input: %linalg_norm_1, inplace: False})\n return relu_3\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "7a6857bc-fedd-4847-ba17-a5d114de34f3",
+ "showInput": false,
+ "customInput": null,
+ "code_folding": [],
+ "hidden_ranges": []
+ },
+ "source": [
+ "The `split_mod` contains the child modules supported by TRT or eager gpu. We can iterate them to transform the module into TRT engine."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "174fd2eb-a864-49cf-a204-6d24a8e2849d",
+ "showInput": true,
+ "customInput": null,
+ "collapsed": false,
+ "requestMsgId": "cf7fdfe4-e781-47c8-9a9a-85b5664c10f7",
+ "customOutput": null,
+ "executionStartTime": 1656367502837,
+ "executionStopTime": 1656367510024,
+ "code_folding": [],
+ "hidden_ranges": []
+ },
+ "source": [
+ "def get_submod_inputs(mod, submod, inputs):\n",
+ " acc_inputs = None\n",
+ "\n",
+ " def get_input(self, inputs):\n",
+ " nonlocal acc_inputs\n",
+ " acc_inputs = inputs\n",
+ "\n",
+ " handle = submod.register_forward_pre_hook(get_input)\n",
+ " mod(*inputs)\n",
+ " handle.remove()\n",
+ " return acc_inputs\n",
+ "\n",
+ "# Since the model is splitted into three segments. We need to lower each TRT eligible segment.\n",
+ "# If we know the model can be fully lowered, we can skip the splitter part.\n",
+ "for name, _ in split_mod.named_children():\n",
+ " if \"_run_on_acc\" in name:\n",
+ " submod = getattr(split_mod, name)\n",
+ " # Get submodule inputs for fx2trt\n",
+ " acc_inputs = get_submod_inputs(split_mod, submod, inputs)\n",
+ "\n",
+ " # fx2trt replacement\n",
+ " interp = TRTInterpreter(\n",
+ " submod,\n",
+ " InputTensorSpec.from_tensors(acc_inputs),\n",
+ " explicit_batch_dimension=True,\n",
+ " )\n",
+ " r = interp.run(lower_precision=LowerPrecision.FP32)\n",
+ " trt_mod = TRTModule(*r)\n",
+ " setattr(split_mod, name, trt_mod)\n",
+ "\n",
+ "lowered_model_output = split_mod(*inputs)"
+ ],
+ "execution_count": 8,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "I0627 150503.073 fx2trt.py:190] Run Module elapsed time: 0:00:00.014965\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "I0627 150504.996 fx2trt.py:241] Build TRT engine elapsed time: 0:00:01.922029\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "I0627 150505.026 fx2trt.py:190] Run Module elapsed time: 0:00:00.000302\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "I0627 150509.953 fx2trt.py:241] Build TRT engine elapsed time: 0:00:04.925192\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "f1db3e1e-3a70-4735-a403-baa557b0f8a6",
+ "showInput": false,
+ "customInput": null,
+ "code_folding": [],
+ "hidden_ranges": []
+ },
+ "source": [
+ "Model can be saved by torch.save and loaded with torch.load. Then we can compare the results with eager mode inference. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "a7c4fa0f-cac6-4959-8fa6-13b3455137d3",
+ "showInput": true,
+ "customInput": null,
+ "collapsed": false,
+ "requestMsgId": "f0c264ac-2bda-4c8e-a236-e2bd475e601e",
+ "customOutput": null,
+ "executionStartTime": 1656367515833,
+ "executionStopTime": 1656367516184
+ },
+ "source": [
+ "torch.save(split_mod, \"trt.pt\")\n",
+ "reload_trt_mod = torch.load(\"trt.pt\")\n",
+ "reload_model_output = reload_trt_mod(*inputs)\n",
+ "\n",
+ "# Make sure the results match\n",
+ "regular_model_output = model(*inputs)\n",
+ "torch.testing.assert_close(\n",
+ " reload_model_output, regular_model_output, atol=3e-3, rtol=1e-2\n",
+ ")"
+ ],
+ "execution_count": 9,
+ "outputs": []
+ }
+ ]
+}
From fcf7e01c66dc9cef8b05b246849c43ef97c87f99 Mon Sep 17 00:00:00 2001
From: Wei
Date: Mon, 27 Jun 2022 15:17:11 -0700
Subject: [PATCH 03/10] Update and rename get_started_fx_path.ipynb to
getting_started_with_fx_path_module.ipynb
---
...getting_started_with_fx_path_module.ipynb} | 20 ++++---------------
1 file changed, 4 insertions(+), 16 deletions(-)
rename notebooks/{get_started_fx_path.ipynb => getting_started_with_fx_path_module.ipynb} (96%)
diff --git a/notebooks/get_started_fx_path.ipynb b/notebooks/getting_started_with_fx_path_module.ipynb
similarity index 96%
rename from notebooks/get_started_fx_path.ipynb
rename to notebooks/getting_started_with_fx_path_module.ipynb
index fb88c64714..5ee8884790 100644
--- a/notebooks/get_started_fx_path.ipynb
+++ b/notebooks/getting_started_with_fx_path_module.ipynb
@@ -44,30 +44,18 @@
{
"cell_type": "markdown",
"metadata": {
- "originalKey": "f9189964-8f5f-4c3d-b58c-ebff076ba890",
- "showInput": true,
+ "originalKey": "aac0295c-e26e-45cb-b1b6-7796ee860152",
+ "showInput": false,
"customInput": null,
"code_folding": [],
"hidden_ranges": []
},
- "source": [
- "Here is a benchmark example demonstrates the basic usage of `lower_to_trt` interface.\n",
- "It shows the boosted performance of TensorRT after lowering. "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "originalKey": "aac0295c-e26e-45cb-b1b6-7796ee860152",
- "showInput": false,
- "customInput": null
- },
"source": [
"The purpose of this example is to demonstrate the overall flow of lowering a PyTorch\n",
"model to TensorRT via FX with existing FX based tooling. The general lowering flow would be like:\n",
"1. Use splitter to split the model if there're ops in the model that we don't want to lower to TensorRT for some reasons like the ops are not supported in TensorRT or running them on other backends provides better performance.\n",
- "2. Lower the model (or part of the model if splitter is used) to TensorRT via fx2trt.\n",
- "If we know the model is fully supported by fx2trt then we can skip the splitter."
+ "2. Lower the model (or part of the model if splitter is used) to TensorRT via fx path.\n",
+ "If we know the model is fully supported by fx path (without op unsupported) then we can skip the splitter."
]
},
{
From 672de80e43fc20509d74be9e14d4ba7e206da287 Mon Sep 17 00:00:00 2001
From: Wei
Date: Mon, 27 Jun 2022 23:49:04 -0700
Subject: [PATCH 04/10] Create getting_started_with_fx_path_lower_to_trt.ipynb
---
...ng_started_with_fx_path_lower_to_trt.ipynb | 432 ++++++++++++++++++
1 file changed, 432 insertions(+)
create mode 100644 notebooks/getting_started_with_fx_path_lower_to_trt.ipynb
diff --git a/notebooks/getting_started_with_fx_path_lower_to_trt.ipynb b/notebooks/getting_started_with_fx_path_lower_to_trt.ipynb
new file mode 100644
index 0000000000..22bae41ac7
--- /dev/null
+++ b/notebooks/getting_started_with_fx_path_lower_to_trt.ipynb
@@ -0,0 +1,432 @@
+{
+ "metadata": {
+ "dataExplorerConfig": {},
+ "bento_stylesheets": {
+ "bento/extensions/flow/main.css": true,
+ "bento/extensions/kernel_selector/main.css": true,
+ "bento/extensions/kernel_ui/main.css": true,
+ "bento/extensions/new_kernel/main.css": true,
+ "bento/extensions/system_usage/main.css": true,
+ "bento/extensions/theme/main.css": true
+ },
+ "kernelspec": {
+ "display_name": "accelerators",
+ "language": "python",
+ "name": "bento_kernel_accelerators",
+ "metadata": {
+ "kernel_name": "bento_kernel_accelerators",
+ "nightly_builds": true,
+ "fbpkg_supported": true,
+ "cinder_runtime": false,
+ "is_prebuilt": true
+ }
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3"
+ },
+ "last_server_session_id": "c6f6ab3c-9274-41e7-8592-b1b583442e00",
+ "last_kernel_id": "fcbf3a69-76a4-4730-9b41-bcd0b24729ca",
+ "last_base_url": "https://devgpu005.ftw6.facebook.com:8093/",
+ "last_msg_id": "e28f842c-f32dde25c1b80ef7d423dfee_407",
+ "outputWidgetContext": {}
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2,
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "8ca7695d-8a19-454e-b32b-3d5c36d52faf",
+ "showInput": false,
+ "customInput": null,
+ "code_folding": [],
+ "hidden_ranges": []
+ },
+ "source": [
+ "The purpose of this example is to demostrate the onverall flow of lowering a PyTorch model\n",
+ "to TensorRT conveniently with lower.py. We integrated the transformation process including `TRTInterpreter`, `TRTModule`, pass optimization into the `lower_to_trt` API, users are encouraged to check the docstring of the API and tune it to meet your needs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "7909785f-b9b4-41dd-82af-c144b879df39",
+ "showInput": true,
+ "customInput": null,
+ "collapsed": false,
+ "requestMsgId": "7db2accc-9fa4-4a1e-8142-d887f2947bcd",
+ "customOutput": null,
+ "executionStartTime": 1656395936225,
+ "executionStopTime": 1656395937851
+ },
+ "source": [
+ "import typing as t\n",
+ "from copy import deepcopy\n",
+ "from dataclasses import dataclass, field, replace\n",
+ "\n",
+ "import torch\n",
+ "import torchvision\n",
+ "from torch_tensorrt.fx.lower import lower_to_trt\n",
+ "from torch_tensorrt.fx.utils import LowerPrecision"
+ ],
+ "execution_count": 4,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "e324a1ff-1bc2-4e78-932f-33534c3ac3f5",
+ "showInput": false,
+ "customInput": null,
+ "code_folding": [],
+ "hidden_ranges": []
+ },
+ "source": [
+ "Specify the `configuration` class used for FX path lowering and benchmark. To extend, add a new configuration field to this class, and modify the lowering or benchmark behavior in `run_configuration_benchmark()` correspondingly. It automatically stores all its values to a `Result` dataclass. \n",
+ "`Result` is another dataclass that holds raw essential benchmark result values like Batch size, QPS, accuracy, etc..\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "a4455135-8633-4d2d-bdd3-6435a4a9f4dd",
+ "showInput": true,
+ "customInput": null,
+ "code_folding": [],
+ "hidden_ranges": [],
+ "collapsed": false,
+ "requestMsgId": "2835fffa-cc50-479a-9080-c4f7002c0726",
+ "customOutput": null,
+ "executionStartTime": 1656398717455,
+ "executionStopTime": 1656398717662
+ },
+ "source": [
+ "@dataclass\n",
+ "class Configuration:\n",
+ " # number of inferences to run\n",
+ " batch_iter: int\n",
+ "\n",
+ " # Input batch size\n",
+ " batch_size: int\n",
+ "\n",
+ " # Friendly name of the configuration\n",
+ " name: str = \"\"\n",
+ "\n",
+ " # Whether to apply TRT lowering to the model before benchmarking\n",
+ " trt: bool = False\n",
+ "\n",
+ " # Whether to apply engine holder to the lowered model\n",
+ " jit: bool = False\n",
+ "\n",
+ " # Whether to enable FP16 mode for TRT lowering\n",
+ " fp16: bool = False\n",
+ "\n",
+ " # Relative tolerance for accuracy check after lowering. -1 means do not\n",
+ " # check accuracy.\n",
+ " accuracy_rtol: float = -1 # disable\n",
+ " \n",
+ "@dataclass\n",
+ "class Result:\n",
+ " module: torch.nn.Module = field(repr=False)\n",
+ " input: t.Any = field(repr=False)\n",
+ " conf: Configuration\n",
+ " time_sec: float\n",
+ " accuracy_res: t.Optional[bool] = None\n",
+ "\n",
+ " @property\n",
+ " def time_per_iter_ms(self) -> float:\n",
+ " return self.time_sec * 1.0e3\n",
+ "\n",
+ " @property\n",
+ " def qps(self) -> float:\n",
+ " return self.conf.batch_size / self.time_sec\n",
+ "\n",
+ " def format(self) -> str:\n",
+ " return (\n",
+ " f\"== Benchmark Result for: {self.conf}\\n\"\n",
+ " f\"BS: {self.conf.batch_size}, \"\n",
+ " f\"Time per iter: {self.time_per_iter_ms:.2f}ms, \"\n",
+ " f\"QPS: {self.qps:.2f}, \"\n",
+ " f\"Accuracy: {self.accuracy_res} (rtol={self.conf.accuracy_rtol})\"\n",
+ " )"
+ ],
+ "execution_count": 22,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "3e462cf6-d282-402d-955b-a3ecb400bf0b",
+ "showInput": true,
+ "customInput": null,
+ "code_folding": [],
+ "hidden_ranges": []
+ },
+ "source": [
+ "Run FX path lowering and benchmark the given model according to the specified benchmark configuration. Prints the benchmark result for each configuration at the end of the run. `benchmark_torch_function` is the actual function that computes the fixed number of iterations of functions runs.\n",
+ "The FX path lowering and TensorRT engine creation is integrated into `low_to_trt()` API which is defined in `fx/lower.py` file.\n",
+ "It is good to list it out and show the usage of it. It takes in original module, input and lowering setting, run lowering workflow to turn module into a executable TRT engine \n",
+ "```\n",
+ "def lower_to_trt(\n",
+ " module: nn.Module,\n",
+ " input: ,\n",
+ " max_batch_size: int = 2048,\n",
+ " max_workspace_size=1 << 25,\n",
+ " explicit_batch_dimension=False,\n",
+ " lower_precision=LowerPrecision.FP16,\n",
+ " verbose_log=False,\n",
+ " timing_cache_prefix=\"\",\n",
+ " save_timing_cache=False,\n",
+ " cuda_graph_batch_size=-1,\n",
+ " dynamic_batch=False,\n",
+ ") -> nn.Module:\n",
+ "``` \n",
+ "\n",
+ " Args:\n",
+ " module: Original module for lowering.\n",
+ " input: Input for module.\n",
+ " max_batch_size: Maximum batch size (must be >= 1 to be set, 0 means not set)\n",
+ " max_workspace_size: Maximum size of workspace given to TensorRT.\n",
+ " explicit_batch_dimension: Use explicit batch dimension in TensorRT if set True, otherwise use implicit batch dimension.\n",
+ " lower_precision: lower_precision config given to TRTModule.\n",
+ " verbose_log: Enable verbose log for TensorRT if set True.\n",
+ " timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.\n",
+ " save_timing_cache: Update timing cache with current timing cache data if set to True.\n",
+ " cuda_graph_batch_size: Cuda graph batch size, default to be -1.\n",
+ "\n",
+ " Returns:\n",
+ " A torch.nn.Module lowered by TensorRT.\n",
+ "We testd a resnet18 network with input size of [128,3,224,224] for [Batch, Channel, Width, Height]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "91333212-7f6d-4bde-a248-44d485e83e5e",
+ "showInput": true,
+ "customInput": null,
+ "code_folding": [],
+ "hidden_ranges": [],
+ "collapsed": false,
+ "requestMsgId": "3002935b-b95a-4a08-a57f-f7a35485af5b",
+ "customOutput": null,
+ "executionStartTime": 1656397903207,
+ "executionStopTime": 1656397964752
+ },
+ "source": [
+ "test_model = torchvision.models.resnet18(pretrained=True)\n",
+ "input = [torch.rand(128, 3, 224, 224)] \n",
+ "benchmark(test_model, input, 50, 128)\n",
+ "\n",
+ "def benchmark_torch_function(iters: int, f, *args) -> float:\n",
+ " \"\"\"Estimates the average time duration for a single inference call in second\n",
+ "\n",
+ " If the input is batched, then the estimation is for the batches inference call.\n",
+ " \"\"\"\n",
+ " with torch.inference_mode():\n",
+ " f(*args)\n",
+ " torch.cuda.synchronize()\n",
+ " start_event = torch.cuda.Event(enable_timing=True)\n",
+ " end_event = torch.cuda.Event(enable_timing=True)\n",
+ " print(\"== Start benchmark iterations\")\n",
+ " with torch.inference_mode():\n",
+ " start_event.record()\n",
+ " for _ in range(iters):\n",
+ " f(*args)\n",
+ " end_event.record()\n",
+ " torch.cuda.synchronize()\n",
+ " print(\"== End benchmark iterations\")\n",
+ " return (start_event.elapsed_time(end_event) * 1.0e-3) / iters\n",
+ "\n",
+ "\n",
+ "def run_configuration_benchmark(\n",
+ " module,\n",
+ " input,\n",
+ " conf: Configuration,\n",
+ ") -> Result:\n",
+ " print(f\"=== Running benchmark for: {conf}\", \"green\")\n",
+ " time = -1.0\n",
+ "\n",
+ " if conf.fp16:\n",
+ " module = module.half()\n",
+ " input = [i.half() for i in input]\n",
+ "\n",
+ " if not conf.trt:\n",
+ " # Run eager mode benchmark\n",
+ " time = benchmark_torch_function(conf.batch_iter, lambda: module(*input))\n",
+ " elif not conf.jit:\n",
+ " # Run lowering eager mode benchmark\n",
+ " lowered_module = lower_to_trt(\n",
+ " module,\n",
+ " input,\n",
+ " max_batch_size=conf.batch_size,\n",
+ " lower_precision=LowerPrecision.FP16 if conf.fp16 else LowerPrecision.FP32,\n",
+ " )\n",
+ " time = benchmark_torch_function(conf.batch_iter, lambda: lowered_module(*input))\n",
+ " else:\n",
+ " print(\"Lowering with JIT is not available!\", \"red\")\n",
+ "\n",
+ " result = Result(module=module, input=input, conf=conf, time_sec=time)\n",
+ " return result\n",
+ "\n",
+ "@torch.inference_mode()\n",
+ "def benchmark(\n",
+ " model,\n",
+ " inputs,\n",
+ " batch_iter: int,\n",
+ " batch_size: int,\n",
+ ") -> None:\n",
+ " model = model.cuda().eval()\n",
+ " inputs = [x.cuda() for x in inputs]\n",
+ "\n",
+ " # benchmark base configuration\n",
+ " conf = Configuration(batch_iter=batch_iter, batch_size=batch_size)\n",
+ "\n",
+ " configurations = [\n",
+ " # Baseline\n",
+ " replace(conf, name=\"CUDA Eager\", trt=False),\n",
+ " # FP32\n",
+ " replace(\n",
+ " conf,\n",
+ " name=\"TRT FP32 Eager\",\n",
+ " trt=True,\n",
+ " jit=False,\n",
+ " fp16=False,\n",
+ " accuracy_rtol=1e-3,\n",
+ " ),\n",
+ " # FP16\n",
+ " replace(\n",
+ " conf,\n",
+ " name=\"TRT FP16 Eager\",\n",
+ " trt=True,\n",
+ " jit=False,\n",
+ " fp16=True,\n",
+ " accuracy_rtol=1e-2,\n",
+ " ),\n",
+ " ]\n",
+ "\n",
+ " results = [\n",
+ " run_configuration_benchmark(deepcopy(model), inputs, conf_)\n",
+ " for conf_ in configurations\n",
+ " ]\n",
+ "\n",
+ " for res in results:\n",
+ " print(res.format())"
+ ],
+ "execution_count": 21,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "I0627 233143.380 manifold.py:1430] URL manifold://torchvision/tree/models/resnet18-f37072fd.pth was already cached in /home/wwei6/.torch/iopath_cache/manifold_cache/tree/models/resnet18-f37072fd.pth\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "=== Running benchmark for: Configuration(batch_iter=50, batch_size=128, name='CUDA Eager', trt=False, jit=False, fp16=False, accuracy_rtol=-1) green\n== Start benchmark iterations\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "== End benchmark iterations\n=== Running benchmark for: Configuration(batch_iter=50, batch_size=128, name='TRT FP32 Eager', trt=True, jit=False, fp16=False, accuracy_rtol=0.001) green\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "== Log pass before/after graph to /tmp/tmpaayayg72\n== Log pass before/after graph to /tmp/tmpdw_pq71j\n\nSupported node types in the model:\nacc_ops.conv2d: ((), {'input': torch.float32, 'weight': torch.float32})\nacc_ops.batch_norm: ((), {'input': torch.float32, 'running_mean': torch.float32, 'running_var': torch.float32, 'weight': torch.float32, 'bias': torch.float32})\nacc_ops.relu: ((), {'input': torch.float32})\nacc_ops.max_pool2d: ((), {'input': torch.float32})\nacc_ops.add: ((), {'input': torch.float32, 'other': torch.float32})\nacc_ops.adaptive_avg_pool2d: ((), {'input': torch.float32})\nacc_ops.flatten: ((), {'input': torch.float32})\nacc_ops.linear: ((), {'input': torch.float32, 'weight': torch.float32, 'bias': torch.float32})\n\nUnsupported node types in the model:\n\nGot 1 acc subgraphs and 0 non-acc subgraphs\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "I0627 233146.650 fx2trt.py:190] Run Module elapsed time: 0:00:00.244369\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "I0627 233206.570 fx2trt.py:241] Build TRT engine elapsed time: 0:00:19.918630\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "== Start benchmark iterations\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "== End benchmark iterations\n=== Running benchmark for: Configuration(batch_iter=50, batch_size=128, name='TRT FP16 Eager', trt=True, jit=False, fp16=True, accuracy_rtol=0.01) green\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "== Log pass before/after graph to /tmp/tmpnoeblgd5\n== Log pass before/after graph to /tmp/tmpyb1egsof\n\nSupported node types in the model:\nacc_ops.conv2d: ((), {'input': torch.float16, 'weight': torch.float16})\nacc_ops.batch_norm: ((), {'input': torch.float16, 'running_mean': torch.float16, 'running_var': torch.float16, 'weight': torch.float16, 'bias': torch.float16})\nacc_ops.relu: ((), {'input': torch.float16})\nacc_ops.max_pool2d: ((), {'input': torch.float16})\nacc_ops.add: ((), {'input': torch.float16, 'other': torch.float16})\nacc_ops.adaptive_avg_pool2d: ((), {'input': torch.float16})\nacc_ops.flatten: ((), {'input': torch.float16})\nacc_ops.linear: ((), {'input': torch.float16, 'weight': torch.float16, 'bias': torch.float16})\n\nUnsupported node types in the model:\n\nGot 1 acc subgraphs and 0 non-acc subgraphs\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "I0627 233208.996 fx2trt.py:190] Run Module elapsed time: 0:00:00.217076\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "I0627 233244.147 fx2trt.py:241] Build TRT engine elapsed time: 0:00:35.150950\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "== Start benchmark iterations\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "== End benchmark iterations\n== Benchmark Result for: Configuration(batch_iter=50, batch_size=128, name='CUDA Eager', trt=False, jit=False, fp16=False, accuracy_rtol=-1)\nBS: 128, Time per iter: 15.00ms, QPS: 8530.72, Accuracy: None (rtol=-1)\n== Benchmark Result for: Configuration(batch_iter=50, batch_size=128, name='TRT FP32 Eager', trt=True, jit=False, fp16=False, accuracy_rtol=0.001)\nBS: 128, Time per iter: 7.95ms, QPS: 16098.45, Accuracy: None (rtol=0.001)\n== Benchmark Result for: Configuration(batch_iter=50, batch_size=128, name='TRT FP16 Eager', trt=True, jit=False, fp16=True, accuracy_rtol=0.01)\nBS: 128, Time per iter: 4.36ms, QPS: 29365.31, Accuracy: None (rtol=0.01)\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "80bbae99-41ff-4baa-94a5-12bf0c9938f3",
+ "showInput": true,
+ "customInput": null
+ },
+ "source": [
+ ""
+ ]
+ }
+ ]
+}
From 193b29ccbf67d12466dee8bf33e162603d094d0c Mon Sep 17 00:00:00 2001
From: Wei
Date: Mon, 27 Jun 2022 23:50:55 -0700
Subject: [PATCH 05/10] Update getting_started_with_fx_path_lower_to_trt.ipynb
---
.../getting_started_with_fx_path_lower_to_trt.ipynb | 9 +--------
1 file changed, 1 insertion(+), 8 deletions(-)
diff --git a/notebooks/getting_started_with_fx_path_lower_to_trt.ipynb b/notebooks/getting_started_with_fx_path_lower_to_trt.ipynb
index 22bae41ac7..5bbe61659d 100644
--- a/notebooks/getting_started_with_fx_path_lower_to_trt.ipynb
+++ b/notebooks/getting_started_with_fx_path_lower_to_trt.ipynb
@@ -51,7 +51,7 @@
"hidden_ranges": []
},
"source": [
- "The purpose of this example is to demostrate the onverall flow of lowering a PyTorch model\n",
+ "The purpose of this example is to demostrate the overall flow of lowering a PyTorch model\n",
"to TensorRT conveniently with lower.py. We integrated the transformation process including `TRTInterpreter`, `TRTModule`, pass optimization into the `lower_to_trt` API, users are encouraged to check the docstring of the API and tune it to meet your needs."
]
},
@@ -324,13 +324,6 @@
],
"execution_count": 21,
"outputs": [
- {
- "output_type": "stream",
- "name": "stderr",
- "text": [
- "I0627 233143.380 manifold.py:1430] URL manifold://torchvision/tree/models/resnet18-f37072fd.pth was already cached in /home/wwei6/.torch/iopath_cache/manifold_cache/tree/models/resnet18-f37072fd.pth\n"
- ]
- },
{
"output_type": "stream",
"name": "stdout",
From c00b15800d0e49c96f74e95972572565ba145a1b Mon Sep 17 00:00:00 2001
From: Wei
Date: Tue, 28 Jun 2022 00:18:35 -0700
Subject: [PATCH 06/10] Update getting_started_with_fx_path_lower_to_trt.ipynb
---
notebooks/getting_started_with_fx_path_lower_to_trt.ipynb | 1 +
1 file changed, 1 insertion(+)
diff --git a/notebooks/getting_started_with_fx_path_lower_to_trt.ipynb b/notebooks/getting_started_with_fx_path_lower_to_trt.ipynb
index 5bbe61659d..5ef957fa36 100644
--- a/notebooks/getting_started_with_fx_path_lower_to_trt.ipynb
+++ b/notebooks/getting_started_with_fx_path_lower_to_trt.ipynb
@@ -202,6 +202,7 @@
" timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.\n",
" save_timing_cache: Update timing cache with current timing cache data if set to True.\n",
" cuda_graph_batch_size: Cuda graph batch size, default to be -1.\n",
+ " dynamic_batch: batch dimension (dim=0) is dynamic.\n",
"\n",
" Returns:\n",
" A torch.nn.Module lowered by TensorRT.\n",
From c19c6b903f4448aaf7cbd09163cbf1163c6ec5ba Mon Sep 17 00:00:00 2001
From: Wei Wei
Date: Tue, 28 Jun 2022 10:02:07 -0700
Subject: [PATCH 07/10] add version() in setup.py
---
py/setup.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/py/setup.py b/py/setup.py
index f195abdbfd..ceaa681016 100644
--- a/py/setup.py
+++ b/py/setup.py
@@ -143,6 +143,7 @@ def finalize_options(self):
def run(self):
if FX_ONLY:
+ gen_version_file()
develop.run(self)
else:
global CXX11_ABI
@@ -163,6 +164,7 @@ def finalize_options(self):
def run(self):
if FX_ONLY:
+ gen_version_file()
install.run(self)
else:
global CXX11_ABI
From 7db85f3111d02ccfd6aa615d663c4db61940f598 Mon Sep 17 00:00:00 2001
From: Wei Wei
Date: Tue, 28 Jun 2022 10:08:14 -0700
Subject: [PATCH 08/10] make html
---
.../classtorch__tensorrt_1_1DataType.html | 2 +-
...rch__tensorrt_1_1Device_1_1DeviceType.html | 2 +-
.../classtorch__tensorrt_1_1TensorFormat.html | 2 +-
...ensorrt_1_1ptq_1_1Int8CacheCalibrator.html | 2 +-
...ch__tensorrt_1_1ptq_1_1Int8Calibrator.html | 2 +-
...8h_1a18d295a837ac71add5578860b55e5502.html | 2 +-
...8h_1a282fd3c0b1c3a215148ae372070e1268.html | 2 +-
...8h_1a31398a6d4d27e28817afb0f0139e909e.html | 2 +-
...8h_1a35703561b26b1a9d2738ad7d58b27827.html | 2 +-
...8h_1abd1465eb38256d3f22cc1426b23d516b.html | 2 +-
...8h_1abe87b341f562fd1cf40b7672e4d759da.html | 2 +-
...8h_1ad19939408f7be171a74a89928b36eb59.html | 2 +-
...8h_1adad592a7b1b7eed529cdf6acd584c883.html | 2 +-
docs/_cpp_api/dir_cpp.html | 2 +-
docs/_cpp_api/dir_cpp_include.html | 2 +-
.../dir_cpp_include_torch_tensorrt.html | 2 +-
...8h_1a130f65408ad8cbaee060f05e8db69558.html | 2 +-
...8h_1a3fbe5d72e4fc624dbd038853079620eb.html | 8 +-
..._cpp_include_torch_tensorrt_logging.h.html | 2 +-
...e_cpp_include_torch_tensorrt_macros.h.html | 2 +-
...file_cpp_include_torch_tensorrt_ptq.h.html | 2 +-
...clude_torch_tensorrt_torch_tensorrt.h.html | 2 +-
...8h_1a0593f776f469c20469e2f729fc7861a3.html | 2 +-
...8h_1a0c012cb374addd90eb1f42eaec570650.html | 2 +-
...8h_1a56e110feaaba2c3fd44bd201fd21a76a.html | 2 +-
...8h_1a7cb50492421ea9de4e3db895819df6f2.html | 2 +-
...8h_1ac46ac0901cb97e3ae6e93b45f24e90b8.html | 2 +-
...8h_1ad2efd47b6c3689e58ccc595680579ae5.html | 2 +-
...8h_1af8f3443813315af7901903d25dd495cc.html | 2 +-
...8h_1a83ff2be7e0b80bc7434de415861dc039.html | 2 +-
...8h_1a9835f6e605dce1abf442a55b64d6dffa.html | 2 +-
...8h_1a5b405fd3bf3c8fc2e2a54cbbab979797.html | 2 +-
...8h_1a6e19490a08fb1553c9dd347a5ae79db9.html | 2 +-
...8h_1a710df824a7718b440e4bc17bf9693cef.html | 2 +-
...8h_1ac4ab8313ae72c2c899ea31548b528528.html | 2 +-
...8h_1ad1acd06eaeaffbbcf6e7ebf426891384.html | 2 +-
...8h_1ad6a4ee8ca6c8f6e5519eb1128ec7f4a1.html | 2 +-
...8h_1ae8d56472106eeef37fbe51ff7f40c9b2.html | 2 +-
docs/_cpp_api/namespace_torch_tensorrt.html | 2 +-
.../namespace_torch_tensorrt__logging.html | 2 +-
.../namespace_torch_tensorrt__ptq.html | 2 +-
...namespace_torch_tensorrt__torchscript.html | 2 +-
..._cpp_include_torch_tensorrt_logging.h.html | 2 +-
...e_cpp_include_torch_tensorrt_macros.h.html | 2 +-
...file_cpp_include_torch_tensorrt_ptq.h.html | 2 +-
...clude_torch_tensorrt_torch_tensorrt.h.html | 2 +-
.../structtorch__tensorrt_1_1Device.html | 2 +-
.../structtorch__tensorrt_1_1Input.html | 2 +-
...ensorrt_1_1torchscript_1_1CompileSpec.html | 2 +-
docs/_cpp_api/torch_tensort_cpp.html | 2 +-
docs/_cpp_api/unabridged_orphan.html | 2 +-
docs/_notebooks/CitriNet-example.html | 24 +-
docs/_notebooks/EfficientNet-example.html | 24 +-
docs/_notebooks/Hugging-Face-BERT.html | 28 +-
docs/_notebooks/Resnet50-example.html | 24 +-
docs/_notebooks/dynamic-shapes.html | 4 +-
.../getting_started_with_fx_path_module.html | 1093 +++++++++++++++++
.../getting_started_with_fx_path_module.ipynb | 469 +++++++
docs/_notebooks/lenet-getting-started.html | 16 +-
.../_notebooks/ssd-object-detection-demo.html | 24 +-
docs/_notebooks/vgg-qat.html | 44 +-
...ting_started_with_fx_path_module.ipynb.txt | 441 +++++++
.../getting_started_with_fx_path.rst.txt | 304 +++++
docs/contributors/conversion.html | 2 +-
docs/contributors/lowering.html | 2 +-
docs/contributors/partitioning.html | 2 +-
docs/contributors/phases.html | 2 +-
docs/contributors/runtime.html | 2 +-
docs/contributors/system_overview.html | 2 +-
docs/contributors/useful_links.html | 2 +-
docs/contributors/writing_converters.html | 2 +-
docs/genindex.html | 256 +---
docs/index.html | 2 +-
docs/indices/supported_ops.html | 2 +-
docs/objects.inv | Bin 24856 -> 23628 bytes
docs/py_api/logging.html | 203 +--
docs/py_api/ptq.html | 112 +-
docs/py_api/torch_tensorrt.html | 340 +----
docs/py_api/ts.html | 227 +---
docs/search.html | 2 +-
docs/searchindex.js | 2 +-
.../pytorch-sphinx-theme/docs/changelog.html | 2 +-
.../docs/configuring.html | 2 +-
.../pytorch-sphinx-theme/docs/demo/api.html | 2 +-
.../pytorch-sphinx-theme/docs/demo/demo.html | 4 +-
.../docs/demo/lists_tables.html | 2 +-
.../pytorch-sphinx-theme/docs/demo/long.html | 2 +-
.../docs/demo/structure.html | 2 +-
docs/src/pytorch-sphinx-theme/docs/index.html | 2 +-
.../pytorch-sphinx-theme/docs/installing.html | 2 +-
...creating_torchscript_module_in_python.html | 2 +-
.../getting_started_with_cpp_api.html | 2 +-
.../getting_started_with_fx_path.html | 916 ++++++++++++++
.../getting_started_with_python_api.html | 2 +-
docs/tutorials/installation.html | 2 +-
docs/tutorials/ptq.html | 2 +-
docs/tutorials/runtime.html | 2 +-
.../serving_torch_tensorrt_with_triton.html | 2 +-
docs/tutorials/torchtrtc.html | 2 +-
docs/tutorials/use_from_pytorch.html | 2 +-
docs/tutorials/using_dla.html | 2 +-
docsrc/conf.py | 2 +-
102 files changed, 3452 insertions(+), 1271 deletions(-)
create mode 100644 docs/_notebooks/getting_started_with_fx_path_module.html
create mode 100644 docs/_notebooks/getting_started_with_fx_path_module.ipynb
create mode 100644 docs/_sources/_notebooks/getting_started_with_fx_path_module.ipynb.txt
create mode 100644 docs/_sources/tutorials/getting_started_with_fx_path.rst.txt
create mode 100644 docs/tutorials/getting_started_with_fx_path.html
diff --git a/docs/_cpp_api/classtorch__tensorrt_1_1DataType.html b/docs/_cpp_api/classtorch__tensorrt_1_1DataType.html
index 7d2b43a650..a59d53413a 100644
--- a/docs/_cpp_api/classtorch__tensorrt_1_1DataType.html
+++ b/docs/_cpp_api/classtorch__tensorrt_1_1DataType.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/classtorch__tensorrt_1_1Device_1_1DeviceType.html b/docs/_cpp_api/classtorch__tensorrt_1_1Device_1_1DeviceType.html
index 3acf6242f8..1b1e3c38fd 100644
--- a/docs/_cpp_api/classtorch__tensorrt_1_1Device_1_1DeviceType.html
+++ b/docs/_cpp_api/classtorch__tensorrt_1_1Device_1_1DeviceType.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/classtorch__tensorrt_1_1TensorFormat.html b/docs/_cpp_api/classtorch__tensorrt_1_1TensorFormat.html
index c794f56758..cad1ec9f5b 100644
--- a/docs/_cpp_api/classtorch__tensorrt_1_1TensorFormat.html
+++ b/docs/_cpp_api/classtorch__tensorrt_1_1TensorFormat.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/classtorch__tensorrt_1_1ptq_1_1Int8CacheCalibrator.html b/docs/_cpp_api/classtorch__tensorrt_1_1ptq_1_1Int8CacheCalibrator.html
index 25f2a76005..9bc569f705 100644
--- a/docs/_cpp_api/classtorch__tensorrt_1_1ptq_1_1Int8CacheCalibrator.html
+++ b/docs/_cpp_api/classtorch__tensorrt_1_1ptq_1_1Int8CacheCalibrator.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/classtorch__tensorrt_1_1ptq_1_1Int8Calibrator.html b/docs/_cpp_api/classtorch__tensorrt_1_1ptq_1_1Int8Calibrator.html
index 30a841989d..f6e9e5acc3 100644
--- a/docs/_cpp_api/classtorch__tensorrt_1_1ptq_1_1Int8Calibrator.html
+++ b/docs/_cpp_api/classtorch__tensorrt_1_1ptq_1_1Int8Calibrator.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/define_macros_8h_1a18d295a837ac71add5578860b55e5502.html b/docs/_cpp_api/define_macros_8h_1a18d295a837ac71add5578860b55e5502.html
index 068f461588..35f58d0ff8 100644
--- a/docs/_cpp_api/define_macros_8h_1a18d295a837ac71add5578860b55e5502.html
+++ b/docs/_cpp_api/define_macros_8h_1a18d295a837ac71add5578860b55e5502.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/define_macros_8h_1a282fd3c0b1c3a215148ae372070e1268.html b/docs/_cpp_api/define_macros_8h_1a282fd3c0b1c3a215148ae372070e1268.html
index 8523b0f2ab..f85b69a8a2 100644
--- a/docs/_cpp_api/define_macros_8h_1a282fd3c0b1c3a215148ae372070e1268.html
+++ b/docs/_cpp_api/define_macros_8h_1a282fd3c0b1c3a215148ae372070e1268.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/define_macros_8h_1a31398a6d4d27e28817afb0f0139e909e.html b/docs/_cpp_api/define_macros_8h_1a31398a6d4d27e28817afb0f0139e909e.html
index 91ae6b4bc4..3c69beaabe 100644
--- a/docs/_cpp_api/define_macros_8h_1a31398a6d4d27e28817afb0f0139e909e.html
+++ b/docs/_cpp_api/define_macros_8h_1a31398a6d4d27e28817afb0f0139e909e.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/define_macros_8h_1a35703561b26b1a9d2738ad7d58b27827.html b/docs/_cpp_api/define_macros_8h_1a35703561b26b1a9d2738ad7d58b27827.html
index 31c9e74d99..2cf8f5bf13 100644
--- a/docs/_cpp_api/define_macros_8h_1a35703561b26b1a9d2738ad7d58b27827.html
+++ b/docs/_cpp_api/define_macros_8h_1a35703561b26b1a9d2738ad7d58b27827.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/define_macros_8h_1abd1465eb38256d3f22cc1426b23d516b.html b/docs/_cpp_api/define_macros_8h_1abd1465eb38256d3f22cc1426b23d516b.html
index 98dfc1b64a..3d84e4fc97 100644
--- a/docs/_cpp_api/define_macros_8h_1abd1465eb38256d3f22cc1426b23d516b.html
+++ b/docs/_cpp_api/define_macros_8h_1abd1465eb38256d3f22cc1426b23d516b.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/define_macros_8h_1abe87b341f562fd1cf40b7672e4d759da.html b/docs/_cpp_api/define_macros_8h_1abe87b341f562fd1cf40b7672e4d759da.html
index 54e1b888c6..7e1fb36098 100644
--- a/docs/_cpp_api/define_macros_8h_1abe87b341f562fd1cf40b7672e4d759da.html
+++ b/docs/_cpp_api/define_macros_8h_1abe87b341f562fd1cf40b7672e4d759da.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/define_macros_8h_1ad19939408f7be171a74a89928b36eb59.html b/docs/_cpp_api/define_macros_8h_1ad19939408f7be171a74a89928b36eb59.html
index d0f03666bf..230b5da237 100644
--- a/docs/_cpp_api/define_macros_8h_1ad19939408f7be171a74a89928b36eb59.html
+++ b/docs/_cpp_api/define_macros_8h_1ad19939408f7be171a74a89928b36eb59.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/define_macros_8h_1adad592a7b1b7eed529cdf6acd584c883.html b/docs/_cpp_api/define_macros_8h_1adad592a7b1b7eed529cdf6acd584c883.html
index 2e9a5ad66b..7d8daff60e 100644
--- a/docs/_cpp_api/define_macros_8h_1adad592a7b1b7eed529cdf6acd584c883.html
+++ b/docs/_cpp_api/define_macros_8h_1adad592a7b1b7eed529cdf6acd584c883.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/dir_cpp.html b/docs/_cpp_api/dir_cpp.html
index b5a8da952c..22b6adf413 100644
--- a/docs/_cpp_api/dir_cpp.html
+++ b/docs/_cpp_api/dir_cpp.html
@@ -197,7 +197,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/dir_cpp_include.html b/docs/_cpp_api/dir_cpp_include.html
index d65425efc9..63ec5b512c 100644
--- a/docs/_cpp_api/dir_cpp_include.html
+++ b/docs/_cpp_api/dir_cpp_include.html
@@ -197,7 +197,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/dir_cpp_include_torch_tensorrt.html b/docs/_cpp_api/dir_cpp_include_torch_tensorrt.html
index f28daeb41a..6ebe38e365 100644
--- a/docs/_cpp_api/dir_cpp_include_torch_tensorrt.html
+++ b/docs/_cpp_api/dir_cpp_include_torch_tensorrt.html
@@ -197,7 +197,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/enum_logging_8h_1a130f65408ad8cbaee060f05e8db69558.html b/docs/_cpp_api/enum_logging_8h_1a130f65408ad8cbaee060f05e8db69558.html
index 193f257934..58f2ae067f 100644
--- a/docs/_cpp_api/enum_logging_8h_1a130f65408ad8cbaee060f05e8db69558.html
+++ b/docs/_cpp_api/enum_logging_8h_1a130f65408ad8cbaee060f05e8db69558.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/enum_torch__tensorrt_8h_1a3fbe5d72e4fc624dbd038853079620eb.html b/docs/_cpp_api/enum_torch__tensorrt_8h_1a3fbe5d72e4fc624dbd038853079620eb.html
index 89b614bb83..56d0400a60 100644
--- a/docs/_cpp_api/enum_torch__tensorrt_8h_1a3fbe5d72e4fc624dbd038853079620eb.html
+++ b/docs/_cpp_api/enum_torch__tensorrt_8h_1a3fbe5d72e4fc624dbd038853079620eb.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
@@ -365,17 +365,17 @@ Enum Documentation
-
-enumerator kSAFETY
+enumerator kSAFETY
-
-enumerator kDLA_STANDALONE
+enumerator kDLA_STANDALONE
diff --git a/docs/_cpp_api/file_cpp_include_torch_tensorrt_logging.h.html b/docs/_cpp_api/file_cpp_include_torch_tensorrt_logging.h.html
index c62df2f40d..5f927b4668 100644
--- a/docs/_cpp_api/file_cpp_include_torch_tensorrt_logging.h.html
+++ b/docs/_cpp_api/file_cpp_include_torch_tensorrt_logging.h.html
@@ -197,7 +197,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/file_cpp_include_torch_tensorrt_macros.h.html b/docs/_cpp_api/file_cpp_include_torch_tensorrt_macros.h.html
index 03d6a3d087..8a66432732 100644
--- a/docs/_cpp_api/file_cpp_include_torch_tensorrt_macros.h.html
+++ b/docs/_cpp_api/file_cpp_include_torch_tensorrt_macros.h.html
@@ -197,7 +197,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/file_cpp_include_torch_tensorrt_ptq.h.html b/docs/_cpp_api/file_cpp_include_torch_tensorrt_ptq.h.html
index f16bb9cbe1..468869ddb3 100644
--- a/docs/_cpp_api/file_cpp_include_torch_tensorrt_ptq.h.html
+++ b/docs/_cpp_api/file_cpp_include_torch_tensorrt_ptq.h.html
@@ -197,7 +197,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/file_cpp_include_torch_tensorrt_torch_tensorrt.h.html b/docs/_cpp_api/file_cpp_include_torch_tensorrt_torch_tensorrt.h.html
index 77a3f66aaa..c5baadd29b 100644
--- a/docs/_cpp_api/file_cpp_include_torch_tensorrt_torch_tensorrt.h.html
+++ b/docs/_cpp_api/file_cpp_include_torch_tensorrt_torch_tensorrt.h.html
@@ -197,7 +197,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/function_logging_8h_1a0593f776f469c20469e2f729fc7861a3.html b/docs/_cpp_api/function_logging_8h_1a0593f776f469c20469e2f729fc7861a3.html
index f4ec7cf898..8cc3af913c 100644
--- a/docs/_cpp_api/function_logging_8h_1a0593f776f469c20469e2f729fc7861a3.html
+++ b/docs/_cpp_api/function_logging_8h_1a0593f776f469c20469e2f729fc7861a3.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/function_logging_8h_1a0c012cb374addd90eb1f42eaec570650.html b/docs/_cpp_api/function_logging_8h_1a0c012cb374addd90eb1f42eaec570650.html
index 53d6ea2b83..7dfd971c02 100644
--- a/docs/_cpp_api/function_logging_8h_1a0c012cb374addd90eb1f42eaec570650.html
+++ b/docs/_cpp_api/function_logging_8h_1a0c012cb374addd90eb1f42eaec570650.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/function_logging_8h_1a56e110feaaba2c3fd44bd201fd21a76a.html b/docs/_cpp_api/function_logging_8h_1a56e110feaaba2c3fd44bd201fd21a76a.html
index 3aac0afd52..abc2d66097 100644
--- a/docs/_cpp_api/function_logging_8h_1a56e110feaaba2c3fd44bd201fd21a76a.html
+++ b/docs/_cpp_api/function_logging_8h_1a56e110feaaba2c3fd44bd201fd21a76a.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/function_logging_8h_1a7cb50492421ea9de4e3db895819df6f2.html b/docs/_cpp_api/function_logging_8h_1a7cb50492421ea9de4e3db895819df6f2.html
index a17a31c20f..492bbc1bb0 100644
--- a/docs/_cpp_api/function_logging_8h_1a7cb50492421ea9de4e3db895819df6f2.html
+++ b/docs/_cpp_api/function_logging_8h_1a7cb50492421ea9de4e3db895819df6f2.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/function_logging_8h_1ac46ac0901cb97e3ae6e93b45f24e90b8.html b/docs/_cpp_api/function_logging_8h_1ac46ac0901cb97e3ae6e93b45f24e90b8.html
index 20be71d2a1..bb598c1214 100644
--- a/docs/_cpp_api/function_logging_8h_1ac46ac0901cb97e3ae6e93b45f24e90b8.html
+++ b/docs/_cpp_api/function_logging_8h_1ac46ac0901cb97e3ae6e93b45f24e90b8.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/function_logging_8h_1ad2efd47b6c3689e58ccc595680579ae5.html b/docs/_cpp_api/function_logging_8h_1ad2efd47b6c3689e58ccc595680579ae5.html
index 811a21aa5b..6a65f08d88 100644
--- a/docs/_cpp_api/function_logging_8h_1ad2efd47b6c3689e58ccc595680579ae5.html
+++ b/docs/_cpp_api/function_logging_8h_1ad2efd47b6c3689e58ccc595680579ae5.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/function_logging_8h_1af8f3443813315af7901903d25dd495cc.html b/docs/_cpp_api/function_logging_8h_1af8f3443813315af7901903d25dd495cc.html
index fab20d1457..3dfa1cba90 100644
--- a/docs/_cpp_api/function_logging_8h_1af8f3443813315af7901903d25dd495cc.html
+++ b/docs/_cpp_api/function_logging_8h_1af8f3443813315af7901903d25dd495cc.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/function_ptq_8h_1a83ff2be7e0b80bc7434de415861dc039.html b/docs/_cpp_api/function_ptq_8h_1a83ff2be7e0b80bc7434de415861dc039.html
index c6812c54b7..cc7e47407a 100644
--- a/docs/_cpp_api/function_ptq_8h_1a83ff2be7e0b80bc7434de415861dc039.html
+++ b/docs/_cpp_api/function_ptq_8h_1a83ff2be7e0b80bc7434de415861dc039.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/function_ptq_8h_1a9835f6e605dce1abf442a55b64d6dffa.html b/docs/_cpp_api/function_ptq_8h_1a9835f6e605dce1abf442a55b64d6dffa.html
index 283ecf2907..6d31473b19 100644
--- a/docs/_cpp_api/function_ptq_8h_1a9835f6e605dce1abf442a55b64d6dffa.html
+++ b/docs/_cpp_api/function_ptq_8h_1a9835f6e605dce1abf442a55b64d6dffa.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/function_torch__tensorrt_8h_1a5b405fd3bf3c8fc2e2a54cbbab979797.html b/docs/_cpp_api/function_torch__tensorrt_8h_1a5b405fd3bf3c8fc2e2a54cbbab979797.html
index 267423c7b2..e07b816931 100644
--- a/docs/_cpp_api/function_torch__tensorrt_8h_1a5b405fd3bf3c8fc2e2a54cbbab979797.html
+++ b/docs/_cpp_api/function_torch__tensorrt_8h_1a5b405fd3bf3c8fc2e2a54cbbab979797.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/function_torch__tensorrt_8h_1a6e19490a08fb1553c9dd347a5ae79db9.html b/docs/_cpp_api/function_torch__tensorrt_8h_1a6e19490a08fb1553c9dd347a5ae79db9.html
index 7328eaddc6..3a4998e76f 100644
--- a/docs/_cpp_api/function_torch__tensorrt_8h_1a6e19490a08fb1553c9dd347a5ae79db9.html
+++ b/docs/_cpp_api/function_torch__tensorrt_8h_1a6e19490a08fb1553c9dd347a5ae79db9.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/function_torch__tensorrt_8h_1a710df824a7718b440e4bc17bf9693cef.html b/docs/_cpp_api/function_torch__tensorrt_8h_1a710df824a7718b440e4bc17bf9693cef.html
index 4b4fb4b142..8c0dfb6bcd 100644
--- a/docs/_cpp_api/function_torch__tensorrt_8h_1a710df824a7718b440e4bc17bf9693cef.html
+++ b/docs/_cpp_api/function_torch__tensorrt_8h_1a710df824a7718b440e4bc17bf9693cef.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/function_torch__tensorrt_8h_1ac4ab8313ae72c2c899ea31548b528528.html b/docs/_cpp_api/function_torch__tensorrt_8h_1ac4ab8313ae72c2c899ea31548b528528.html
index 0106231d18..f913a43ed4 100644
--- a/docs/_cpp_api/function_torch__tensorrt_8h_1ac4ab8313ae72c2c899ea31548b528528.html
+++ b/docs/_cpp_api/function_torch__tensorrt_8h_1ac4ab8313ae72c2c899ea31548b528528.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/function_torch__tensorrt_8h_1ad1acd06eaeaffbbcf6e7ebf426891384.html b/docs/_cpp_api/function_torch__tensorrt_8h_1ad1acd06eaeaffbbcf6e7ebf426891384.html
index bd94c10b75..84fbd2033c 100644
--- a/docs/_cpp_api/function_torch__tensorrt_8h_1ad1acd06eaeaffbbcf6e7ebf426891384.html
+++ b/docs/_cpp_api/function_torch__tensorrt_8h_1ad1acd06eaeaffbbcf6e7ebf426891384.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/function_torch__tensorrt_8h_1ad6a4ee8ca6c8f6e5519eb1128ec7f4a1.html b/docs/_cpp_api/function_torch__tensorrt_8h_1ad6a4ee8ca6c8f6e5519eb1128ec7f4a1.html
index 65f72c8326..5d3a3ea80a 100644
--- a/docs/_cpp_api/function_torch__tensorrt_8h_1ad6a4ee8ca6c8f6e5519eb1128ec7f4a1.html
+++ b/docs/_cpp_api/function_torch__tensorrt_8h_1ad6a4ee8ca6c8f6e5519eb1128ec7f4a1.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/function_torch__tensorrt_8h_1ae8d56472106eeef37fbe51ff7f40c9b2.html b/docs/_cpp_api/function_torch__tensorrt_8h_1ae8d56472106eeef37fbe51ff7f40c9b2.html
index 7725f1ac84..d4ac6dbf90 100644
--- a/docs/_cpp_api/function_torch__tensorrt_8h_1ae8d56472106eeef37fbe51ff7f40c9b2.html
+++ b/docs/_cpp_api/function_torch__tensorrt_8h_1ae8d56472106eeef37fbe51ff7f40c9b2.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/namespace_torch_tensorrt.html b/docs/_cpp_api/namespace_torch_tensorrt.html
index b12e1d909d..3c0b040b09 100644
--- a/docs/_cpp_api/namespace_torch_tensorrt.html
+++ b/docs/_cpp_api/namespace_torch_tensorrt.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/namespace_torch_tensorrt__logging.html b/docs/_cpp_api/namespace_torch_tensorrt__logging.html
index 3915f9942a..2d7553fb68 100644
--- a/docs/_cpp_api/namespace_torch_tensorrt__logging.html
+++ b/docs/_cpp_api/namespace_torch_tensorrt__logging.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/namespace_torch_tensorrt__ptq.html b/docs/_cpp_api/namespace_torch_tensorrt__ptq.html
index e780be2799..0817e41cdc 100644
--- a/docs/_cpp_api/namespace_torch_tensorrt__ptq.html
+++ b/docs/_cpp_api/namespace_torch_tensorrt__ptq.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/namespace_torch_tensorrt__torchscript.html b/docs/_cpp_api/namespace_torch_tensorrt__torchscript.html
index 7422ce4aff..6c5b5353f9 100644
--- a/docs/_cpp_api/namespace_torch_tensorrt__torchscript.html
+++ b/docs/_cpp_api/namespace_torch_tensorrt__torchscript.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/program_listing_file_cpp_include_torch_tensorrt_logging.h.html b/docs/_cpp_api/program_listing_file_cpp_include_torch_tensorrt_logging.h.html
index 8de4dfa497..1b4df13eaf 100644
--- a/docs/_cpp_api/program_listing_file_cpp_include_torch_tensorrt_logging.h.html
+++ b/docs/_cpp_api/program_listing_file_cpp_include_torch_tensorrt_logging.h.html
@@ -197,7 +197,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/program_listing_file_cpp_include_torch_tensorrt_macros.h.html b/docs/_cpp_api/program_listing_file_cpp_include_torch_tensorrt_macros.h.html
index fccc343fdd..9ef5d96570 100644
--- a/docs/_cpp_api/program_listing_file_cpp_include_torch_tensorrt_macros.h.html
+++ b/docs/_cpp_api/program_listing_file_cpp_include_torch_tensorrt_macros.h.html
@@ -197,7 +197,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/program_listing_file_cpp_include_torch_tensorrt_ptq.h.html b/docs/_cpp_api/program_listing_file_cpp_include_torch_tensorrt_ptq.h.html
index ff4fec239f..6381a34fca 100644
--- a/docs/_cpp_api/program_listing_file_cpp_include_torch_tensorrt_ptq.h.html
+++ b/docs/_cpp_api/program_listing_file_cpp_include_torch_tensorrt_ptq.h.html
@@ -197,7 +197,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/program_listing_file_cpp_include_torch_tensorrt_torch_tensorrt.h.html b/docs/_cpp_api/program_listing_file_cpp_include_torch_tensorrt_torch_tensorrt.h.html
index a0191faee5..300ed123f6 100644
--- a/docs/_cpp_api/program_listing_file_cpp_include_torch_tensorrt_torch_tensorrt.h.html
+++ b/docs/_cpp_api/program_listing_file_cpp_include_torch_tensorrt_torch_tensorrt.h.html
@@ -197,7 +197,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/structtorch__tensorrt_1_1Device.html b/docs/_cpp_api/structtorch__tensorrt_1_1Device.html
index 2178f4c32e..6f683302d8 100644
--- a/docs/_cpp_api/structtorch__tensorrt_1_1Device.html
+++ b/docs/_cpp_api/structtorch__tensorrt_1_1Device.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/structtorch__tensorrt_1_1Input.html b/docs/_cpp_api/structtorch__tensorrt_1_1Input.html
index ecb48d8acb..a41f548feb 100644
--- a/docs/_cpp_api/structtorch__tensorrt_1_1Input.html
+++ b/docs/_cpp_api/structtorch__tensorrt_1_1Input.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/structtorch__tensorrt_1_1torchscript_1_1CompileSpec.html b/docs/_cpp_api/structtorch__tensorrt_1_1torchscript_1_1CompileSpec.html
index 9e870c731d..b78b61a977 100644
--- a/docs/_cpp_api/structtorch__tensorrt_1_1torchscript_1_1CompileSpec.html
+++ b/docs/_cpp_api/structtorch__tensorrt_1_1torchscript_1_1CompileSpec.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/torch_tensort_cpp.html b/docs/_cpp_api/torch_tensort_cpp.html
index 48d8b11c14..81e9c7f584 100644
--- a/docs/_cpp_api/torch_tensort_cpp.html
+++ b/docs/_cpp_api/torch_tensort_cpp.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_cpp_api/unabridged_orphan.html b/docs/_cpp_api/unabridged_orphan.html
index 232e61fedf..5d243535cb 100644
--- a/docs/_cpp_api/unabridged_orphan.html
+++ b/docs/_cpp_api/unabridged_orphan.html
@@ -197,7 +197,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
diff --git a/docs/_notebooks/CitriNet-example.html b/docs/_notebooks/CitriNet-example.html
index fb1282d7cd..548bbbd8ca 100644
--- a/docs/_notebooks/CitriNet-example.html
+++ b/docs/_notebooks/CitriNet-example.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
@@ -618,7 +618,7 @@
-
+
Torch-TensorRT Getting Started - CitriNet
@@ -640,7 +640,9 @@ Content
Conclusion
-## 1. Requirements
+
+
Follow the steps in README to prepare a Docker container, within which you can run this notebook. This notebook assumes that you are within a Jupyter environment in a docker container with Torch-TensorRT installed, such as an NGC monthly release of nvcr.io/nvidia/pytorch:<yy.mm>-py3
(where yy
indicates the last two numbers of a calendar year, and mm
indicates the month in two-digit numerical form)
Now that you are in the docker, the next step is to install the required dependencies.
@@ -913,7 +915,9 @@
Content
-## 2. Download Citrinet model
+
+## 2. Download Citrinet model
+
Next, we download a pretrained Nemo Citrinet model and convert it to a Torchscript module:
-
+
Torch-TensorRT Getting Started - EfficientNet-B0
@@ -691,16 +691,22 @@ Content container to run this notebook.
Otherwise, you can follow the steps in notebooks/README
to prepare a Docker container yourself, within which you can run this demo notebook.
-## 2. EfficientNet Overview
+
+## 2. EfficientNet Overview
+
PyTorch has a model repository called timm
, which is a source for high quality implementations of computer vision models. We can get our EfficientNet model from there pretrained on ImageNet.
Model Description
This model is based on the EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks paper.

-## 3. Running the model without optimizations
+
+## 3. Running the model without optimizations
+
PyTorch has a model repository called timm
, which is a source for high quality implementations of computer vision models. We can get our EfficientNet model from there pretrained on ImageNet.
-## 5. Conclusion
+
+
In this notebook, we have walked through the complete process of compiling TorchScript models with Torch-TensorRT for EfficientNet-B0 model and test the performance impact of the optimization. With Torch-TensorRT, we observe a speedup of 1.35x with FP32, and 3.13x with FP16 on an NVIDIA 3090 GPU. These acceleration numbers will vary from GPU to GPU(as well as implementation to implementation based on the ops used) and we encorage you to try out latest generation of Data center compute
cards for maximum acceleration.
diff --git a/docs/_notebooks/Hugging-Face-BERT.html b/docs/_notebooks/Hugging-Face-BERT.html
index 42c58b5cef..5fd0d4ca30 100644
--- a/docs/_notebooks/Hugging-Face-BERT.html
+++ b/docs/_notebooks/Hugging-Face-BERT.html
@@ -199,7 +199,7 @@
- master (1.2.0a0+ffedb78)
+ master (1.2.0a0+3a8704db)
@@ -618,7 +618,7 @@
-
+