From 63d2be8bf8f30a14e8e0b3f063bb9cb81624f03b Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 5 May 2023 10:04:05 -0700 Subject: [PATCH 1/2] feat: Add example usage scripts for dynamo path - Add sample scripts covering resnet18, transformers, and custom examples showcasing the `torch_tensorrt.dynamo.compile` path, which can compile models with data-dependent control flow and other such restrictions which can make other compilation methods more difficult - Cover different customizeable features allowed in the new backend - Make scripts Sphinx-Gallery compatible Python files --- .gitignore | 3 +- docsrc/conf.py | 7 ++ docsrc/index.rst | 29 ++++-- docsrc/requirements.txt | 1 + examples/dynamo/README.rst | 12 +++ .../dynamo/dynamo_compile_advanced_usage.py | 83 +++++++++++++++++ .../dynamo/dynamo_compile_resnet_example.py | 82 +++++++++++++++++ .../dynamo_compile_transformers_example.py | 92 +++++++++++++++++++ 8 files changed, 299 insertions(+), 10 deletions(-) create mode 100644 examples/dynamo/README.rst create mode 100644 examples/dynamo/dynamo_compile_advanced_usage.py create mode 100644 examples/dynamo/dynamo_compile_resnet_example.py create mode 100644 examples/dynamo/dynamo_compile_transformers_example.py diff --git a/.gitignore b/.gitignore index 929998c3d3..2899f2fb19 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ docsrc/_build docsrc/_notebooks docsrc/_cpp_api docsrc/_tmp +docsrc/tutorials/_rendered_examples *.so __pycache__ *.egg-info @@ -66,4 +67,4 @@ bazel-tensorrt *.cache *cifar-10-batches-py* bazel-project -build/ \ No newline at end of file +build/ diff --git a/docsrc/conf.py b/docsrc/conf.py index 9f9af43dd4..f11f6a5050 100644 --- a/docsrc/conf.py +++ b/docsrc/conf.py @@ -47,6 +47,7 @@ "sphinx.ext.coverage", "sphinx.ext.mathjax", "sphinx.ext.viewcode", + "sphinx_gallery.gen_gallery", ] napoleon_use_ivar = True @@ -79,6 +80,12 @@ # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] +# sphinx-gallery configuration +sphinx_gallery_conf = { + "examples_dirs": "../examples/dynamo", + "gallery_dirs": "tutorials/_rendered_examples/", +} + # Setup the breathe extension breathe_projects = {"Torch-TensorRT": "./_tmp/xml"} breathe_default_project = "Torch-TensorRT" diff --git a/docsrc/index.rst b/docsrc/index.rst index e5da81d2a5..f9fa39aa3a 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -36,30 +36,41 @@ Getting Started getting_started/getting_started_with_windows -Tutorials +User Guide ------------ * :ref:`creating_a_ts_mod` * :ref:`getting_started_with_fx` * :ref:`ptq` * :ref:`runtime` -* :ref:`serving_torch_tensorrt_with_triton` * :ref:`use_from_pytorch` * :ref:`using_dla` + +.. toctree:: + :caption: User Guide + :maxdepth: 1 + :hidden: + + user_guide/creating_torchscript_module_in_python + user_guide/getting_started_with_fx_path + user_guide/ptq + user_guide/runtime + user_guide/use_from_pytorch + user_guide/using_dla + +Tutorials +------------ +* :ref:`serving_torch_tensorrt_with_triton` * :ref:`notebooks` +* :ref:`dynamo_compile` .. toctree:: :caption: Tutorials - :maxdepth: 1 + :maxdepth: 3 :hidden: - tutorials/creating_torchscript_module_in_python - tutorials/getting_started_with_fx_path - tutorials/ptq - tutorials/runtime tutorials/serving_torch_tensorrt_with_triton - tutorials/use_from_pytorch - tutorials/using_dla tutorials/notebooks + tutorials/_rendered_examples/index Python API Documenation ------------------------ diff --git a/docsrc/requirements.txt b/docsrc/requirements.txt index ccbe311f0f..ac75bf5632 100644 --- a/docsrc/requirements.txt +++ b/docsrc/requirements.txt @@ -1,4 +1,5 @@ sphinx==4.5.0 +sphinx-gallery==0.13.0 breathe==4.33.1 exhale==0.3.1 -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst new file mode 100644 index 0000000000..7d407f73d0 --- /dev/null +++ b/examples/dynamo/README.rst @@ -0,0 +1,12 @@ +.. _dynamo_compile: + +Dynamo Compile Examples +================ + +This document contains examples of usage of the `torch_tensorrt.dynamo.compile` API which integrates with `torch.compile` functionality + +Overview of Available Scripts +----------------------------------------------- +- `dynamo_compile_resnet_example.py <./dynamo_compile_resnet_example.html>`_: Example showcasing compilation of ResNet model +- `dynamo_compile_transformers_example.py <./dynamo_compile_transformers_example.html>`_: Example showcasing compilation of transformer-based model +- `dynamo_compile_advanced_usage.py <./dynamo_compile_advanced_usage.html>`_: Advanced usage including making a custom backend to use directly with the `torch.compile` API diff --git a/examples/dynamo/dynamo_compile_advanced_usage.py b/examples/dynamo/dynamo_compile_advanced_usage.py new file mode 100644 index 0000000000..4e1e67f816 --- /dev/null +++ b/examples/dynamo/dynamo_compile_advanced_usage.py @@ -0,0 +1,83 @@ +""" +Dynamo Compile Advanced Usage +========================= + +This interactive script is intended as an overview of the process by which `torch_tensorrt.dynamo.compile` works, and how it integrates with the new `torch.compile` API.""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +import torch +from torch_tensorrt.dynamo.backend import create_backend +from torch_tensorrt.fx.lower_setting import LowerPrecision + +# %% + +# We begin by defining a model +class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, x: torch.Tensor, y: torch.Tensor): + x_out = self.relu(x) + y_out = self.relu(y) + x_y_out = x_out + y_out + return torch.mean(x_y_out) + + +# %% +# Compilation with `torch.compile` Using Default Settings +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Define sample float inputs and initialize model +sample_inputs = [torch.rand((5, 7)).cuda(), torch.rand((5, 7)).cuda()] +model = Model().eval().cuda() + +# %% + +# Next, we compile the model using torch.compile +# For the default settings, we can simply call torch.compile +# with the backend "tensorrt", and run the model on an +# input to cause compilation, as so: +optimized_model = torch.compile(model, backend="tensorrt") +optimized_model(*sample_inputs) + +# %% +# Compilation with `torch.compile` Using Custom Settings +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Define sample half inputs and initialize model +sample_inputs_half = [ + torch.rand((5, 7)).half().cuda(), + torch.rand((5, 7)).half().cuda(), +] +model_half = Model().eval().cuda() + +# %% + +# If we want to customize certain options in the backend, +# but still use the torch.compile call directly, we can call the +# convenience/helper function create_backend to create a custom backend +# which has been pre-populated with certain keys +custom_backend = create_backend( + lower_precision=LowerPrecision.FP16, + debug=True, + min_block_size=2, + torch_executed_ops={}, +) + +# Run the model on an input to cause compilation, as so: +optimized_model_custom = torch.compile(model_half, backend=custom_backend) +optimized_model_custom(*sample_inputs_half) + +# %% +# Cleanup +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Finally, we use Torch utilities to clean up the workspace +torch._dynamo.reset() + +with torch.no_grad(): + torch.cuda.empty_cache() diff --git a/examples/dynamo/dynamo_compile_resnet_example.py b/examples/dynamo/dynamo_compile_resnet_example.py new file mode 100644 index 0000000000..737d188c54 --- /dev/null +++ b/examples/dynamo/dynamo_compile_resnet_example.py @@ -0,0 +1,82 @@ +""" +Dynamo Compile ResNet Example +========================= + +This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a ResNet model.""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +import torch +import torch_tensorrt +import torchvision.models as models + +# %% + +# Initialize model with half precision and sample inputs +model = models.resnet18(pretrained=True).half().eval().to("cuda") +inputs = [torch.randn((1, 3, 224, 224)).to("cuda").half()] + +# %% +# Optional Input Arguments to `torch_tensorrt.dynamo.compile` +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Enabled precision for TensorRT optimization +enabled_precisions = {torch.half} + +# Whether to print verbose logs +debug = True + +# Workspace size for TensorRT +workspace_size = 20 << 30 + +# Maximum number of TRT Engines +# (Lower value allows more graph segmentation) +min_block_size = 3 + +# Operations to Run in Torch, regardless of converter support +torch_executed_ops = {} + +# %% +# Compilation with `torch_tensorrt.dynamo.compile` +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Build and compile the model with torch.compile, using Torch-TensorRT backend +optimized_model = torch_tensorrt.dynamo.compile( + model, + inputs, + enabled_precisions=enabled_precisions, + debug=debug, + workspace_size=workspace_size, + min_block_size=min_block_size, + torch_executed_ops=torch_executed_ops, +) + +# %% +# Equivalently, we could have run the above via the convenience frontend, as so: +# `torch_tensorrt.compile(model, ir="dynamo_compile", inputs=inputs, ...)` + +# %% +# Inference +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Does not cause recompilation (same batch size as input) +new_inputs = [torch.randn((1, 3, 224, 224)).half().to("cuda")] +new_outputs = optimized_model(*new_inputs) + +# %% + +# Does cause recompilation (new batch size) +new_batch_size_inputs = [torch.randn((8, 3, 224, 224)).half().to("cuda")] +new_batch_size_outputs = optimized_model(*new_batch_size_inputs) + +# %% +# Cleanup +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Finally, we use Torch utilities to clean up the workspace +torch._dynamo.reset() + +with torch.no_grad(): + torch.cuda.empty_cache() diff --git a/examples/dynamo/dynamo_compile_transformers_example.py b/examples/dynamo/dynamo_compile_transformers_example.py new file mode 100644 index 0000000000..bcff5f946b --- /dev/null +++ b/examples/dynamo/dynamo_compile_transformers_example.py @@ -0,0 +1,92 @@ +""" +Dynamo Compile Transformers Example +========================= + +This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a transformer-based model.""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +import torch +import torch_tensorrt +from transformers import BertModel + +# %% + +# Initialize model with float precision and sample inputs +model = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda") +inputs = [ + torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), + torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), +] + + +# %% +# Optional Input Arguments to `torch_tensorrt.dynamo.compile` +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Enabled precision for TensorRT optimization +enabled_precisions = {torch.float} + +# Whether to print verbose logs +debug = True + +# Workspace size for TensorRT +workspace_size = 20 << 30 + +# Maximum number of TRT Engines +# (Lower value allows more graph segmentation) +min_block_size = 3 + +# Operations to Run in Torch, regardless of converter support +torch_executed_ops = {} + +# %% +# Compilation with `torch_tensorrt.dynamo.compile` +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Build and compile the model with torch.compile, using tensorrt backend +optimized_model = torch_tensorrt.dynamo.compile( + model, + inputs, + enabled_precisions=enabled_precisions, + debug=debug, + workspace_size=workspace_size, + min_block_size=min_block_size, + torch_executed_ops=torch_executed_ops, +) + +# %% +# Equivalently, we could have run the above via the convenience frontend, as so: +# `torch_tensorrt.compile(model, ir="dynamo_compile", inputs=inputs, ...)` + +# %% +# Inference +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Does not cause recompilation (same batch size as input) +new_inputs = [ + torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), + torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), +] +new_outputs = optimized_model(*new_inputs) + +# %% + +# Does cause recompilation (new batch size) +new_inputs = [ + torch.randint(0, 2, (4, 14), dtype=torch.int32).to("cuda"), + torch.randint(0, 2, (4, 14), dtype=torch.int32).to("cuda"), +] +new_outputs = optimized_model(*new_inputs) + +# %% +# Cleanup +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Finally, we use Torch utilities to clean up the workspace +torch._dynamo.reset() + +with torch.no_grad(): + torch.cuda.empty_cache() From cd6542e7e5bf1c4ba5fa2644a8f9e08a41122527 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 26 May 2023 16:35:35 -0700 Subject: [PATCH 2/2] fix: Update `index.rst` - Show individual links in sidebar --- docsrc/index.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docsrc/index.rst b/docsrc/index.rst index f9fa39aa3a..e12c9e6f83 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -70,7 +70,9 @@ Tutorials tutorials/serving_torch_tensorrt_with_triton tutorials/notebooks - tutorials/_rendered_examples/index + tutorials/_rendered_examples/dynamo_compile_resnet_example + tutorials/_rendered_examples/dynamo_compile_transformers_example + tutorials/_rendered_examples/dynamo_compile_advanced_usage Python API Documenation ------------------------