From 7c011677dbadc6d704dc8fed73c378cca166513c Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 30 Oct 2023 17:58:11 -0700 Subject: [PATCH 1/6] feat: support ne, ge, and le converters --- .../dynamo/conversion/aten_ops_converters.py | 75 ++++++++++++++++--- .../dynamo/conversion/impl/elementwise/ops.py | 54 +++++++++++++ .../{test_equal_aten.py => test_eq_aten.py} | 41 +++++----- tests/py/dynamo/conversion/test_ge_aten.py | 69 +++++++++++++++++ .../{test_greater_aten.py => test_gt_aten.py} | 33 ++++---- tests/py/dynamo/conversion/test_le_aten.py | 69 +++++++++++++++++ .../{test_less_aten.py => test_lt_aten.py} | 21 +++--- tests/py/dynamo/conversion/test_ne_aten.py | 69 +++++++++++++++++ 8 files changed, 374 insertions(+), 57 deletions(-) rename tests/py/dynamo/conversion/{test_equal_aten.py => test_eq_aten.py} (58%) create mode 100644 tests/py/dynamo/conversion/test_ge_aten.py rename tests/py/dynamo/conversion/{test_greater_aten.py => test_gt_aten.py} (65%) create mode 100644 tests/py/dynamo/conversion/test_le_aten.py rename tests/py/dynamo/conversion/{test_less_aten.py => test_lt_aten.py} (77%) create mode 100644 tests/py/dynamo/conversion/test_ne_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index b05713c360..1044fc78b2 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1701,9 +1701,9 @@ def aten_ops_logical_xor( ) -@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) -def aten_ops_equal( +@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) # type: ignore[misc] +def aten_ops_eq( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], @@ -1720,9 +1720,28 @@ def aten_ops_equal( ) -@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) -def aten_ops_greater( +@dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar) # type: ignore[misc] +def aten_ops_ne( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.ne( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) # type: ignore[misc] +def aten_ops_gt( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], @@ -1739,9 +1758,28 @@ def aten_ops_greater( ) -@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) -def aten_ops_less( +@dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar) # type: ignore[misc] +def aten_ops_ge( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.ge( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) # type: ignore[misc] +def aten_ops_lt( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], @@ -1758,6 +1796,25 @@ def aten_ops_less( ) +@dynamo_tensorrt_converter(torch.ops.aten.le.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.le.Scalar) # type: ignore[misc] +def aten_ops_le( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.le( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + def conv_param_validator(conv_node: Node) -> bool: return conv_node.args[7] in ([0], [0, 0], [0, 0, 0]) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index 9f1143959f..8a0a3f43b3 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -2,6 +2,7 @@ import numpy as np import tensorrt as trt +import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -441,6 +442,23 @@ def eq( ) +def ne( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float], + rhs_val: Union[TRTTensor, int, float], +) -> TRTTensor: + return impl.unary.logical_not( + ctx, + target, + source_ir, + f"{name}_logical_not", + eq(ctx, target, source_ir, f"{name}_eq", lhs_val, rhs_val), + ) + + def gt( ctx: ConversionContext, target: Target, @@ -460,6 +478,24 @@ def gt( ) +def ge( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float], + rhs_val: Union[TRTTensor, int, float], +) -> TRTTensor: + return logical_or( + ctx, + target, + source_ir, + name, + gt(ctx, target, source_ir, f"{name}_gt", lhs_val, rhs_val), + eq(ctx, target, source_ir, f"{name}_eq", lhs_val, rhs_val), + ) + + def lt( ctx: ConversionContext, target: Target, @@ -477,3 +513,21 @@ def lt( lhs_val, rhs_val, ) + + +def le( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float], + rhs_val: Union[TRTTensor, int, float], +) -> TRTTensor: + return logical_or( + ctx, + target, + source_ir, + name, + lt(ctx, target, source_ir, f"{name}_lt", lhs_val, rhs_val), + eq(ctx, target, source_ir, f"{name}_eq", lhs_val, rhs_val), + ) diff --git a/tests/py/dynamo/conversion/test_equal_aten.py b/tests/py/dynamo/conversion/test_eq_aten.py similarity index 58% rename from tests/py/dynamo/conversion/test_equal_aten.py rename to tests/py/dynamo/conversion/test_eq_aten.py index 7761b31410..17a372182c 100644 --- a/tests/py/dynamo/conversion/test_equal_aten.py +++ b/tests/py/dynamo/conversion/test_eq_aten.py @@ -2,7 +2,6 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt import Input from .harness import DispatchTestCase @@ -10,56 +9,58 @@ class TestEqualConverter(DispatchTestCase): @parameterized.expand( [ - ("2d", (2, 1)), - ("3d", (2, 1, 2)), + ("2d", (5, 3)), + ("3d", (5, 3, 2)), ] ) - def test_equal_tensor(self, _, shape): - class equal(nn.Module): + def test_eq_tensor(self, _, shape): + class eq(nn.Module): def forward(self, lhs_val, rhs_val): return torch.ops.aten.eq.Tensor(lhs_val, rhs_val) - inputs = [torch.randn(shape), torch.randn(shape)] + inputs = [ + torch.randint(0, 3, shape, dtype=torch.int32), + torch.randint(0, 3, shape, dtype=torch.int32), + ] self.run_test( - equal(), + eq(), inputs, output_dtypes=[torch.bool], ) @parameterized.expand( [ - ("2d", (2, 1), 1), - ("3d", (2, 1, 2), 2.0), + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), ] ) - def test_equal_tensor_scalar(self, _, shape, scalar): - class equal(nn.Module): + def test_eq_tensor_scalar(self, _, shape, scalar): + class eq(nn.Module): def forward(self, lhs_val): return torch.ops.aten.eq.Tensor(lhs_val, torch.tensor(scalar)) - inputs = [torch.randn(shape)] + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] self.run_test( - equal(), + eq(), inputs, output_dtypes=[torch.bool], ) @parameterized.expand( [ - ("2d", (2, 1), 1), - ("3d", (2, 1, 2), 2.0), + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), ] ) - def test_equal_scalar(self, _, shape, scalar): - class equal(nn.Module): + def test_eq_scalar(self, _, shape, scalar): + class eq(nn.Module): def forward(self, lhs_val): return torch.ops.aten.eq.Scalar(lhs_val, scalar) - inputs = [torch.randn(shape)] + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] self.run_test( - equal(), + eq(), inputs, - # expected_ops={torch.ops.aten.eq.Scalar}, output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_ge_aten.py b/tests/py/dynamo/conversion/test_ge_aten.py new file mode 100644 index 0000000000..6b1ee6d440 --- /dev/null +++ b/tests/py/dynamo/conversion/test_ge_aten.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestGtConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_ge_tensor(self, _, shape): + class ge(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.ge.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 3, shape, dtype=torch.int32), + torch.randint(0, 3, shape, dtype=torch.int32), + ] + self.run_test( + ge(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_ge_tensor_scalar(self, _, shape, scalar): + class ge(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.ge.Tensor(lhs_val, torch.tensor(scalar)) + + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] + self.run_test( + ge(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_ge_scalar(self, _, shape, scalar): + class ge(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.ge.Scalar(lhs_val, scalar) + + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] + self.run_test( + ge(), + inputs, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_greater_aten.py b/tests/py/dynamo/conversion/test_gt_aten.py similarity index 65% rename from tests/py/dynamo/conversion/test_greater_aten.py rename to tests/py/dynamo/conversion/test_gt_aten.py index 230fff23d8..8d9ae24f80 100644 --- a/tests/py/dynamo/conversion/test_greater_aten.py +++ b/tests/py/dynamo/conversion/test_gt_aten.py @@ -2,62 +2,61 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt import Input from .harness import DispatchTestCase -class TestGreaterConverter(DispatchTestCase): +class TestGtConverter(DispatchTestCase): @parameterized.expand( [ - ("2d", (2, 1)), - ("3d", (2, 1, 2)), + ("2d", (5, 3)), + ("3d", (5, 3, 2)), ] ) - def test_greater_tensor(self, _, shape): - class greater(nn.Module): + def test_gt_tensor(self, _, shape): + class gt(nn.Module): def forward(self, lhs_val, rhs_val): return torch.ops.aten.gt.Tensor(lhs_val, rhs_val) inputs = [torch.randn(shape), torch.randn(shape)] self.run_test( - greater(), + gt(), inputs, output_dtypes=[torch.bool], ) @parameterized.expand( [ - ("2d", (2, 1), 1), - ("3d", (2, 1, 2), 2.0), + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), ] ) - def test_greater_tensor_scalar(self, _, shape, scalar): - class greater(nn.Module): + def test_gt_tensor_scalar(self, _, shape, scalar): + class gt(nn.Module): def forward(self, lhs_val): return torch.ops.aten.gt.Tensor(lhs_val, torch.tensor(scalar)) inputs = [torch.randn(shape)] self.run_test( - greater(), + gt(), inputs, output_dtypes=[torch.bool], ) @parameterized.expand( [ - ("2d", (2, 1), 1), - ("3d", (2, 1, 2), 2.0), + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), ] ) - def test_greater_scalar(self, _, shape, scalar): - class greater(nn.Module): + def test_gt_scalar(self, _, shape, scalar): + class gt(nn.Module): def forward(self, lhs_val): return torch.ops.aten.gt.Scalar(lhs_val, scalar) inputs = [torch.randn(shape)] self.run_test( - greater(), + gt(), inputs, output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_le_aten.py b/tests/py/dynamo/conversion/test_le_aten.py new file mode 100644 index 0000000000..373384c6f9 --- /dev/null +++ b/tests/py/dynamo/conversion/test_le_aten.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestLeConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_le_tensor(self, _, shape): + class le(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.lt.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 3, shape, dtype=torch.int32), + torch.randint(0, 3, shape, dtype=torch.int32), + ] + self.run_test( + le(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_le_tensor_scalar(self, _, shape, scalar): + class le(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.lt.Tensor(lhs_val, torch.tensor(scalar)) + + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] + self.run_test( + le(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_le_scalar(self, _, shape, scalar): + class le(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.lt.Scalar(lhs_val, scalar) + + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] + self.run_test( + le(), + inputs, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_less_aten.py b/tests/py/dynamo/conversion/test_lt_aten.py similarity index 77% rename from tests/py/dynamo/conversion/test_less_aten.py rename to tests/py/dynamo/conversion/test_lt_aten.py index 28ca2cb514..89cb7f42c5 100644 --- a/tests/py/dynamo/conversion/test_less_aten.py +++ b/tests/py/dynamo/conversion/test_lt_aten.py @@ -2,26 +2,25 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt import Input from .harness import DispatchTestCase -class TestLessConverter(DispatchTestCase): +class TestLtConverter(DispatchTestCase): @parameterized.expand( [ ("2d", (2, 1)), ("3d", (2, 1, 2)), ] ) - def test_less_tensor(self, _, shape): - class less(nn.Module): + def test_lt_tensor(self, _, shape): + class lt(nn.Module): def forward(self, lhs_val, rhs_val): return torch.ops.aten.lt.Tensor(lhs_val, rhs_val) inputs = [torch.randn(shape), torch.randn(shape)] self.run_test( - less(), + lt(), inputs, output_dtypes=[torch.bool], ) @@ -32,14 +31,14 @@ def forward(self, lhs_val, rhs_val): ("3d", (2, 1, 2), 2.0), ] ) - def test_less_tensor_scalar(self, _, shape, scalar): - class less(nn.Module): + def test_lt_tensor_scalar(self, _, shape, scalar): + class lt(nn.Module): def forward(self, lhs_val): return torch.ops.aten.lt.Tensor(lhs_val, torch.tensor(scalar)) inputs = [torch.randn(shape)] self.run_test( - less(), + lt(), inputs, output_dtypes=[torch.bool], ) @@ -50,14 +49,14 @@ def forward(self, lhs_val): ("3d", (2, 1, 2), 2.0), ] ) - def test_less_scalar(self, _, shape, scalar): - class less(nn.Module): + def test_lt_scalar(self, _, shape, scalar): + class lt(nn.Module): def forward(self, lhs_val): return torch.ops.aten.lt.Scalar(lhs_val, scalar) inputs = [torch.randn(shape)] self.run_test( - less(), + lt(), inputs, output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_ne_aten.py b/tests/py/dynamo/conversion/test_ne_aten.py new file mode 100644 index 0000000000..2450ac0945 --- /dev/null +++ b/tests/py/dynamo/conversion/test_ne_aten.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestNotEqualConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_ne_tensor(self, _, shape): + class ne(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.ne.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 3, shape, dtype=torch.int32), + torch.randint(0, 3, shape, dtype=torch.int32), + ] + self.run_test( + ne(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_ne_tensor_scalar(self, _, shape, scalar): + class ne(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.ne.Tensor(lhs_val, torch.tensor(scalar)) + + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] + self.run_test( + ne(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_ne_scalar(self, _, shape, scalar): + class ne(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.ne.Scalar(lhs_val, scalar) + + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] + self.run_test( + ne(), + inputs, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests() From bae5398792b7b0138fc18fd99683265240c5f1f1 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 31 Oct 2023 13:05:17 -0700 Subject: [PATCH 2/6] update type enforcement --- .../dynamo/conversion/aten_ops_converters.py | 30 +++++++++++++++++++ .../dynamo/conversion/impl/elementwise/ops.py | 26 ++++++++-------- 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 1044fc78b2..0ee666b6ef 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1703,6 +1703,11 @@ def aten_ops_logical_xor( @dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] def aten_ops_eq( ctx: ConversionContext, target: Target, @@ -1722,6 +1727,11 @@ def aten_ops_eq( @dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] def aten_ops_ne( ctx: ConversionContext, target: Target, @@ -1741,6 +1751,11 @@ def aten_ops_ne( @dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] def aten_ops_gt( ctx: ConversionContext, target: Target, @@ -1760,6 +1775,11 @@ def aten_ops_gt( @dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] def aten_ops_ge( ctx: ConversionContext, target: Target, @@ -1779,6 +1799,11 @@ def aten_ops_ge( @dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] def aten_ops_lt( ctx: ConversionContext, target: Target, @@ -1798,6 +1823,11 @@ def aten_ops_lt( @dynamo_tensorrt_converter(torch.ops.aten.le.Tensor) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.le.Scalar) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] def aten_ops_le( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index 8a0a3f43b3..75c190779d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Sequence, Union import numpy as np import tensorrt as trt @@ -428,8 +428,8 @@ def eq( target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: TRTTensor, + rhs_val: Union[TRTTensor, int, float, bool, Sequence[Union[int, float, bool]]], ) -> TRTTensor: return convert_binary_elementwise( ctx, @@ -447,8 +447,8 @@ def ne( target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: TRTTensor, + rhs_val: Union[TRTTensor, int, float, bool, Sequence[Union[int, float, bool]]], ) -> TRTTensor: return impl.unary.logical_not( ctx, @@ -464,8 +464,8 @@ def gt( target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: TRTTensor, + rhs_val: Union[TRTTensor, int, float, Sequence[Union[int, float]]], ) -> TRTTensor: return convert_binary_elementwise( ctx, @@ -483,8 +483,8 @@ def ge( target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: TRTTensor, + rhs_val: Union[TRTTensor, int, float, Sequence[Union[int, float]]], ) -> TRTTensor: return logical_or( ctx, @@ -501,8 +501,8 @@ def lt( target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: TRTTensor, + rhs_val: Union[TRTTensor, int, float, Sequence[Union[int, float]]], ) -> TRTTensor: return convert_binary_elementwise( ctx, @@ -520,8 +520,8 @@ def le( target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: TRTTensor, + rhs_val: Union[TRTTensor, int, float, Sequence[Union[int, float]]], ) -> TRTTensor: return logical_or( ctx, From 58278c6d9ccd819378edcead35fee208aab883f3 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 1 Nov 2023 15:42:57 -0700 Subject: [PATCH 3/6] feat: support more elementwise and unary converters --- .../dynamo/conversion/aten_ops_converters.py | 114 ++++++++++++++++++ .../dynamo/conversion/impl/elementwise/ops.py | 33 +++++ .../dynamo/conversion/impl/unary/ops.py | 13 ++ .../conversion/test_bitwise_and_aten.py | 33 +++++ .../conversion/test_bitwise_not_aten.py | 33 +++++ .../dynamo/conversion/test_bitwise_or_aten.py | 33 +++++ .../conversion/test_bitwise_xor_aten.py | 33 +++++ 7 files changed, 292 insertions(+) create mode 100644 tests/py/dynamo/conversion/test_bitwise_and_aten.py create mode 100644 tests/py/dynamo/conversion/test_bitwise_not_aten.py create mode 100644 tests/py/dynamo/conversion/test_bitwise_or_aten.py create mode 100644 tests/py/dynamo/conversion/test_bitwise_xor_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 0ee666b6ef..24b9b9d2a0 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1701,6 +1701,120 @@ def aten_ops_logical_xor( ) +def bitwise_type_validator(node: Node) -> bool: + targets = [ + torch.ops.aten.bitwise_and.Tensor, + torch.ops.aten.bitwise_or.Tensor, + torch.ops.aten.bitwise_xor.Tensor, + ] + if node.target not in targets: + return False + + lhs_val = node.args[0] + rhs_val = node.args[1] + lhs_meta = lhs_val.meta.get("tensor_meta") + rhs_meta = rhs_val.meta.get("tensor_meta") + + if lhs_meta is None or rhs_meta is None: + return False + + supported_type = [torch.bool, bool] + return lhs_meta.dtype in supported_type and rhs_meta.dtype in supported_type + + +@dynamo_tensorrt_converter(torch.ops.aten.bitwise_and.Tensor, capability_validator=bitwise_type_validator) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.bitwise_and.Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.bitwise_and.Scalar_Tensor) # type: ignore[misc] +def aten_ops_bitwise_and( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.bitwise_and( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.bitwise_or.Tensor, capability_validator=bitwise_type_validator) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.bitwise_or.Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.bitwise_or.Scalar_Tensor) # type: ignore[misc] +def aten_ops_bitwise_or( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.bitwise_or( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.bitwise_xor.Tensor, capability_validator=bitwise_type_validator) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.bitwise_xor.Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.bitwise_xor.Scalar_Tensor) # type: ignore[misc] +def aten_ops_bitwise_xor( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.bitwise_xor( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +def bitwise_not_type_validator(node: Node) -> bool: + val = node.args[0] + val_meta = val.meta.get("tensor_meta") + + if val_meta is None: + return False + + supported_type = [torch.bool, bool] + return val_meta.dtype in supported_type + + +@dynamo_tensorrt_converter(torch.ops.aten.bitwise_not.default, capability_validator=bitwise_not_type_validator) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] +def aten_ops_bitwise_not( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.bitwise_not( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) # type: ignore[misc] @enforce_tensor_types( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index 75c190779d..7db1787301 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -423,6 +423,39 @@ def logical_xor( ) +def bitwise_and( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]], + rhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]], +) -> TRTTensor: + return logical_and(ctx, target, source_ir, f"{name}_logical_and", lhs_val, rhs_val) + + +def bitwise_or( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]], + rhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]], +) -> TRTTensor: + return logical_or(ctx, target, source_ir, f"{name}_logical_or", lhs_val, rhs_val) + + +def bitwise_xor( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]], + rhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]], +) -> TRTTensor: + return logical_xor(ctx, target, source_ir, f"{name}_logical_xor", lhs_val, rhs_val) + + def eq( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index 58c5f6ff4a..3a0fd47ac5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -1,6 +1,7 @@ from typing import Optional import tensorrt as trt +import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -336,6 +337,18 @@ def logical_not( ) +def bitwise_not( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +) -> TRTTensor: + return impl.unary.logical_not( + ctx, target, source_ir, f"{name}_logical_not", input_val + ) + + def sign( ctx: ConversionContext, target: Target, diff --git a/tests/py/dynamo/conversion/test_bitwise_and_aten.py b/tests/py/dynamo/conversion/test_bitwise_and_aten.py new file mode 100644 index 0000000000..73e606cc0e --- /dev/null +++ b/tests/py/dynamo/conversion/test_bitwise_and_aten.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestBitwiseAndConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_bitwise_and_tensor(self, _, shape): + class bitwise_and(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.bitwise_and.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_and(), + inputs, + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_bitwise_not_aten.py b/tests/py/dynamo/conversion/test_bitwise_not_aten.py new file mode 100644 index 0000000000..6dd512ef16 --- /dev/null +++ b/tests/py/dynamo/conversion/test_bitwise_not_aten.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestBitwiseNotConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_bitwise_not_tensor(self, _, shape): + class bitwise_not(nn.Module): + def forward(self, val): + return torch.ops.aten.bitwise_not.default(val) + + inputs = [ + torch.randint(0, 2, shape, dtype=torch.bool), + ] + self.run_test( + bitwise_not(), + inputs, + enable_passes=True, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_bitwise_or_aten.py b/tests/py/dynamo/conversion/test_bitwise_or_aten.py new file mode 100644 index 0000000000..435e426e3d --- /dev/null +++ b/tests/py/dynamo/conversion/test_bitwise_or_aten.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestBitwiseOrConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_bitwise_or_tensor(self, _, shape): + class bitwise_or(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.bitwise_or.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_or(), + inputs, + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_bitwise_xor_aten.py b/tests/py/dynamo/conversion/test_bitwise_xor_aten.py new file mode 100644 index 0000000000..bdaf0aea1a --- /dev/null +++ b/tests/py/dynamo/conversion/test_bitwise_xor_aten.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestBitwiseXorConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_bitwise_xor_tensor(self, _, shape): + class bitwise_xor(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.bitwise_xor.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_xor(), + inputs, + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() From af49b8b8129613b5f65aec3a7b6b9673d457c904 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 9 Nov 2023 15:34:51 -0800 Subject: [PATCH 4/6] fix bugs --- .../dynamo/conversion/aten_ops_converters.py | 139 ++++++++++++------ .../dynamo/conversion/impl/elementwise/ops.py | 39 ++--- .../conversion/test_bitwise_and_aten.py | 40 +++++ .../dynamo/conversion/test_bitwise_or_aten.py | 40 +++++ .../conversion/test_bitwise_xor_aten.py | 40 +++++ 5 files changed, 234 insertions(+), 64 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 24b9b9d2a0..1e2ca9e277 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -754,12 +754,12 @@ def aten_ops_cumsum( ) -@dynamo_tensorrt_converter(torch.ops.aten.tile.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.tile.default) @enforce_tensor_types( { 0: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_tile( ctx: ConversionContext, target: Target, @@ -777,7 +777,7 @@ def aten_ops_tile( ) -@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.permute.default) @enforce_tensor_types( { 0: (TRTTensor,), @@ -1702,29 +1702,63 @@ def aten_ops_logical_xor( def bitwise_type_validator(node: Node) -> bool: - targets = [ + supported_type = [torch.bool, bool] + + tensor_targets = [ torch.ops.aten.bitwise_and.Tensor, torch.ops.aten.bitwise_or.Tensor, torch.ops.aten.bitwise_xor.Tensor, ] - if node.target not in targets: - return False + scalar_targets = [ + torch.ops.aten.bitwise_and.Scalar, + torch.ops.aten.bitwise_or.Scalar, + torch.ops.aten.bitwise_xor.Scalar, + ] + scalar_tensor_targets = [ + torch.ops.aten.bitwise_and.Scalar_Tensor, + torch.ops.aten.bitwise_or.Scalar_Tensor, + torch.ops.aten.bitwise_xor.Scalar_Tensor, + ] - lhs_val = node.args[0] - rhs_val = node.args[1] - lhs_meta = lhs_val.meta.get("tensor_meta") - rhs_meta = rhs_val.meta.get("tensor_meta") + if node.target in tensor_targets: + lhs_val = node.args[0] + rhs_val = node.args[1] + lhs_meta = lhs_val.meta.get("tensor_meta") + rhs_meta = rhs_val.meta.get("tensor_meta") + if lhs_meta is None or rhs_meta is None: + return False + return lhs_meta.dtype in supported_type and rhs_meta.dtype in supported_type - if lhs_meta is None or rhs_meta is None: - return False + elif node.target in scalar_targets: + lhs_val = node.args[0] + rhs_val = node.args[1] + lhs_meta = lhs_val.meta.get("tensor_meta") + if lhs_meta is None: + return False + return lhs_meta.dtype in supported_type and isinstance(rhs_val, bool) - supported_type = [torch.bool, bool] - return lhs_meta.dtype in supported_type and rhs_meta.dtype in supported_type + elif node.target in scalar_tensor_targets: + lhs_val = node.args[0] + rhs_val = node.args[1] + rhs_meta = rhs_val.meta.get("tensor_meta") + if rhs_meta is None: + return False + return isinstance(lhs_val, bool) and rhs_meta.dtype in supported_type + + else: + return False -@dynamo_tensorrt_converter(torch.ops.aten.bitwise_and.Tensor, capability_validator=bitwise_type_validator) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.bitwise_and.Scalar) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.bitwise_and.Scalar_Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_and.Tensor, capability_validator=bitwise_type_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_and.Scalar, capability_validator=bitwise_type_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_and.Scalar_Tensor, + capability_validator=bitwise_type_validator, +) def aten_ops_bitwise_and( ctx: ConversionContext, target: Target, @@ -1742,9 +1776,15 @@ def aten_ops_bitwise_and( ) -@dynamo_tensorrt_converter(torch.ops.aten.bitwise_or.Tensor, capability_validator=bitwise_type_validator) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.bitwise_or.Scalar) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.bitwise_or.Scalar_Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_or.Tensor, capability_validator=bitwise_type_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_or.Scalar, capability_validator=bitwise_type_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_or.Scalar_Tensor, capability_validator=bitwise_type_validator +) def aten_ops_bitwise_or( ctx: ConversionContext, target: Target, @@ -1762,9 +1802,16 @@ def aten_ops_bitwise_or( ) -@dynamo_tensorrt_converter(torch.ops.aten.bitwise_xor.Tensor, capability_validator=bitwise_type_validator) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.bitwise_xor.Scalar) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.bitwise_xor.Scalar_Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_xor.Tensor, capability_validator=bitwise_type_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_xor.Scalar, capability_validator=bitwise_type_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_xor.Scalar_Tensor, + capability_validator=bitwise_type_validator, +) def aten_ops_bitwise_xor( ctx: ConversionContext, target: Target, @@ -1793,12 +1840,14 @@ def bitwise_not_type_validator(node: Node) -> bool: return val_meta.dtype in supported_type -@dynamo_tensorrt_converter(torch.ops.aten.bitwise_not.default, capability_validator=bitwise_not_type_validator) # type: ignore[misc] +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_not.default, capability_validator=bitwise_not_type_validator +) @enforce_tensor_types( { 0: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_bitwise_not( ctx: ConversionContext, target: Target, @@ -1815,13 +1864,13 @@ def aten_ops_bitwise_not( ) -@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) @enforce_tensor_types( { 0: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_eq( ctx: ConversionContext, target: Target, @@ -1839,13 +1888,13 @@ def aten_ops_eq( ) -@dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar) @enforce_tensor_types( { 0: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_ne( ctx: ConversionContext, target: Target, @@ -1863,13 +1912,13 @@ def aten_ops_ne( ) -@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) @enforce_tensor_types( { 0: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_gt( ctx: ConversionContext, target: Target, @@ -1887,13 +1936,13 @@ def aten_ops_gt( ) -@dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar) @enforce_tensor_types( { 0: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_ge( ctx: ConversionContext, target: Target, @@ -1911,13 +1960,13 @@ def aten_ops_ge( ) -@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) @enforce_tensor_types( { 0: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_lt( ctx: ConversionContext, target: Target, @@ -1935,13 +1984,13 @@ def aten_ops_lt( ) -@dynamo_tensorrt_converter(torch.ops.aten.le.Tensor) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.le.Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.le.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.le.Scalar) @enforce_tensor_types( { 0: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_le( ctx: ConversionContext, target: Target, @@ -2191,14 +2240,14 @@ def aten_ops_argmax( ) -@dynamo_tensorrt_converter(torch.ops.aten.addmm.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.addmm.default) @enforce_tensor_types( { 0: (TRTTensor,), 1: (np.ndarray, torch.Tensor, TRTTensor), 2: (np.ndarray, torch.Tensor, TRTTensor), } -) # type: ignore[misc] +) def aten_ops_addmm( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index 7db1787301..06e07eedb1 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -1,7 +1,8 @@ -from typing import Optional, Sequence, Union +from typing import Optional, Union import numpy as np import tensorrt as trt +import torch import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR @@ -371,8 +372,8 @@ def logical_and( target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: Union[TRTTensor, int, float, bool], + rhs_val: Union[TRTTensor, int, float, bool], ) -> TRTTensor: if isinstance(lhs_val, TRTTensor): lhs_val = cast_int_or_float_to_bool(ctx, name, lhs_val) @@ -390,8 +391,8 @@ def logical_or( target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: Union[TRTTensor, int, float, bool], + rhs_val: Union[TRTTensor, int, float, bool], ) -> TRTTensor: if isinstance(lhs_val, TRTTensor): lhs_val = cast_int_or_float_to_bool(ctx, name, lhs_val) @@ -409,8 +410,8 @@ def logical_xor( target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: Union[TRTTensor, int, float, bool], + rhs_val: Union[TRTTensor, int, float, bool], ) -> TRTTensor: if isinstance(lhs_val, TRTTensor): lhs_val = cast_int_or_float_to_bool(ctx, name, lhs_val) @@ -428,8 +429,8 @@ def bitwise_and( target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]], - rhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]], + lhs_val: Union[TRTTensor, int, float, torch.Tensor, bool], + rhs_val: Union[TRTTensor, int, float, torch.Tensor, bool], ) -> TRTTensor: return logical_and(ctx, target, source_ir, f"{name}_logical_and", lhs_val, rhs_val) @@ -439,8 +440,8 @@ def bitwise_or( target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]], - rhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]], + lhs_val: Union[TRTTensor, int, float, torch.Tensor, bool], + rhs_val: Union[TRTTensor, int, float, torch.Tensor, bool], ) -> TRTTensor: return logical_or(ctx, target, source_ir, f"{name}_logical_or", lhs_val, rhs_val) @@ -450,8 +451,8 @@ def bitwise_xor( target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]], - rhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]], + lhs_val: Union[TRTTensor, int, float, torch.Tensor, bool], + rhs_val: Union[TRTTensor, int, float, torch.Tensor, bool], ) -> TRTTensor: return logical_xor(ctx, target, source_ir, f"{name}_logical_xor", lhs_val, rhs_val) @@ -462,7 +463,7 @@ def eq( source_ir: Optional[SourceIR], name: str, lhs_val: TRTTensor, - rhs_val: Union[TRTTensor, int, float, bool, Sequence[Union[int, float, bool]]], + rhs_val: Union[TRTTensor, int, float, torch.Tensor], ) -> TRTTensor: return convert_binary_elementwise( ctx, @@ -481,7 +482,7 @@ def ne( source_ir: Optional[SourceIR], name: str, lhs_val: TRTTensor, - rhs_val: Union[TRTTensor, int, float, bool, Sequence[Union[int, float, bool]]], + rhs_val: Union[TRTTensor, int, float, torch.Tensor], ) -> TRTTensor: return impl.unary.logical_not( ctx, @@ -498,7 +499,7 @@ def gt( source_ir: Optional[SourceIR], name: str, lhs_val: TRTTensor, - rhs_val: Union[TRTTensor, int, float, Sequence[Union[int, float]]], + rhs_val: Union[TRTTensor, int, float, torch.Tensor], ) -> TRTTensor: return convert_binary_elementwise( ctx, @@ -517,7 +518,7 @@ def ge( source_ir: Optional[SourceIR], name: str, lhs_val: TRTTensor, - rhs_val: Union[TRTTensor, int, float, Sequence[Union[int, float]]], + rhs_val: Union[TRTTensor, int, float, torch.Tensor], ) -> TRTTensor: return logical_or( ctx, @@ -535,7 +536,7 @@ def lt( source_ir: Optional[SourceIR], name: str, lhs_val: TRTTensor, - rhs_val: Union[TRTTensor, int, float, Sequence[Union[int, float]]], + rhs_val: Union[TRTTensor, int, float, torch.Tensor], ) -> TRTTensor: return convert_binary_elementwise( ctx, @@ -554,7 +555,7 @@ def le( source_ir: Optional[SourceIR], name: str, lhs_val: TRTTensor, - rhs_val: Union[TRTTensor, int, float, Sequence[Union[int, float]]], + rhs_val: Union[TRTTensor, int, float, torch.Tensor], ) -> TRTTensor: return logical_or( ctx, diff --git a/tests/py/dynamo/conversion/test_bitwise_and_aten.py b/tests/py/dynamo/conversion/test_bitwise_and_aten.py index 73e606cc0e..5c2a78a18a 100644 --- a/tests/py/dynamo/conversion/test_bitwise_and_aten.py +++ b/tests/py/dynamo/conversion/test_bitwise_and_aten.py @@ -28,6 +28,46 @@ def forward(self, lhs_val, rhs_val): enable_passes=True, ) + @parameterized.expand( + [ + ("2d", (5, 3), True), + ("3d", (5, 3, 2), False), + ] + ) + def test_bitwise_and_scalar(self, _, shape, scalar): + class bitwise_and(nn.Module): + def forward(self, tensor): + return torch.ops.aten.bitwise_and.Scalar(tensor, scalar) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_and(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ("2d", (5, 3), True), + ("3d", (5, 3, 2), False), + ] + ) + def test_bitwise_and_scalar_tensor(self, _, shape, scalar): + class bitwise_and(nn.Module): + def forward(self, tensor): + return torch.ops.aten.bitwise_and.Scalar_Tensor(scalar, tensor) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_and(), + inputs, + enable_passes=True, + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_bitwise_or_aten.py b/tests/py/dynamo/conversion/test_bitwise_or_aten.py index 435e426e3d..b5e0200734 100644 --- a/tests/py/dynamo/conversion/test_bitwise_or_aten.py +++ b/tests/py/dynamo/conversion/test_bitwise_or_aten.py @@ -28,6 +28,46 @@ def forward(self, lhs_val, rhs_val): enable_passes=True, ) + @parameterized.expand( + [ + ("2d", (5, 3), True), + ("3d", (5, 3, 2), False), + ] + ) + def test_bitwise_or_scalar(self, _, shape, scalar): + class bitwise_or(nn.Module): + def forward(self, tensor): + return torch.ops.aten.bitwise_or.Scalar(tensor, scalar) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_or(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ("2d", (5, 3), True), + ("3d", (5, 3, 2), False), + ] + ) + def test_bitwise_or_scalar_tensor(self, _, shape, scalar): + class bitwise_or(nn.Module): + def forward(self, tensor): + return torch.ops.aten.bitwise_or.Scalar_Tensor(scalar, tensor) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_or(), + inputs, + enable_passes=True, + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_bitwise_xor_aten.py b/tests/py/dynamo/conversion/test_bitwise_xor_aten.py index bdaf0aea1a..8c1a8136ef 100644 --- a/tests/py/dynamo/conversion/test_bitwise_xor_aten.py +++ b/tests/py/dynamo/conversion/test_bitwise_xor_aten.py @@ -28,6 +28,46 @@ def forward(self, lhs_val, rhs_val): enable_passes=True, ) + @parameterized.expand( + [ + ("2d", (5, 3), True), + ("3d", (5, 3, 2), False), + ] + ) + def test_bitwise_xor_scalar(self, _, shape, scalar): + class bitwise_xor(nn.Module): + def forward(self, tensor): + return torch.ops.aten.bitwise_xor.Scalar(tensor, scalar) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_xor(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ("2d", (5, 3), True), + ("3d", (5, 3, 2), False), + ] + ) + def test_bitwise_xor_scalar_tensor(self, _, shape, scalar): + class bitwise_xor(nn.Module): + def forward(self, tensor): + return torch.ops.aten.bitwise_xor.Scalar_Tensor(scalar, tensor) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_xor(), + inputs, + enable_passes=True, + ) + if __name__ == "__main__": run_tests() From 1645b5efbfda753cb30db8596470834cf5fbb3c7 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 15 Nov 2023 17:36:25 -0800 Subject: [PATCH 5/6] add bool type to convert_binary_elementwise --- py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index 3700242fe7..7fe5c99eb2 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -58,8 +58,8 @@ def convert_binary_elementwise( source_ir: Optional[SourceIR], name: str, op_type: trt.ElementWiseOperation, - lhs_val: Union[int, float, TRTTensor, torch.Tensor], - rhs_val: Union[int, float, TRTTensor, torch.Tensor], + lhs_val: Union[int, float, bool, TRTTensor, torch.Tensor], + rhs_val: Union[int, float, bool, TRTTensor, torch.Tensor], ) -> TRTTensor: """ This function adds a TensorRT elementwise layer. We allow both operands to be From 8b3cbbcf61db83a54570e5bd3bf01a456d2f5dd4 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 17 Nov 2023 13:51:37 -0800 Subject: [PATCH 6/6] fix bugs --- py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index 7fe5c99eb2..8282ee8698 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -120,11 +120,11 @@ def convert_binary_elementwise( # Note that the dtype here is supposed to be the same as the scalar # dtype but we don't have a way to detect whether it makes sense for the # scalar to be float or half. Hence we go with the lhs dtype. - if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)): + if is_lhs_trt_tensor and isinstance(rhs_val, (float, int, bool)): rhs_val = np.array( [rhs_val], dtype=unified_dtype_converter(lhs_dtype, Frameworks.NUMPY) ) - if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)): + if is_rhs_trt_tensor and isinstance(lhs_val, (float, int, bool)): lhs_val = np.array( [lhs_val], dtype=unified_dtype_converter(rhs_dtype, Frameworks.NUMPY) )