From a1bfda68482beac9a8a44a61c9559fbe12e7950e Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Thu, 18 May 2023 12:37:32 -0700 Subject: [PATCH 1/2] fix: Reorganize Dynamo directory + backends (#1928) --- .circleci/config.yml | 22 ++++----- py/torch_tensorrt/__init__.py | 2 +- py/torch_tensorrt/_compile.py | 12 ++--- py/torch_tensorrt/dynamo/__init__.py | 2 +- .../{torch_compile => backend}/__init__.py | 10 ++-- .../{torch_compile => backend}/_defaults.py | 2 +- .../{torch_compile => backend}/_settings.py | 2 +- .../{torch_compile => backend}/backends.py | 48 +++++++++++-------- .../{torch_compile => backend}/conversion.py | 0 .../dynamo/backend/lowering/__init__.py | 7 +++ .../lowering/_decompositions.py | 0 .../lowering/_partition.py | 2 +- .../test/test_compiler_utils.py | 2 +- .../test/test_lowering.py | 0 .../test/test_partitioning.py | 2 +- .../{torch_compile => backend}/test/utils.py | 4 +- .../{torch_compile => backend}/utils.py | 2 +- py/torch_tensorrt/dynamo/test/conftest.py | 2 +- .../dynamo/test/test_dynamo_backend.py | 35 +++++++++++++- .../dynamo/torch_compile/lowering/__init__.py | 7 --- 20 files changed, 103 insertions(+), 60 deletions(-) rename py/torch_tensorrt/dynamo/{torch_compile => backend}/__init__.py (90%) rename py/torch_tensorrt/dynamo/{torch_compile => backend}/_defaults.py (83%) rename py/torch_tensorrt/dynamo/{torch_compile => backend}/_settings.py (86%) rename py/torch_tensorrt/dynamo/{torch_compile => backend}/backends.py (70%) rename py/torch_tensorrt/dynamo/{torch_compile => backend}/conversion.py (100%) create mode 100644 py/torch_tensorrt/dynamo/backend/lowering/__init__.py rename py/torch_tensorrt/dynamo/{torch_compile => backend}/lowering/_decompositions.py (100%) rename py/torch_tensorrt/dynamo/{torch_compile => backend}/lowering/_partition.py (98%) rename py/torch_tensorrt/dynamo/{torch_compile => backend}/test/test_compiler_utils.py (95%) rename py/torch_tensorrt/dynamo/{torch_compile => backend}/test/test_lowering.py (100%) rename py/torch_tensorrt/dynamo/{torch_compile => backend}/test/test_partitioning.py (97%) rename py/torch_tensorrt/dynamo/{torch_compile => backend}/test/utils.py (95%) rename py/torch_tensorrt/dynamo/{torch_compile => backend}/utils.py (98%) delete mode 100644 py/torch_tensorrt/dynamo/torch_compile/lowering/__init__.py diff --git a/.circleci/config.yml b/.circleci/config.yml index a7d799eb9d..2fef214168 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -740,15 +740,15 @@ commands: - store_artifacts: path: /tmp/testlogs - test-dynamo-torch_compile-core: - description: "Test the Dynamo torch_compile path" + test-dynamo-compile-core: + description: "Test the Dynamo compile path" steps: - run: - name: Run Dynamo torch_compile core tests + name: Run Dynamo compile core tests command: | - cd py/torch_tensorrt/dynamo/torch_compile + cd py/torch_tensorrt/dynamo/backend pushd test/ - pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml + pytest --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml popd - store_test_results: @@ -756,17 +756,17 @@ commands: - store_artifacts: path: /tmp/testlogs - test-dynamo-torch_compile: - description: "Test the Dynamo torch_compile path" + test-dynamo-compile: + description: "Test the Dynamo compile path" steps: - run: - name: Run Dynamo torch_compile E2E tests + name: Run Dynamo compile E2E tests command: | cd py/torch_tensorrt/dynamo/ pushd test/ pip3 install timm pip3 install transformers - pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml --ir torch_compile + pytest --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir dynamo_compile popd - store_test_results: @@ -1000,8 +1000,8 @@ jobs: command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl # We install torch after torch-trt because pip automatically enforces the version constraint otherwise - dump-test-env - - test-dynamo-torch_compile - - test-dynamo-torch_compile-core + - test-dynamo-compile + - test-dynamo-compile-core - test-dynamo-fx_ts package-x86_64-linux: diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index b8e4fd0d9d..f92b29aa86 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -97,7 +97,7 @@ def _find_lib(name, paths): if version.parse(torch.__version__) >= version.parse("2.dev"): from torch_tensorrt import dynamo - from torch_tensorrt.dynamo import torch_compile + from torch_tensorrt.dynamo import backend def _register_with_torch(): diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index e300669fd5..de0aeb5308 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -16,7 +16,7 @@ class _IRType(Enum): ts = 0 fx = 1 fx_ts_compat = 2 - torch_compile = 3 + dynamo_compile = 3 class _ModuleType(Enum): @@ -47,7 +47,7 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]]) ir_targets_fx = ir == "fx" - ir_targets_torch_compile = ir == "torch_compile" + ir_targets_dynamo_compile = ir == "dynamo_compile" ir_targets_fx_ts_compat = ir == "fx_ts_compat" if module_is_tsable and ir_targets_torchscript: @@ -56,8 +56,8 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: return _IRType.fx elif module_is_fxable and ir_targets_fx_ts_compat: return _IRType.fx_ts_compat - elif module_is_fxable and ir_targets_torch_compile: - return _IRType.torch_compile + elif module_is_fxable and ir_targets_dynamo_compile: + return _IRType.dynamo_compile else: if ir == "default": # Options are listed in order of preference @@ -156,8 +156,8 @@ def compile( dynamic_batch=False, **kwargs, ) - elif target_ir == _IRType.torch_compile: - return torch_tensorrt.dynamo.torch_compile( + elif target_ir == _IRType.dynamo_compile: + return torch_tensorrt.dynamo.compile( module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs ) elif target_ir == _IRType.fx_ts_compat: diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 26e8b7aa3e..ea1778edfe 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -1,2 +1,2 @@ from torch_tensorrt.dynamo import fx_ts_compat -from .torch_compile import compile as torch_compile +from .backend import compile diff --git a/py/torch_tensorrt/dynamo/torch_compile/__init__.py b/py/torch_tensorrt/dynamo/backend/__init__.py similarity index 90% rename from py/torch_tensorrt/dynamo/torch_compile/__init__.py rename to py/torch_tensorrt/dynamo/backend/__init__.py index 32e5567c51..eba389ecec 100644 --- a/py/torch_tensorrt/dynamo/torch_compile/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/__init__.py @@ -8,10 +8,10 @@ from torch_tensorrt import EngineCapability, Device from torch_tensorrt.fx.utils import LowerPrecision -from torch_tensorrt.dynamo.torch_compile._settings import CompilationSettings -from torch_tensorrt.dynamo.torch_compile.utils import prepare_inputs, prepare_device -from torch_tensorrt.dynamo.torch_compile.backends import tensorrt_backend -from torch_tensorrt.dynamo.torch_compile._defaults import ( +from torch_tensorrt.dynamo.backend._settings import CompilationSettings +from torch_tensorrt.dynamo.backend.utils import prepare_inputs, prepare_device +from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend +from torch_tensorrt.dynamo.backend._defaults import ( PRECISION, DEBUG, MAX_WORKSPACE_SIZE, @@ -121,6 +121,6 @@ def create_backend( ) return partial( - tensorrt_backend, + torch_tensorrt_backend, settings=settings, ) diff --git a/py/torch_tensorrt/dynamo/torch_compile/_defaults.py b/py/torch_tensorrt/dynamo/backend/_defaults.py similarity index 83% rename from py/torch_tensorrt/dynamo/torch_compile/_defaults.py rename to py/torch_tensorrt/dynamo/backend/_defaults.py index 48c9a26f9e..814331e158 100644 --- a/py/torch_tensorrt/dynamo/torch_compile/_defaults.py +++ b/py/torch_tensorrt/dynamo/backend/_defaults.py @@ -4,4 +4,4 @@ PRECISION = LowerPrecision.FP32 DEBUG = False MAX_WORKSPACE_SIZE = 20 << 30 -MAX_NUM_TRT_ENGINES = 200 +MAX_NUM_TRT_ENGINES = 10 diff --git a/py/torch_tensorrt/dynamo/torch_compile/_settings.py b/py/torch_tensorrt/dynamo/backend/_settings.py similarity index 86% rename from py/torch_tensorrt/dynamo/torch_compile/_settings.py rename to py/torch_tensorrt/dynamo/backend/_settings.py index 276b8742ff..7677b1bd57 100644 --- a/py/torch_tensorrt/dynamo/torch_compile/_settings.py +++ b/py/torch_tensorrt/dynamo/backend/_settings.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from torch_tensorrt.fx.utils import LowerPrecision -from torch_tensorrt.dynamo.torch_compile._defaults import ( +from torch_tensorrt.dynamo.backend._defaults import ( PRECISION, DEBUG, MAX_WORKSPACE_SIZE, diff --git a/py/torch_tensorrt/dynamo/torch_compile/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py similarity index 70% rename from py/torch_tensorrt/dynamo/torch_compile/backends.py rename to py/torch_tensorrt/dynamo/backend/backends.py index 9ceab947f0..9df3f1c686 100644 --- a/py/torch_tensorrt/dynamo/torch_compile/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -4,30 +4,42 @@ from functools import partial import torch._dynamo as td -from torch_tensorrt.dynamo.torch_compile._settings import CompilationSettings -from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import ( +from torch_tensorrt.dynamo.backend._settings import CompilationSettings +from torch_tensorrt.dynamo.backend.lowering._decompositions import ( get_decompositions, ) -from torch_tensorrt.dynamo.torch_compile.lowering._partition import ( +from torch_tensorrt.dynamo.backend.lowering._partition import ( partition, get_submod_inputs, ) -from torch_tensorrt.dynamo.torch_compile.conversion import convert_module +from torch_tensorrt.dynamo.backend.conversion import convert_module from torch._dynamo.backends.common import fake_tensor_unsupported from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler -@td.register_backend(name="tensorrt") +@td.register_backend(name="torch_tensorrt") @fake_tensor_unsupported -def tensorrt_backend( - gm: torch.nn.Module, +def torch_tensorrt_backend( + gm: torch.fx.GraphModule, + sample_inputs: Sequence[torch.Tensor], + settings: CompilationSettings = CompilationSettings(), +): + DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend + + return DEFAULT_BACKEND(gm, sample_inputs, settings=settings) + + +@td.register_backend(name="aot_torch_tensorrt_aten") +@fake_tensor_unsupported +def aot_torch_tensorrt_aten_backend( + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], settings: CompilationSettings = CompilationSettings(), ): custom_backend = partial( - fx_dynamo_backend, + _pretraced_backend, settings=settings, ) @@ -40,14 +52,12 @@ def tensorrt_backend( ) -@td.register_backend(name="fx_tensorrt") -@fake_tensor_unsupported -def fx_dynamo_backend( +def _pretraced_backend( gm: torch.fx.GraphModule, - example_inputs: Sequence[torch.Tensor], + sample_inputs: Sequence[torch.Tensor], settings: CompilationSettings = CompilationSettings(), ): - """Helper function to manage translation of FX module to TRT engines + """Helper function to manage translation of traced FX module to TRT engines Args: module: FX GraphModule to convert @@ -57,9 +67,9 @@ def fx_dynamo_backend( Compiled FX GraphModule """ try: - trt_compiled = compile_module( + trt_compiled = _compile_module( gm, - example_inputs, + sample_inputs, settings=settings, ) return trt_compiled @@ -72,12 +82,12 @@ def fx_dynamo_backend( return gm.forward -def compile_module( +def _compile_module( gm: torch.fx.GraphModule, - example_inputs: Sequence[torch.Tensor], + sample_inputs: Sequence[torch.Tensor], settings: CompilationSettings = CompilationSettings(), ) -> torch.fx.GraphModule: - """Compile an FX module + """Compile a traced FX module Includes: Partitioning + Conversion Phases @@ -100,7 +110,7 @@ def compile_module( # Get submodule inputs submodule_inputs = get_submod_inputs( - partitioned_module, submodule, example_inputs + partitioned_module, submodule, sample_inputs ) # Create TRT Module from submodule diff --git a/py/torch_tensorrt/dynamo/torch_compile/conversion.py b/py/torch_tensorrt/dynamo/backend/conversion.py similarity index 100% rename from py/torch_tensorrt/dynamo/torch_compile/conversion.py rename to py/torch_tensorrt/dynamo/backend/conversion.py diff --git a/py/torch_tensorrt/dynamo/backend/lowering/__init__.py b/py/torch_tensorrt/dynamo/backend/lowering/__init__.py new file mode 100644 index 0000000000..01b20cef6d --- /dev/null +++ b/py/torch_tensorrt/dynamo/backend/lowering/__init__.py @@ -0,0 +1,7 @@ +from torch_tensorrt.dynamo.backend.lowering._decompositions import ( + get_decompositions, +) +from torch_tensorrt.dynamo.backend.lowering._partition import ( + partition, + get_submod_inputs, +) diff --git a/py/torch_tensorrt/dynamo/torch_compile/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py similarity index 100% rename from py/torch_tensorrt/dynamo/torch_compile/lowering/_decompositions.py rename to py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py diff --git a/py/torch_tensorrt/dynamo/torch_compile/lowering/_partition.py b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py similarity index 98% rename from py/torch_tensorrt/dynamo/torch_compile/lowering/_partition.py rename to py/torch_tensorrt/dynamo/backend/lowering/_partition.py index 1dd38e0bd9..1885d18705 100644 --- a/py/torch_tensorrt/dynamo/torch_compile/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py @@ -2,7 +2,7 @@ import torch -from torch_tensorrt.dynamo.torch_compile._defaults import MAX_NUM_TRT_ENGINES +from torch_tensorrt.dynamo.backend._defaults import MAX_NUM_TRT_ENGINES from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupport diff --git a/py/torch_tensorrt/dynamo/torch_compile/test/test_compiler_utils.py b/py/torch_tensorrt/dynamo/backend/test/test_compiler_utils.py similarity index 95% rename from py/torch_tensorrt/dynamo/torch_compile/test/test_compiler_utils.py rename to py/torch_tensorrt/dynamo/backend/test/test_compiler_utils.py index da7157c3e5..947a277ddd 100644 --- a/py/torch_tensorrt/dynamo/torch_compile/test/test_compiler_utils.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_compiler_utils.py @@ -1,4 +1,4 @@ -from torch_tensorrt.dynamo.torch_compile.utils import prepare_device, prepare_inputs +from torch_tensorrt.dynamo.backend.utils import prepare_device, prepare_inputs from utils import same_output_format import torch_tensorrt import unittest diff --git a/py/torch_tensorrt/dynamo/torch_compile/test/test_lowering.py b/py/torch_tensorrt/dynamo/backend/test/test_lowering.py similarity index 100% rename from py/torch_tensorrt/dynamo/torch_compile/test/test_lowering.py rename to py/torch_tensorrt/dynamo/backend/test/test_lowering.py diff --git a/py/torch_tensorrt/dynamo/torch_compile/test/test_partitioning.py b/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py similarity index 97% rename from py/torch_tensorrt/dynamo/torch_compile/test/test_partitioning.py rename to py/torch_tensorrt/dynamo/backend/test/test_partitioning.py index b068f9c413..fccdd3c32e 100644 --- a/py/torch_tensorrt/dynamo/torch_compile/test/test_partitioning.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py @@ -1,4 +1,4 @@ -from torch_tensorrt.dynamo.torch_compile.lowering import partition +from torch_tensorrt.dynamo.backend.lowering import partition from torch.testing._internal.common_utils import run_tests, TestCase import torch from copy import deepcopy diff --git a/py/torch_tensorrt/dynamo/torch_compile/test/utils.py b/py/torch_tensorrt/dynamo/backend/test/utils.py similarity index 95% rename from py/torch_tensorrt/dynamo/torch_compile/test/utils.py rename to py/torch_tensorrt/dynamo/backend/test/utils.py index bdcbbfcc4a..466a600db8 100644 --- a/py/torch_tensorrt/dynamo/torch_compile/test/utils.py +++ b/py/torch_tensorrt/dynamo/backend/test/utils.py @@ -2,10 +2,10 @@ from functools import partial from typing import List, Sequence import torch -from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import ( +from torch_tensorrt.dynamo.backend.lowering._decompositions import ( get_decompositions, ) -from torch_tensorrt.dynamo.torch_compile.lowering._partition import ( +from torch_tensorrt.dynamo.backend.lowering._partition import ( partition, ) diff --git a/py/torch_tensorrt/dynamo/torch_compile/utils.py b/py/torch_tensorrt/dynamo/backend/utils.py similarity index 98% rename from py/torch_tensorrt/dynamo/torch_compile/utils.py rename to py/torch_tensorrt/dynamo/backend/utils.py index ba76536338..e6e22d5f96 100644 --- a/py/torch_tensorrt/dynamo/torch_compile/utils.py +++ b/py/torch_tensorrt/dynamo/backend/utils.py @@ -45,7 +45,7 @@ def prepare_inputs( else: raise ValueError( - f"Invalid input type {type(inputs)} encountered in the torch_compile input parsing. " + f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. " + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" ) diff --git a/py/torch_tensorrt/dynamo/test/conftest.py b/py/torch_tensorrt/dynamo/test/conftest.py index 98be643435..7218d5335b 100644 --- a/py/torch_tensorrt/dynamo/test/conftest.py +++ b/py/torch_tensorrt/dynamo/test/conftest.py @@ -9,7 +9,7 @@ def pytest_addoption(parser): type=str, required=True, help="IR to compile with", - choices=["torch_compile", "fx_ts_compat"], + choices=["dynamo_compile", "fx_ts_compat"], ) diff --git a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py index 4852f033bd..531d0cc317 100644 --- a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py +++ b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py @@ -24,6 +24,7 @@ def test_resnet18(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "ir": ir, + "max_num_trt_engines": 200, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -33,6 +34,12 @@ def test_resnet18(ir): f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + # Clean up model env + torch._dynamo.reset() + + with torch.no_grad(): + torch.cuda.empty_cache() + @pytest.mark.unit def test_mobilenet_v2(ir): @@ -48,6 +55,7 @@ def test_mobilenet_v2(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "ir": ir, + "max_num_trt_engines": 200, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -57,6 +65,12 @@ def test_mobilenet_v2(ir): f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + # Clean up model env + torch._dynamo.reset() + + with torch.no_grad(): + torch.cuda.empty_cache() + @pytest.mark.unit def test_efficientnet_b0(ir): @@ -72,6 +86,7 @@ def test_efficientnet_b0(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "ir": ir, + "max_num_trt_engines": 200, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -81,6 +96,12 @@ def test_efficientnet_b0(ir): f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + # Clean up model env + torch._dynamo.reset() + + with torch.no_grad(): + torch.cuda.empty_cache() + @pytest.mark.unit def test_bert_base_uncased(ir): @@ -104,8 +125,8 @@ def test_bert_base_uncased(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "truncate_long_and_double": True, - "debug": True, "ir": ir, + "max_num_trt_engines": 200, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -119,6 +140,12 @@ def test_bert_base_uncased(ir): f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + # Clean up model env + torch._dynamo.reset() + + with torch.no_grad(): + torch.cuda.empty_cache() + @pytest.mark.unit def test_resnet18_half(ir): @@ -142,3 +169,9 @@ def test_resnet18_half(ir): cos_sim > COSINE_THRESHOLD, f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + + # Clean up model env + torch._dynamo.reset() + + with torch.no_grad(): + torch.cuda.empty_cache() diff --git a/py/torch_tensorrt/dynamo/torch_compile/lowering/__init__.py b/py/torch_tensorrt/dynamo/torch_compile/lowering/__init__.py deleted file mode 100644 index e0a41df755..0000000000 --- a/py/torch_tensorrt/dynamo/torch_compile/lowering/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import ( - get_decompositions, -) -from torch_tensorrt.dynamo.torch_compile.lowering._partition import ( - partition, - get_submod_inputs, -) From a831029682e16845600eacb793e16a3ee0982605 Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Thu, 18 May 2023 18:24:52 -0700 Subject: [PATCH 2/2] fix: Improve partitioning + lowering systems in `torch.compile` path (#1879) --- py/torch_tensorrt/dynamo/backend/__init__.py | 16 +- py/torch_tensorrt/dynamo/backend/_defaults.py | 2 +- py/torch_tensorrt/dynamo/backend/_settings.py | 8 +- py/torch_tensorrt/dynamo/backend/backends.py | 5 +- .../backend/lowering/_decompositions.py | 15 ++ .../dynamo/backend/lowering/_partition.py | 151 ++++++++++++++---- .../dynamo/backend/test/test_lowering.py | 106 +++++++++--- .../dynamo/backend/test/test_partitioning.py | 66 +++++++- .../dynamo/backend/test/utils.py | 94 ++++++++++- .../dynamo/test/test_dynamo_backend.py | 4 - 10 files changed, 391 insertions(+), 76 deletions(-) diff --git a/py/torch_tensorrt/dynamo/backend/__init__.py b/py/torch_tensorrt/dynamo/backend/__init__.py index eba389ecec..0846dec144 100644 --- a/py/torch_tensorrt/dynamo/backend/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/__init__.py @@ -4,7 +4,7 @@ import torch_tensorrt from functools import partial -from typing import Any +from typing import Any, Sequence from torch_tensorrt import EngineCapability, Device from torch_tensorrt.fx.utils import LowerPrecision @@ -15,7 +15,7 @@ PRECISION, DEBUG, MAX_WORKSPACE_SIZE, - MAX_NUM_TRT_ENGINES, + MIN_BLOCK_SIZE, ) @@ -41,7 +41,7 @@ def compile( calibrator=None, truncate_long_and_double=False, require_full_compilation=False, - min_block_size=3, + min_block_size=MIN_BLOCK_SIZE, torch_executed_ops=[], torch_executed_modules=[], **kwargs, @@ -50,7 +50,7 @@ def compile( logger.warn( "The Dynamo backend is an experimental feature, for which only the " + "following arguments are supported: " - + "{enabled_precisions, debug, workspace_size, max_num_trt_engines}" + + "{enabled_precisions, debug, workspace_size, min_block_size, torch_executed_ops}" ) if not isinstance(inputs, collections.abc.Sequence): @@ -80,6 +80,8 @@ def compile( precision=lower_precision, debug=debug, workspace_size=workspace_size, + min_block_size=min_block_size, + torch_executed_ops=torch_executed_ops, **kwargs, ) @@ -100,7 +102,8 @@ def create_backend( precision: LowerPrecision = PRECISION, debug: bool = DEBUG, workspace_size: int = MAX_WORKSPACE_SIZE, - max_num_trt_engines: int = MAX_NUM_TRT_ENGINES, + min_block_size: int = MIN_BLOCK_SIZE, + torch_executed_ops: Sequence[str] = set(), **kwargs, ): """Create torch.compile backend given specified arguments @@ -117,7 +120,8 @@ def create_backend( debug=debug, precision=precision, workspace_size=workspace_size, - max_num_trt_engines=max_num_trt_engines, + min_block_size=min_block_size, + torch_executed_ops=torch_executed_ops, ) return partial( diff --git a/py/torch_tensorrt/dynamo/backend/_defaults.py b/py/torch_tensorrt/dynamo/backend/_defaults.py index 814331e158..b1ee62dfa3 100644 --- a/py/torch_tensorrt/dynamo/backend/_defaults.py +++ b/py/torch_tensorrt/dynamo/backend/_defaults.py @@ -4,4 +4,4 @@ PRECISION = LowerPrecision.FP32 DEBUG = False MAX_WORKSPACE_SIZE = 20 << 30 -MAX_NUM_TRT_ENGINES = 10 +MIN_BLOCK_SIZE = 5 diff --git a/py/torch_tensorrt/dynamo/backend/_settings.py b/py/torch_tensorrt/dynamo/backend/_settings.py index 7677b1bd57..8c1a807343 100644 --- a/py/torch_tensorrt/dynamo/backend/_settings.py +++ b/py/torch_tensorrt/dynamo/backend/_settings.py @@ -1,11 +1,12 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Sequence from torch_tensorrt.fx.utils import LowerPrecision from torch_tensorrt.dynamo.backend._defaults import ( PRECISION, DEBUG, MAX_WORKSPACE_SIZE, - MAX_NUM_TRT_ENGINES, + MIN_BLOCK_SIZE, ) @@ -14,4 +15,5 @@ class CompilationSettings: precision: LowerPrecision = PRECISION debug: bool = DEBUG workspace_size: int = MAX_WORKSPACE_SIZE - max_num_trt_engines: int = MAX_NUM_TRT_ENGINES + min_block_size: int = MIN_BLOCK_SIZE + torch_executed_ops: Sequence[str] = field(default_factory=set) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 9df3f1c686..962cbe8eba 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -100,7 +100,10 @@ def _compile_module( """ # Partition module into components that can be TRT-accelerated partitioned_module = partition( - gm, verbose=settings.debug, max_num_trt_engines=settings.max_num_trt_engines + gm, + verbose=settings.debug, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, ) # Iterate over all components that can be accelerated diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py index 7aff1a79d1..d0bd5ed3b8 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py @@ -41,5 +41,20 @@ def inplace_op(*args, **kwargs): replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce) +@register_decomposition(aten.std, registry=DECOMPOSITIONS) +def std_replacement(*args, **kwargs) -> torch.Tensor: + return torch.sqrt(torch.var(*args, **kwargs)) + + +@register_decomposition(aten.rsqrt, registry=DECOMPOSITIONS) +def rsqrt_replacement(*args, **kwargs) -> torch.Tensor: + return torch.reciprocal(torch.sqrt(*args, **kwargs)) + + +@register_decomposition(aten.alias, registry=DECOMPOSITIONS) +def alias_replacement(x: torch.Tensor) -> torch.Tensor: + return x + + def get_decompositions(): return DECOMPOSITIONS diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py index 1885d18705..b4d1b18db9 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py @@ -1,62 +1,159 @@ -from typing import Dict, Optional, Sequence +import logging +from typing import Dict, List, Optional, Sequence import torch -from torch_tensorrt.dynamo.backend._defaults import MAX_NUM_TRT_ENGINES -from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition +from torch.fx.graph_module import GraphModule +from torch.fx.node import _get_qualified_name from torch.fx.passes.operator_support import OperatorSupport from torch_tensorrt.fx.converter_registry import CONVERTERS +logger = logging.getLogger(__name__) + + +class TRTPartitioner(CapabilityBasedPartitioner): + """Partitioner to split an FX graph into subgraphs based on operator support + + Args: + graph_module: FX GraphModule to partition + operator_support: OperatorSupport class describing allowed operators + non_compute_ops: Operators which are not considered computational (e.g. getattr) + allowed_single_node_partition_ops: Nodes which can be included in single-node partitons. + Generally useful for module-level exclusion ops which are intensive despite being single functions + min_block_size: Minimum number of computational operators per block + Returns: + torch.fx.GraphModule + """ + + def __init__( + self, + graph_module: GraphModule, + operator_support: OperatorSupport, + *, + non_compute_ops: Optional[Sequence[str]] = None, + allowed_single_node_partition_ops: Optional[Sequence[str]] = None, + min_block_size=MIN_BLOCK_SIZE, + ) -> None: + super().__init__( + graph_module, + operator_support, + allows_single_node_partition=True, + non_compute_ops=non_compute_ops, + allowed_single_node_partition_ops=allowed_single_node_partition_ops, + ) + + self.min_block_size = min_block_size + + def propose_partitions(self) -> List[Partition]: + # Propose partitions using the default, then refine the results + initial_proposed_partitions = super().propose_partitions() + partitions = {i: part for i, part in enumerate(initial_proposed_partitions)} + + # For each partition, determine whether or not the number of computational operators + # exceeds the threshold, and if not, remove that partition + partitions_to_remove = {} + for id, partition in partitions.items(): + default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"} + non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops)) + exempted_partition = False + + compute_node_count = 0 + for node in partition.nodes: + # Partitions are exempted from min_block_size if they contain an allowed single-node op + if ( + node.op == "call_function" + and _get_qualified_name(node.target) + in self.allowed_single_node_partition_ops + ): + exempted_partition = True + break + elif ( + node.op == "call_function" + and _get_qualified_name(node.target) not in non_compute_ops + ): + compute_node_count += 1 + + if compute_node_count < self.min_block_size and not exempted_partition: + partitions_to_remove[id] = compute_node_count + + # Remove any nodes violating the criteria specified by the user + for id, count in partitions_to_remove.items(): + logger.debug( + f"Removing partition which has {count} < {self.min_block_size} computational operators" + ) + del partitions[id] + + return [partitions[k] for k in sorted(partitions.keys())] + + def partition_and_fuse(self) -> GraphModule: + partitions = self.propose_partitions() + fused_gm = self.fuse_partitions(partitions) + return fused_gm + + class TorchTensorRTOperatorSupport(OperatorSupport): """Class to determine whether operators within a module are supported""" - def __init__(self, support_dict=None): + def __init__(self, support_dict=None, torch_executed_ops=set()): super().__init__(support_dict) # Initialize sets of supported/unsupported operators self.supported_operators = set() self.unsupported_operators = set() + self.torch_executed_ops = torch_executed_ops def is_node_supported( self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node ) -> bool: - if node.target in CONVERTERS.keys(): - # If node is a proper computational node, store the operator + node_name = ( + _get_qualified_name(node.target) + if not isinstance(node.target, str) + else node.target + ) + + if ( + node.target in CONVERTERS.keys() + and node_name not in self.torch_executed_ops + ): + # If node is a proper, supported computational node, store the operator if not node.is_impure(): - node_name = node._pretty_print_target(node.target) self.supported_operators.add(node_name) return True else: if not node.is_impure(): - node_name = node._pretty_print_target(node.target) self.unsupported_operators.add(node_name) return False def print_support_overview(self, num_trt_blocks: Optional[int] = None): if num_trt_blocks is not None: - print(f"\nNumber of TensorRT-Accelerated Subgraphs: {num_trt_blocks}") + logger.debug( + f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}" + ) - print("\nSupported Nodes:") + logger.debug("\nSupported Nodes:") for node_name in self.supported_operators: - print("-", node_name) + logger.debug("-", node_name) if len(self.unsupported_operators) != 0: - print("\nUnsupported Nodes:") + logger.debug("\nUnsupported or Excluded Nodes:") for node_name in self.unsupported_operators: - print("-", node_name) - print("\n") + logger.debug("-", node_name) + logger.debug("\n") else: - print("\nAll Nodes Supported\n") + logger.debug("\nAll Nodes Supported\n") def partition( gm: torch.fx.GraphModule, verbose: bool = True, - max_num_trt_engines: int = MAX_NUM_TRT_ENGINES, + min_block_size: int = MIN_BLOCK_SIZE, + torch_executed_ops: Sequence[str] = set(), ) -> torch.fx.GraphModule: """Partition an FX GraphModule with aten ops into TRT engines Partitioning is based on converter operator support @@ -64,29 +161,21 @@ def partition( Args: gm: FX GraphModule to partition verbose: Bool representing whether to print operator support - max_num_trt_engines: Maximum number of allowed TRT engines in partitioning + min_block_size: Minimum number of operators per TRT-Engine Block + torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage Returns: torch.fx.GraphModule """ - supported_ops = TorchTensorRTOperatorSupport() - partitioner = CapabilityBasedPartitioner(gm, supported_ops) + supported_ops = TorchTensorRTOperatorSupport(torch_executed_ops=torch_executed_ops) + partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size) - # Determine partitions, and raise error if the degree of partitioning - # exceeds a specified threshold + # Determine partitions based on user specifications and operator support + # Then, fuse partitions and display overview of supported/unsupported operators partitions = partitioner.propose_partitions() - num_blocks = len(partitions) - if num_blocks > max_num_trt_engines: - raise AssertionError( - f"The graph module has {num_blocks} TRT Engines which is larger than the " - + f"threshold={max_num_trt_engines}. Falling back to non-TRT module." - ) - - # Fuse partitions and display overview of supported/unsupported operators fused_graph = partitioner.fuse_partitions(partitions) - num_blocks = len(partitions) if verbose: - supported_ops.print_support_overview(num_blocks) + supported_ops.print_support_overview(len(partitions)) return fused_graph diff --git a/py/torch_tensorrt/dynamo/backend/test/test_lowering.py b/py/torch_tensorrt/dynamo/backend/test/test_lowering.py index d14acb815b..6b7651957f 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_lowering.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_lowering.py @@ -1,12 +1,12 @@ from functools import partial -from utils import fx_dynamo_testing_backend +from utils import lower_graph_testing from torch.testing._internal.common_utils import run_tests, TestCase import torch class TestLowering(TestCase): def test_lowering_inplace_op(self): - class FullySupported(torch.nn.Module): + class InPlace(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -18,35 +18,95 @@ def forward(self, x, y): # Operations expected to be included in the traced graph after decompositions expected_ops = {torch.ops.aten.add.Tensor, torch.ops.aten.relu.default} - # Trace module and set up custom backend to track intermediate graphs - fx_graph = torch.fx.symbolic_trace(FullySupported()) - partitioned_graphs = [] - custom_backend = partial( - fx_dynamo_testing_backend, - store_intermediate_graphs=partitioned_graphs, - ) - - # Invoke compilation - compiled_graph = torch.compile(fx_graph, backend=custom_backend) - compiled_graph( + inputs = [ torch.rand( 5, - ).cuda(), + ), torch.rand( 5, - ).cuda(), + ), + ] + + fx_graph = torch.fx.symbolic_trace(InPlace()) + _, expected_ops_unseen = lower_graph_testing( + fx_graph, inputs, expected_ops=expected_ops, min_block_size=2 ) - # Iterate over intermediate graphs, attempt to match nodes - for fx_module in partitioned_graphs: - for _, submodule in fx_module.named_children(): - for node in submodule.graph.nodes: + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + def test_lowering_alias_replacement(self): + class Alias(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) - if node.op == "call_function" and node.target in expected_ops: - expected_ops.remove(node.target) + def forward(self, x): + y = torch.ops.aten.alias.default(x) + return y + + # Operations expected to be removed in the traced graph after decompositions + unexpected_ops = {torch.ops.aten.alias.default} + + inputs = [ + torch.rand( + 5, + ), + ] + + fx_graph = torch.fx.symbolic_trace(Alias()) + unexpected_ops_seen, _ = lower_graph_testing( + fx_graph, inputs, unexpected_ops=unexpected_ops, min_block_size=1 + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + def test_lowering_rsqrt(self): + class Rsqrt(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x): + y = torch.ops.aten.rsqrt.default(x) + return y + + # Operations expected to be removed in the traced graph after decompositions + expected_ops = {torch.ops.aten.sqrt.default, torch.ops.aten.reciprocal.default} + unexpected_ops = {torch.ops.aten.rsqrt.default} + + inputs = [ + torch.randint( + 1, + 10, + (5,), + ), + ] + + fx_graph = torch.fx.symbolic_trace(Rsqrt()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) - self.assertEqual( - len(expected_ops), 0, "All operators should have been decomposed" + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", ) diff --git a/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py b/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py index fccdd3c32e..fb5430b384 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py @@ -1,5 +1,6 @@ from torch_tensorrt.dynamo.backend.lowering import partition from torch.testing._internal.common_utils import run_tests, TestCase +from utils import lower_graph_testing import torch from copy import deepcopy import numpy as np @@ -16,7 +17,7 @@ def forward(self, x, y): fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) partitioned_graph = partition(deepcopy(fx_graph)) - self.assertEqual( + self.assertEquals( len(list(partitioned_graph.named_children())), 0, "Single operators should not be segmented", @@ -35,8 +36,8 @@ def forward(self, x, y): return pow_ fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) - partitioned_graph = partition(deepcopy(fx_graph)) - self.assertEqual( + partitioned_graph = partition(deepcopy(fx_graph), min_block_size=2) + self.assertEquals( len(list(partitioned_graph.named_children())), 1, "All operators are supported, there should be one segment", @@ -56,13 +57,68 @@ def forward(self, x, y): return pow_ fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) - partitioned_graph = partition(deepcopy(fx_graph)) - self.assertEqual( + partitioned_graph = partition(deepcopy(fx_graph), min_block_size=2) + self.assertEquals( len(list(partitioned_graph.named_children())), 2, "Unsupported operators interleave supported ones, expected 2 segments", ) + def test_partition_partially_supported_with_torch_executed_ops(self): + class PartiallySupportedMultiOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + sum_1 = torch.ops.aten.add.Tensor(x, y) + sum_2 = torch.ops.aten.add.Tensor(x, sum_1) + sum_ = torch.ops.aten.add.Tensor(sum_1, sum_2) + relu_ = torch.ops.aten.relu.default(sum_) + pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) + return pow_ + + unexpected_ops = {torch.ops.aten.add.Tensor} + + inputs = [ + torch.randint( + 1, + 10, + (5,), + ), + torch.randint( + 1, + 10, + (5,), + ), + ] + + fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) + (unexpected_ops_seen, _, partitioned_graphs,) = lower_graph_testing( + fx_graph, + inputs, + unexpected_ops=unexpected_ops, + min_block_size=2, + torch_executed_ops={"torch.ops.aten.add.Tensor"}, + testing_partitioning=True, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEquals( + len(partitioned_graphs), + 1, + "Without control flow breaks, there should only be a single graph", + ) + self.assertEquals( + len(list(partitioned_graphs[0].named_children())), + 1, + "Certain operators are set to run in Torch, expected 1 segment", + ) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/dynamo/backend/test/utils.py b/py/torch_tensorrt/dynamo/backend/test/utils.py index 466a600db8..d59b710faf 100644 --- a/py/torch_tensorrt/dynamo/backend/test/utils.py +++ b/py/torch_tensorrt/dynamo/backend/test/utils.py @@ -1,6 +1,6 @@ from copy import deepcopy from functools import partial -from typing import List, Sequence +from typing import Any, List, Sequence, Set import torch from torch_tensorrt.dynamo.backend.lowering._decompositions import ( get_decompositions, @@ -20,11 +20,15 @@ def fx_dynamo_testing_backend( sample_inputs: Sequence[torch.Tensor], *, store_intermediate_graphs: List, + min_block_size: int = 3, + torch_executed_ops: Sequence[str] = set(), ): """Helper Dynamo backend exclusively for testing""" custom_backend = partial( compile_module_testing, store_intermediate_graphs=store_intermediate_graphs, + min_block_size=min_block_size, + torch_executed_ops=torch_executed_ops, ) # Invoke AOTAutograd to translate operators to aten @@ -41,9 +45,13 @@ def compile_module_testing( example_inputs: Sequence[torch.Tensor], *, store_intermediate_graphs: List, + min_block_size: int = 3, + torch_executed_ops: Sequence[str] = str(), ) -> torch.fx.GraphModule: """Helper compiler exclusively for testing""" - partitioned_module = partition(gm) + partitioned_module = partition( + gm, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops + ) # Store intermediate graph from partitioned module store_intermediate_graphs.append(deepcopy(partitioned_module)) @@ -52,6 +60,18 @@ def compile_module_testing( def same_output_format(trt_output, torch_output, enforce_tensor_type=True): + """Determines whether two objects containing Tensors have the same format + + ((Tensor, Tensor), Tensor) and (Tensor (Tensor, Tensor)) do not + have the same format, for example. + + Args: + trt_output: TensorRT output + torch_output: Torch output + enforce_tensor_type: Whether to enforce Tensor type equivalence + Returns: + bool: True if the outputs have the same format + """ # For each encountered collection type, ensure the torch and trt outputs agree # on type and size, checking recursively through all member elements. if isinstance(trt_output, tuple): @@ -92,3 +112,73 @@ def same_output_format(trt_output, torch_output, enforce_tensor_type=True): return type(trt_output) is type(torch_output) else: return True + + +def lower_graph_testing( + fx_graph: torch.fx.GraphModule, + inputs: Any, + *, + expected_ops: Set = set(), + unexpected_ops: Set = set(), + min_block_size: int = 3, + torch_executed_ops: Sequence[str] = set(), + testing_partitioning: bool = False, +): + """Helper function to assist with graph lowering for testing of Dynamo torch_compile + + Args: + fx_graph: Graph to lower + inputs: Input values to the FX graph + expected_ops: Operations to be expected in the lowered graph + unexpected_ops: Operations not to be expected in the lowered graph + min_block_size: Minimum number of operators per TRT-Engine Block + torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage + testing_partitioning: Whether partitioning is being tested (to analyze only TRT-supported ops) + Returns: + If testing_partitioning: + List[torch.fx.GraphModule], Set, Set: List of partitioned graph outputs, unexpected ops seen, expected ops unseen + Else: + Set, Set: unexpected ops seen and expected ops unseen (If the run was successful, both sets should be empty) + """ + # Trace module and set up custom backend to track intermediate graphs + partitioned_graphs = [] + custom_backend = partial( + fx_dynamo_testing_backend, + store_intermediate_graphs=partitioned_graphs, + min_block_size=min_block_size, + torch_executed_ops=torch_executed_ops, + ) + + # Invoke compilation + compiled_graph = torch.compile(fx_graph, backend=custom_backend) + compiled_graph(*inputs) + + unexpected_ops_seen = set() + expected_ops_seen = set() + + def classify_node(node: torch.fx.Node): + if node.target in unexpected_ops: + unexpected_ops_seen.add(node.target) + elif node.target in expected_ops: + expected_ops_seen.add(node.target) + + # Iterate over intermediate graphs, attempt to match nodes + # If an unexpected or expected op is encountered, register it + for fx_module in partitioned_graphs: + # For each function call in the set of graph nodes, classify the node + for top_level_node in fx_module.graph.nodes: + if top_level_node.op == "call_function" and not testing_partitioning: + classify_node(top_level_node) + elif top_level_node.op == "call_module": + for node in fx_module.get_submodule(top_level_node.target).graph.nodes: + classify_node(node) + + # Return unexpected ops seen and expected ops unseen + # If the run was successful, both sets should be empty + expected_ops_unseen = expected_ops.difference(expected_ops_seen) + + if testing_partitioning: + return unexpected_ops_seen, expected_ops_unseen, partitioned_graphs + + else: + return unexpected_ops_seen, expected_ops_unseen diff --git a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py index 531d0cc317..b86817df56 100644 --- a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py +++ b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py @@ -24,7 +24,6 @@ def test_resnet18(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "ir": ir, - "max_num_trt_engines": 200, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -55,7 +54,6 @@ def test_mobilenet_v2(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "ir": ir, - "max_num_trt_engines": 200, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -86,7 +84,6 @@ def test_efficientnet_b0(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "ir": ir, - "max_num_trt_engines": 200, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -126,7 +123,6 @@ def test_bert_base_uncased(ir): "enabled_precisions": {torch.float}, "truncate_long_and_double": True, "ir": ir, - "max_num_trt_engines": 200, } trt_mod = torchtrt.compile(model, **compile_spec)