From 61e716eadc6f8b3129cbe57749723c1118d647ee Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 12 Jul 2023 13:41:08 -0700 Subject: [PATCH 1/4] feat: Add `_to_copy`, `operator.get` and `clone` - Add ATen converters for key operators in the pipeline of multiple models - Add robust testing and patch issues in interpreter - Add evaluator and casting utilities to the converter utils --- .../dynamo/conversion/_TRTInterpreter.py | 10 ++- .../dynamo/conversion/aten_ops_converters.py | 84 +++++++++++++++++-- .../dynamo/conversion/converter_registry.py | 2 +- .../dynamo/conversion/converter_utils.py | 20 ++++- .../dynamo/conversion/impl/__init__.py | 2 + .../dynamo/conversion/impl/cast.py | 23 +++++ .../conversion/impl/elementwise/base.py | 11 ++- .../dynamo/conversion/impl/evaluators.py | 40 +++++++++ tests/py/dynamo/converters/harness.py | 44 +++++++--- tests/py/dynamo/converters/test_casts.py | 59 +++++++++++++ tests/py/dynamo/converters/test_evaluators.py | 67 +++++++++++++++ 11 files changed, 334 insertions(+), 28 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/cast.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/evaluators.py create mode 100644 tests/py/dynamo/converters/test_casts.py create mode 100644 tests/py/dynamo/converters/test_evaluators.py diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 5338f36876..29485a919b 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -27,6 +27,10 @@ ] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER") +class UnsupportedOperatorException(RuntimeError): + pass + + class TRTInterpreterResult(NamedTuple): engine: Any input_names: Sequence[str] @@ -301,7 +305,7 @@ def call_module( converter = CONVERTERS.get(self._cur_node) if not converter: - raise RuntimeError( + raise UnsupportedOperatorException( f"Conversion of module of type {submod_type} not currently supported!" ) @@ -312,7 +316,7 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any: # TODO: Why is this stateful? We should be able to take in the inputs converter = CONVERTERS.get(self._cur_node) if not converter: - raise RuntimeError( + raise UnsupportedOperatorException( f"Conversion of function {torch.typename(target)} not currently supported!" ) @@ -324,7 +328,7 @@ def call_method(self, target: str, args: Any, kwargs: Any) -> Any: converter = CONVERTERS.get(self._cur_node) if not converter: - raise RuntimeError( + raise UnsupportedOperatorException( f"Conversion of method {target} not currently supported!" ) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 0ef7266624..8d1d9b7ecf 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1,6 +1,8 @@ import logging +import operator from typing import Any, Dict, Optional, Sequence, Tuple, Union +import tensorrt as trt import torch from torch.fx.node import Argument, Node, Target from torch_tensorrt.dynamo._SourceIR import SourceIR @@ -12,8 +14,6 @@ from torch_tensorrt.fx.converters import acc_ops_converters from torch_tensorrt.fx.types import TRTNetwork, TRTTensor -import tensorrt as trt - from .converter_registry import dynamo_tensorrt_converter _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -76,13 +76,13 @@ def aten_ops_div( kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32 ): kwargs_new["input"] = cast_trt_tensor( - network, kwargs_new["input"], trt.float32, name + network, kwargs_new["input"], trt.float32, name, target ) elif isinstance(args[1], TRTTensor) and ( kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32 ): kwargs_new["other"] = cast_trt_tensor( - network, kwargs_new["other"], trt.float32, name + network, kwargs_new["other"], trt.float32, name, target ) rounding_mode = kwargs.get("rounding_mode") if rounding_mode is None: @@ -101,7 +101,7 @@ def aten_ops_div( ) -def embedding_param_validator(embedding_node: Node): +def embedding_param_validator(embedding_node: Node) -> bool: scale_grad_by_freq = args_bounds_check(embedding_node.args, 3) sparse = args_bounds_check(embedding_node.args, 4) @@ -365,3 +365,77 @@ def aten_ops_permute( args[0], args[1], ) + + +def to_copy_dtype_validator(to_copy_node: Node) -> bool: + allowed_casts = {torch.float, torch.int32, torch.bool, torch.int8, torch.float16} + + # Validate input node has convertible kwargs + if "dtype" in to_copy_node.kwargs: + if to_copy_node.kwargs["dtype"] in allowed_casts: + return True + else: + _LOGGER.debug( + f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}" + ) + return False + else: + _LOGGER.debug( + f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}" + ) + return False + + +@dynamo_tensorrt_converter( + torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator +) +def aten_ops_to_copy_dtype( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.cast.to_copy( + network, + target, + SourceIR.ATEN, + name, + args[0], + kwargs["dtype"], + ) + + +@dynamo_tensorrt_converter(operator.getitem) +def operator_getitem( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.evaluators.getitem( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.clone.default) +def aten_ops_clone( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.evaluators.clone( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_registry.py b/py/torch_tensorrt/dynamo/conversion/converter_registry.py index 493773bbde..b09bf61418 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_registry.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_registry.py @@ -66,7 +66,7 @@ def dynamo_tensorrt_converter( enabled: bool = True, capability_validator: Optional[Callable[[Node], bool]] = None, priority: ConverterPriority = ConverterPriority.STANDARD, -) -> Callable[[Any], Any]: +) -> Callable[[Any], Union[TRTTensor, Sequence[TRTTensor]]]: """Decorator for Dynamo TensorRT Converter Registers the decorated function in the DYNAMO_ATEN_CONVERTERS registry diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 3d32b25f63..bbb57b9da7 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,15 +1,18 @@ import logging import re -from typing import List +from typing import List, Optional import tensorrt as trt import torch +from torch.fx.node import Target, _get_qualified_name from torch_tensorrt.fx.converters.converter_utils import ( Frameworks, unified_dtype_converter, ) from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor +from .._SourceIR import SourceIR + _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -71,6 +74,8 @@ def cast_trt_tensor( input_val: TRTTensor, dtype: TRTDataType, name: str, + target: Target = "", + source_ir: Optional[SourceIR] = None, ) -> TRTTensor: """ Given a TRT Tensor, convert that Tensor to the specified dtype @@ -78,17 +83,26 @@ def cast_trt_tensor( Args: network (TRTNetwork): A TensorRT network input_val (TRTTensor): A TRT Tensor to cast to a new data type - dtype (TRTDataType): The TRTDataType to cast the input Tensor to + dtype (TRTDataType, torch.dtype, np.dtype): The data type to cast the input Tensor to name (str): Name of the calling layer + target (Target): Target of calling node + source_ir (SourceIR): SourceIR of calling converter Returns: A TensorRT ITensor which has been casted to the specified dtype """ trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT) if input_val.dtype != trt_dtype: + source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN + target_name = ( + f"{source_ir}_ops{'.' + target if target else ''}" + if (isinstance(target, str)) + else f"{source_ir}_ops.{_get_qualified_name(target)}" + ) + identity_layer = network.add_identity(input_val) identity_layer.set_output_type(0, trt_dtype) - identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - {name}" + identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} -{name}-[{target_name}]-[{name}]" return identity_layer.get_output(0) else: return input_val diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index b402240b84..611dc630fa 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -2,9 +2,11 @@ from . import ( activation, + cast, condition, elementwise, embedding, + evaluators, matmul, normalization, permutation, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cast.py b/py/torch_tensorrt/dynamo/conversion/impl/cast.py new file mode 100644 index 0000000000..68899de766 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/cast.py @@ -0,0 +1,23 @@ +from typing import Optional + +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor +from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor + + +def to_copy( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dtype: TRTDataType, +) -> TRTTensor: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"to_copy received input {input} that is not a TensorRT ITensor" + ) + + casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir) + return casted_tensor diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index c4cc744aa9..9ae7859fdc 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -2,6 +2,7 @@ import warnings from typing import Any, Callable, Optional, Union +import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR @@ -15,8 +16,6 @@ from torch_tensorrt.fx.types import TRTElementWiseOp, TRTNetwork, TRTTensor from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter -import tensorrt as trt - def get_python_op_from_trt_elementwise_op( trt_op: TRTElementWiseOp, @@ -132,9 +131,13 @@ def convert_binary_elementwise( trt_promoted_type = unified_dtype_converter(promoted_type, Frameworks.TRT) if trt_promoted_type != lhs_val.dtype: - lhs_val = cast_trt_tensor(network, lhs_val, trt_promoted_type, name) + lhs_val = cast_trt_tensor( + network, lhs_val, trt_promoted_type, name, target, source_ir + ) if trt_promoted_type != rhs_val.dtype: - rhs_val = cast_trt_tensor(network, rhs_val, trt_promoted_type, name) + rhs_val = cast_trt_tensor( + network, rhs_val, trt_promoted_type, name, target, source_ir + ) # Check the limitation in the doc string. if network.has_implicit_batch_dimension: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/evaluators.py b/py/torch_tensorrt/dynamo/conversion/impl/evaluators.py new file mode 100644 index 0000000000..cb61fb6158 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/evaluators.py @@ -0,0 +1,40 @@ +import logging +import operator +from typing import Optional, Sequence + +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + +LOGGER: logging.Logger = logging.getLogger(__name__) + + +def getitem( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: Sequence[TRTTensor], + index: int, +) -> TRTTensor: + LOGGER.debug(f"Evaluating getitem on object with name: {name}") + + # Directly index the input sequence and return the value + return operator.getitem(input, index) + + +def clone( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"clone received input {input} that is not a TensorRT ITensor" + ) + + LOGGER.debug(f"Evaluating clone on object with name: {name}") + + return input diff --git a/tests/py/dynamo/converters/harness.py b/tests/py/dynamo/converters/harness.py index 5634e37a30..f6ff25fb77 100644 --- a/tests/py/dynamo/converters/harness.py +++ b/tests/py/dynamo/converters/harness.py @@ -1,11 +1,17 @@ +import logging import time import unittest -import torch -import logging from typing import Callable, List, Optional, Set, Tuple -from torch.testing._internal.common_utils import TestCase +import torch import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer +from torch.fx.passes.infra.pass_base import PassResult +from torch.testing._internal.common_utils import TestCase +from torch_tensorrt import Input + +# Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry +from torch_tensorrt.dynamo.conversion import TRTInterpreter +from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule from torch_tensorrt.fx.passes.lower_basic_pass_aten import ( compose_bmm, compose_chunk, @@ -18,15 +24,8 @@ replace_transpose_mm_op_with_linear, run_const_fold, ) -from torch.fx.passes.infra.pass_base import PassResult from torch_tensorrt.fx.passes.pass_utils import chain_passes -# Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry -from torch_tensorrt.dynamo.conversion import TRTInterpreter -from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule -from torch_tensorrt import Input - - _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -217,6 +216,7 @@ def generate_graph( expected_ops: Set[Callable], unexpected_ops: Optional[Set[Callable]] = None, customized_passes: List[Callable] = None, + disable_passes: bool = False, ): # Torchdynamo+aot proxytensor tracer # Below are common passes @@ -234,6 +234,10 @@ def generate_graph( # Combine with customized passes specific to any model if customized_passes: passes_list.extend(customized_passes) + + if disable_passes: + passes_list = [] + fx_module, _ = aten_tracer.trace(mod, original_inputs) for passes in passes_list: pr: PassResult = passes(fx_module) @@ -261,9 +265,17 @@ def run_test( atol=1e-03, precision=torch.float, check_dtype=True, + disable_passes=False, ): mod.eval() - mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None) + mod = self.generate_graph( + mod, + inputs, + expected_ops, + unexpected_ops, + None, + disable_passes=disable_passes, + ) if apply_passes is not None: pass_tracer = chain_passes(*apply_passes) @@ -293,10 +305,18 @@ def run_test_with_dynamic_shape( unexpected_ops=None, rtol=1e-03, atol=1e-03, + disable_passes=False, ): mod.eval() inputs = [spec.example_tensor("opt_shape") for spec in input_specs] - mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None) + mod = self.generate_graph( + mod, + inputs, + expected_ops, + unexpected_ops, + None, + disable_passes=disable_passes, + ) interp = TRTInterpreter( mod, diff --git a/tests/py/dynamo/converters/test_casts.py b/tests/py/dynamo/converters/test_casts.py new file mode 100644 index 0000000000..804ecae6d5 --- /dev/null +++ b/tests/py/dynamo/converters/test_casts.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +from harness import DispatchTestCase +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.conversion.trt_interpreter import ( + UnsupportedOperatorException, +) + + +class TestToCopyConverter(DispatchTestCase): + def test_to_copy_half(self): + class ToCopyHalf(nn.Module): + def forward(self, x): + y = x.to(dtype=torch.half) + return y + + inputs = [torch.rand((1, 3, 10))] + self.run_test( + ToCopyHalf(), + inputs, + expected_ops={torch.ops.aten._to_copy.default}, + precision=torch.half, + disable_passes=True, + ) + + def test_to_copy_float(self): + class ToCopyFloat(nn.Module): + def forward(self, x): + y = x.to(dtype=torch.float) + return y + + inputs = [torch.rand((1, 3, 10)).half()] + self.run_test( + ToCopyFloat(), + inputs, + expected_ops={torch.ops.aten._to_copy.default}, + precision=torch.float, + disable_passes=True, + ) + + def test_to_copy_unsupported(self): + class ToCopy64Bit(nn.Module): + def forward(self, x): + y = x.to(dtype=torch.int64) + return y + + inputs = [torch.randn((1, 3, 10)).int()] + + with self.assertRaises(UnsupportedOperatorException): + self.run_test( + ToCopy64Bit(), + inputs, + expected_ops={torch.ops.aten._to_copy.default}, + disable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_evaluators.py b/tests/py/dynamo/converters/test_evaluators.py new file mode 100644 index 0000000000..cf42009495 --- /dev/null +++ b/tests/py/dynamo/converters/test_evaluators.py @@ -0,0 +1,67 @@ +import operator +import unittest + +import torch +import torch.nn as nn +from harness import DispatchTestCase +from torch.testing._internal.common_utils import run_tests + + +class TestCloneConverter(DispatchTestCase): + def test_clone_contiguous(self): + class Clone(nn.Module): + def forward(self, x): + y = torch.clone(x, memory_format=torch.contiguous_format) + return y + 1 + + inputs = [torch.randn((1, 3, 10))] + self.run_test( + Clone(), + inputs, + expected_ops={torch.ops.aten.clone.default}, + disable_passes=True, + ) + + def test_clone_regular(self): + class Clone(nn.Module): + def forward(self, x): + y = torch.clone(x) + return y + 1 + + inputs = [torch.randn((8, 2, 10))] + self.run_test( + Clone(), + inputs, + expected_ops={torch.ops.aten.clone.default}, + disable_passes=True, + ) + + +# TODO: Switch this test back to self.run_test once an implementation exists +# for a converter that returns a list, such as aten.split +@unittest.skip("Pending aten.split converter. Currently tested by E2E") +class TestGetItemConverter(DispatchTestCase): + def test_getitem(self): + class GetItem(nn.Module): + def forward(self, x): + lis = torch.split(x, 5) + b = operator.getitem(lis, 0) + c = operator.getitem(lis, 1) + d = b + c + return d + + inputs = [ + torch.randn((3, 3, 10)), + torch.randn((3, 3, 10)), + torch.randn((3, 3, 10)), + ] + self.run_test( + GetItem(), + inputs, + expected_ops={operator.getitem}, + disable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() From 31c1cfd93320a801eddab437b3b85d8e4673753a Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 2 Aug 2023 15:02:54 -0700 Subject: [PATCH 2/4] fix: Add temporary workaround for precisions - torch compile precisions are currently not being reflected due to recent API changes. This update honors specified precisions --- py/torch_tensorrt/dynamo/utils.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index cec328e84f..398af78788 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -5,9 +5,11 @@ from typing import Any, Callable, Dict, Optional, Sequence import torch +import torch_tensorrt from torch_tensorrt._Device import Device from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import CompilationSettings +from torch_tensorrt.dynamo._defaults import PRECISION from packaging import version @@ -161,6 +163,28 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings: if settings.debug: logger.setLevel(logging.DEBUG) + # TODO: Remove once Dynamo precisions refactoring is complete + if "enabled_precisions" in kwargs: + enabled_precisions = kwargs["enabled_precisions"] + + if ( + torch.float16 in enabled_precisions + or torch_tensorrt.dtype.half in enabled_precisions + ): + settings.precision = torch.float16 + elif ( + torch.float32 in enabled_precisions + or torch_tensorrt.dtype.float in enabled_precisions + ): + settings.precision = torch.float32 + elif len(enabled_precisions) == 0: + logger.info(f"No precision specified, defaulting to {PRECISION}") + settings.precision = PRECISION + else: + raise ValueError( + f"Precision {enabled_precisions} not supported in the Dynamo Path" + ) + # Parse input runtime specification settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime) From 26d1051a44b389c59bd59b9f626f744ee04b2dfc Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 9 Aug 2023 12:04:50 -0700 Subject: [PATCH 3/4] fix: Increase block size to reduce compile time BERT --- tests/py/dynamo/converters/test_casts.py | 4 +--- tests/py/dynamo/models/test_models.py | 15 +++++---------- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/tests/py/dynamo/converters/test_casts.py b/tests/py/dynamo/converters/test_casts.py index 804ecae6d5..4bb05ef463 100644 --- a/tests/py/dynamo/converters/test_casts.py +++ b/tests/py/dynamo/converters/test_casts.py @@ -2,9 +2,7 @@ import torch.nn as nn from harness import DispatchTestCase from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.conversion.trt_interpreter import ( - UnsupportedOperatorException, -) +from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException class TestToCopyConverter(DispatchTestCase): diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index 0fdfcb3fd0..c8f730e2e6 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -1,18 +1,13 @@ -import torch -import timm -import pytest import unittest +import pytest +import timm +import torch import torch_tensorrt as torchtrt import torchvision.models as models - +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity from transformers import BertModel -from torch_tensorrt.dynamo.utils import ( - COSINE_THRESHOLD, - cosine_similarity, -) - assertions = unittest.TestCase() @@ -143,7 +138,7 @@ def test_bert_base_uncased(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "min_block_size": 10, + "min_block_size": 15, "ir": "torch_compile", } trt_mod = torchtrt.compile(model, **compile_spec) From 2e902db5e7046caf79c60908083c3f5469e7f27e Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 11 Aug 2023 16:48:00 -0700 Subject: [PATCH 4/4] fix: Add generic evaluator function --- .../dynamo/conversion/__init__.py | 1 + .../dynamo/conversion/aten_ops_converters.py | 21 +--------- .../dynamo/conversion/converter_utils.py | 12 +++--- .../dynamo/conversion/impl/__init__.py | 1 - .../dynamo/conversion/impl/cast.py | 20 ++++++++++ .../dynamo/conversion/impl/evaluators.py | 40 ------------------- .../dynamo/conversion/op_evaluators.py | 32 +++++++++++++++ tests/py/dynamo/converters/test_casts.py | 30 ++++++++++++++ tests/py/dynamo/converters/test_evaluators.py | 30 -------------- tests/py/dynamo/models/test_models.py | 2 - 10 files changed, 89 insertions(+), 100 deletions(-) delete mode 100644 py/torch_tensorrt/dynamo/conversion/impl/evaluators.py create mode 100644 py/torch_tensorrt/dynamo/conversion/op_evaluators.py diff --git a/py/torch_tensorrt/dynamo/conversion/__init__.py b/py/torch_tensorrt/dynamo/conversion/__init__.py index 4536ff0e7b..9cbfff950e 100644 --- a/py/torch_tensorrt/dynamo/conversion/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/__init__.py @@ -1,4 +1,5 @@ from ._TRTInterpreter import * # noqa: F403 from .aten_ops_converters import * # noqa: F403 from .conversion import * # noqa: F403 +from .op_evaluators import * # noqa: F403 from .truncate_long_and_double import repair_long_or_double_inputs diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 8d1d9b7ecf..75a7782354 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1,5 +1,4 @@ import logging -import operator from typing import Any, Dict, Optional, Sequence, Tuple, Union import tensorrt as trt @@ -406,24 +405,6 @@ def aten_ops_to_copy_dtype( ) -@dynamo_tensorrt_converter(operator.getitem) -def operator_getitem( - network: TRTNetwork, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.evaluators.getitem( - network, - target, - SourceIR.ATEN, - name, - args[0], - args[1], - ) - - @dynamo_tensorrt_converter(torch.ops.aten.clone.default) def aten_ops_clone( network: TRTNetwork, @@ -432,7 +413,7 @@ def aten_ops_clone( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.evaluators.clone( + return impl.cast.clone( network, target, SourceIR.ATEN, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index bbb57b9da7..44bc8b9445 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -4,7 +4,7 @@ import tensorrt as trt import torch -from torch.fx.node import Target, _get_qualified_name +from torch.fx.node import Target from torch_tensorrt.fx.converters.converter_utils import ( Frameworks, unified_dtype_converter, @@ -12,6 +12,7 @@ from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor from .._SourceIR import SourceIR +from .converter_registry import ConverterRegistry _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -94,15 +95,12 @@ def cast_trt_tensor( if input_val.dtype != trt_dtype: source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN - target_name = ( - f"{source_ir}_ops{'.' + target if target else ''}" - if (isinstance(target, str)) - else f"{source_ir}_ops.{_get_qualified_name(target)}" - ) + target_str = ConverterRegistry.qualified_name_or_str(target) + target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}" identity_layer = network.add_identity(input_val) identity_layer.set_output_type(0, trt_dtype) - identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} -{name}-[{target_name}]-[{name}]" + identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - [{target_name}]-[{name}]" return identity_layer.get_output(0) else: return input_val diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 611dc630fa..8f7ab1badc 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -6,7 +6,6 @@ condition, elementwise, embedding, - evaluators, matmul, normalization, permutation, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cast.py b/py/torch_tensorrt/dynamo/conversion/impl/cast.py index 68899de766..0c55731169 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cast.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cast.py @@ -1,3 +1,4 @@ +import logging from typing import Optional from torch.fx.node import Target @@ -5,6 +6,8 @@ from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor +LOGGER: logging.Logger = logging.getLogger(__name__) + def to_copy( network: TRTNetwork, @@ -21,3 +24,20 @@ def to_copy( casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir) return casted_tensor + + +def clone( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"clone received input {input} that is not a TensorRT ITensor" + ) + + LOGGER.debug(f"Evaluating clone on object with name: {name}") + + return input diff --git a/py/torch_tensorrt/dynamo/conversion/impl/evaluators.py b/py/torch_tensorrt/dynamo/conversion/impl/evaluators.py deleted file mode 100644 index cb61fb6158..0000000000 --- a/py/torch_tensorrt/dynamo/conversion/impl/evaluators.py +++ /dev/null @@ -1,40 +0,0 @@ -import logging -import operator -from typing import Optional, Sequence - -from torch.fx.node import Target -from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor - -LOGGER: logging.Logger = logging.getLogger(__name__) - - -def getitem( - network: TRTNetwork, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input: Sequence[TRTTensor], - index: int, -) -> TRTTensor: - LOGGER.debug(f"Evaluating getitem on object with name: {name}") - - # Directly index the input sequence and return the value - return operator.getitem(input, index) - - -def clone( - network: TRTNetwork, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input: TRTTensor, -) -> TRTTensor: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"clone received input {input} that is not a TensorRT ITensor" - ) - - LOGGER.debug(f"Evaluating clone on object with name: {name}") - - return input diff --git a/py/torch_tensorrt/dynamo/conversion/op_evaluators.py b/py/torch_tensorrt/dynamo/conversion/op_evaluators.py new file mode 100644 index 0000000000..a546e34305 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/op_evaluators.py @@ -0,0 +1,32 @@ +import logging +import operator +from typing import Dict, Sequence, Tuple, Union + +from torch.fx.node import Argument, Node, Target +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + +from .converter_registry import ConverterRegistry, dynamo_tensorrt_converter + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def getitem_validator(getitem_node: Node) -> bool: + from torch_tensorrt.dynamo.conversion.converter_registry import DYNAMO_CONVERTERS + + # Getitem nodes can only be converted if their parent node also can + return getitem_node.args[0] in DYNAMO_CONVERTERS + + +# TODO: Subsequent evaluators should be registered here with their own validators +@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) +def generic_evaluator( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + _LOGGER.debug( + f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}" + ) + return target(*args) diff --git a/tests/py/dynamo/converters/test_casts.py b/tests/py/dynamo/converters/test_casts.py index 4bb05ef463..3a4fd65610 100644 --- a/tests/py/dynamo/converters/test_casts.py +++ b/tests/py/dynamo/converters/test_casts.py @@ -5,6 +5,36 @@ from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException +class TestCloneConverter(DispatchTestCase): + def test_clone_contiguous(self): + class Clone(nn.Module): + def forward(self, x): + y = torch.clone(x, memory_format=torch.contiguous_format) + return y + 1 + + inputs = [torch.randn((1, 3, 10))] + self.run_test( + Clone(), + inputs, + expected_ops={torch.ops.aten.clone.default}, + disable_passes=True, + ) + + def test_clone_regular(self): + class Clone(nn.Module): + def forward(self, x): + y = torch.clone(x) + return y + 1 + + inputs = [torch.randn((8, 2, 10))] + self.run_test( + Clone(), + inputs, + expected_ops={torch.ops.aten.clone.default}, + disable_passes=True, + ) + + class TestToCopyConverter(DispatchTestCase): def test_to_copy_half(self): class ToCopyHalf(nn.Module): diff --git a/tests/py/dynamo/converters/test_evaluators.py b/tests/py/dynamo/converters/test_evaluators.py index cf42009495..64dd303727 100644 --- a/tests/py/dynamo/converters/test_evaluators.py +++ b/tests/py/dynamo/converters/test_evaluators.py @@ -7,36 +7,6 @@ from torch.testing._internal.common_utils import run_tests -class TestCloneConverter(DispatchTestCase): - def test_clone_contiguous(self): - class Clone(nn.Module): - def forward(self, x): - y = torch.clone(x, memory_format=torch.contiguous_format) - return y + 1 - - inputs = [torch.randn((1, 3, 10))] - self.run_test( - Clone(), - inputs, - expected_ops={torch.ops.aten.clone.default}, - disable_passes=True, - ) - - def test_clone_regular(self): - class Clone(nn.Module): - def forward(self, x): - y = torch.clone(x) - return y + 1 - - inputs = [torch.randn((8, 2, 10))] - self.run_test( - Clone(), - inputs, - expected_ops={torch.ops.aten.clone.default}, - disable_passes=True, - ) - - # TODO: Switch this test back to self.run_test once an implementation exists # for a converter that returns a list, such as aten.split @unittest.skip("Pending aten.split converter. Currently tested by E2E") diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index c8f730e2e6..50d7fcbbd9 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -27,7 +27,6 @@ def test_resnet18(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "min_block_size": 10, "ir": "torch_compile", } @@ -176,7 +175,6 @@ def test_resnet18_half(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "min_block_size": 10, "ir": "torch_compile", }