diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index b05713c360..1e2ca9e277 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -754,12 +754,12 @@ def aten_ops_cumsum( ) -@dynamo_tensorrt_converter(torch.ops.aten.tile.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.tile.default) @enforce_tensor_types( { 0: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_tile( ctx: ConversionContext, target: Target, @@ -777,7 +777,7 @@ def aten_ops_tile( ) -@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.permute.default) @enforce_tensor_types( { 0: (TRTTensor,), @@ -1701,9 +1701,177 @@ def aten_ops_logical_xor( ) +def bitwise_type_validator(node: Node) -> bool: + supported_type = [torch.bool, bool] + + tensor_targets = [ + torch.ops.aten.bitwise_and.Tensor, + torch.ops.aten.bitwise_or.Tensor, + torch.ops.aten.bitwise_xor.Tensor, + ] + scalar_targets = [ + torch.ops.aten.bitwise_and.Scalar, + torch.ops.aten.bitwise_or.Scalar, + torch.ops.aten.bitwise_xor.Scalar, + ] + scalar_tensor_targets = [ + torch.ops.aten.bitwise_and.Scalar_Tensor, + torch.ops.aten.bitwise_or.Scalar_Tensor, + torch.ops.aten.bitwise_xor.Scalar_Tensor, + ] + + if node.target in tensor_targets: + lhs_val = node.args[0] + rhs_val = node.args[1] + lhs_meta = lhs_val.meta.get("tensor_meta") + rhs_meta = rhs_val.meta.get("tensor_meta") + if lhs_meta is None or rhs_meta is None: + return False + return lhs_meta.dtype in supported_type and rhs_meta.dtype in supported_type + + elif node.target in scalar_targets: + lhs_val = node.args[0] + rhs_val = node.args[1] + lhs_meta = lhs_val.meta.get("tensor_meta") + if lhs_meta is None: + return False + return lhs_meta.dtype in supported_type and isinstance(rhs_val, bool) + + elif node.target in scalar_tensor_targets: + lhs_val = node.args[0] + rhs_val = node.args[1] + rhs_meta = rhs_val.meta.get("tensor_meta") + if rhs_meta is None: + return False + return isinstance(lhs_val, bool) and rhs_meta.dtype in supported_type + + else: + return False + + +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_and.Tensor, capability_validator=bitwise_type_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_and.Scalar, capability_validator=bitwise_type_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_and.Scalar_Tensor, + capability_validator=bitwise_type_validator, +) +def aten_ops_bitwise_and( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.bitwise_and( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_or.Tensor, capability_validator=bitwise_type_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_or.Scalar, capability_validator=bitwise_type_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_or.Scalar_Tensor, capability_validator=bitwise_type_validator +) +def aten_ops_bitwise_or( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.bitwise_or( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_xor.Tensor, capability_validator=bitwise_type_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_xor.Scalar, capability_validator=bitwise_type_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_xor.Scalar_Tensor, + capability_validator=bitwise_type_validator, +) +def aten_ops_bitwise_xor( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.bitwise_xor( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +def bitwise_not_type_validator(node: Node) -> bool: + val = node.args[0] + val_meta = val.meta.get("tensor_meta") + + if val_meta is None: + return False + + supported_type = [torch.bool, bool] + return val_meta.dtype in supported_type + + +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_not.default, capability_validator=bitwise_not_type_validator +) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_bitwise_not( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.bitwise_not( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) @dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) -def aten_ops_equal( +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_eq( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], @@ -1720,9 +1888,38 @@ def aten_ops_equal( ) +@dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_ne( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.ne( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) @dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) -def aten_ops_greater( +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_gt( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], @@ -1739,9 +1936,38 @@ def aten_ops_greater( ) +@dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_ge( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.ge( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) @dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) -def aten_ops_less( +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_lt( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], @@ -1758,6 +1984,30 @@ def aten_ops_less( ) +@dynamo_tensorrt_converter(torch.ops.aten.le.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.le.Scalar) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_le( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.le( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + def conv_param_validator(conv_node: Node) -> bool: return conv_node.args[7] in ([0], [0, 0], [0, 0, 0]) @@ -1990,14 +2240,14 @@ def aten_ops_argmax( ) -@dynamo_tensorrt_converter(torch.ops.aten.addmm.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.addmm.default) @enforce_tensor_types( { 0: (TRTTensor,), 1: (np.ndarray, torch.Tensor, TRTTensor), 2: (np.ndarray, torch.Tensor, TRTTensor), } -) # type: ignore[misc] +) def aten_ops_addmm( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index 3700242fe7..8282ee8698 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -58,8 +58,8 @@ def convert_binary_elementwise( source_ir: Optional[SourceIR], name: str, op_type: trt.ElementWiseOperation, - lhs_val: Union[int, float, TRTTensor, torch.Tensor], - rhs_val: Union[int, float, TRTTensor, torch.Tensor], + lhs_val: Union[int, float, bool, TRTTensor, torch.Tensor], + rhs_val: Union[int, float, bool, TRTTensor, torch.Tensor], ) -> TRTTensor: """ This function adds a TensorRT elementwise layer. We allow both operands to be @@ -120,11 +120,11 @@ def convert_binary_elementwise( # Note that the dtype here is supposed to be the same as the scalar # dtype but we don't have a way to detect whether it makes sense for the # scalar to be float or half. Hence we go with the lhs dtype. - if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)): + if is_lhs_trt_tensor and isinstance(rhs_val, (float, int, bool)): rhs_val = np.array( [rhs_val], dtype=unified_dtype_converter(lhs_dtype, Frameworks.NUMPY) ) - if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)): + if is_rhs_trt_tensor and isinstance(lhs_val, (float, int, bool)): lhs_val = np.array( [lhs_val], dtype=unified_dtype_converter(rhs_dtype, Frameworks.NUMPY) ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index 9f1143959f..06e07eedb1 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -2,6 +2,8 @@ import numpy as np import tensorrt as trt +import torch +import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -370,8 +372,8 @@ def logical_and( target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: Union[TRTTensor, int, float, bool], + rhs_val: Union[TRTTensor, int, float, bool], ) -> TRTTensor: if isinstance(lhs_val, TRTTensor): lhs_val = cast_int_or_float_to_bool(ctx, name, lhs_val) @@ -389,8 +391,8 @@ def logical_or( target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: Union[TRTTensor, int, float, bool], + rhs_val: Union[TRTTensor, int, float, bool], ) -> TRTTensor: if isinstance(lhs_val, TRTTensor): lhs_val = cast_int_or_float_to_bool(ctx, name, lhs_val) @@ -408,8 +410,8 @@ def logical_xor( target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: Union[TRTTensor, int, float, bool], + rhs_val: Union[TRTTensor, int, float, bool], ) -> TRTTensor: if isinstance(lhs_val, TRTTensor): lhs_val = cast_int_or_float_to_bool(ctx, name, lhs_val) @@ -422,13 +424,46 @@ def logical_xor( ) +def bitwise_and( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float, torch.Tensor, bool], + rhs_val: Union[TRTTensor, int, float, torch.Tensor, bool], +) -> TRTTensor: + return logical_and(ctx, target, source_ir, f"{name}_logical_and", lhs_val, rhs_val) + + +def bitwise_or( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float, torch.Tensor, bool], + rhs_val: Union[TRTTensor, int, float, torch.Tensor, bool], +) -> TRTTensor: + return logical_or(ctx, target, source_ir, f"{name}_logical_or", lhs_val, rhs_val) + + +def bitwise_xor( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float, torch.Tensor, bool], + rhs_val: Union[TRTTensor, int, float, torch.Tensor, bool], +) -> TRTTensor: + return logical_xor(ctx, target, source_ir, f"{name}_logical_xor", lhs_val, rhs_val) + + def eq( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: TRTTensor, + rhs_val: Union[TRTTensor, int, float, torch.Tensor], ) -> TRTTensor: return convert_binary_elementwise( ctx, @@ -441,13 +476,30 @@ def eq( ) +def ne( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: TRTTensor, + rhs_val: Union[TRTTensor, int, float, torch.Tensor], +) -> TRTTensor: + return impl.unary.logical_not( + ctx, + target, + source_ir, + f"{name}_logical_not", + eq(ctx, target, source_ir, f"{name}_eq", lhs_val, rhs_val), + ) + + def gt( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: TRTTensor, + rhs_val: Union[TRTTensor, int, float, torch.Tensor], ) -> TRTTensor: return convert_binary_elementwise( ctx, @@ -460,13 +512,31 @@ def gt( ) +def ge( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: TRTTensor, + rhs_val: Union[TRTTensor, int, float, torch.Tensor], +) -> TRTTensor: + return logical_or( + ctx, + target, + source_ir, + name, + gt(ctx, target, source_ir, f"{name}_gt", lhs_val, rhs_val), + eq(ctx, target, source_ir, f"{name}_eq", lhs_val, rhs_val), + ) + + def lt( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: TRTTensor, + rhs_val: Union[TRTTensor, int, float, torch.Tensor], ) -> TRTTensor: return convert_binary_elementwise( ctx, @@ -477,3 +547,21 @@ def lt( lhs_val, rhs_val, ) + + +def le( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: TRTTensor, + rhs_val: Union[TRTTensor, int, float, torch.Tensor], +) -> TRTTensor: + return logical_or( + ctx, + target, + source_ir, + name, + lt(ctx, target, source_ir, f"{name}_lt", lhs_val, rhs_val), + eq(ctx, target, source_ir, f"{name}_eq", lhs_val, rhs_val), + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index 58c5f6ff4a..3a0fd47ac5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -1,6 +1,7 @@ from typing import Optional import tensorrt as trt +import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -336,6 +337,18 @@ def logical_not( ) +def bitwise_not( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +) -> TRTTensor: + return impl.unary.logical_not( + ctx, target, source_ir, f"{name}_logical_not", input_val + ) + + def sign( ctx: ConversionContext, target: Target, diff --git a/tests/py/dynamo/conversion/test_bitwise_and_aten.py b/tests/py/dynamo/conversion/test_bitwise_and_aten.py new file mode 100644 index 0000000000..5c2a78a18a --- /dev/null +++ b/tests/py/dynamo/conversion/test_bitwise_and_aten.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestBitwiseAndConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_bitwise_and_tensor(self, _, shape): + class bitwise_and(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.bitwise_and.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_and(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ("2d", (5, 3), True), + ("3d", (5, 3, 2), False), + ] + ) + def test_bitwise_and_scalar(self, _, shape, scalar): + class bitwise_and(nn.Module): + def forward(self, tensor): + return torch.ops.aten.bitwise_and.Scalar(tensor, scalar) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_and(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ("2d", (5, 3), True), + ("3d", (5, 3, 2), False), + ] + ) + def test_bitwise_and_scalar_tensor(self, _, shape, scalar): + class bitwise_and(nn.Module): + def forward(self, tensor): + return torch.ops.aten.bitwise_and.Scalar_Tensor(scalar, tensor) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_and(), + inputs, + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_bitwise_not_aten.py b/tests/py/dynamo/conversion/test_bitwise_not_aten.py new file mode 100644 index 0000000000..6dd512ef16 --- /dev/null +++ b/tests/py/dynamo/conversion/test_bitwise_not_aten.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestBitwiseNotConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_bitwise_not_tensor(self, _, shape): + class bitwise_not(nn.Module): + def forward(self, val): + return torch.ops.aten.bitwise_not.default(val) + + inputs = [ + torch.randint(0, 2, shape, dtype=torch.bool), + ] + self.run_test( + bitwise_not(), + inputs, + enable_passes=True, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_bitwise_or_aten.py b/tests/py/dynamo/conversion/test_bitwise_or_aten.py new file mode 100644 index 0000000000..b5e0200734 --- /dev/null +++ b/tests/py/dynamo/conversion/test_bitwise_or_aten.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestBitwiseOrConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_bitwise_or_tensor(self, _, shape): + class bitwise_or(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.bitwise_or.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_or(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ("2d", (5, 3), True), + ("3d", (5, 3, 2), False), + ] + ) + def test_bitwise_or_scalar(self, _, shape, scalar): + class bitwise_or(nn.Module): + def forward(self, tensor): + return torch.ops.aten.bitwise_or.Scalar(tensor, scalar) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_or(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ("2d", (5, 3), True), + ("3d", (5, 3, 2), False), + ] + ) + def test_bitwise_or_scalar_tensor(self, _, shape, scalar): + class bitwise_or(nn.Module): + def forward(self, tensor): + return torch.ops.aten.bitwise_or.Scalar_Tensor(scalar, tensor) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_or(), + inputs, + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_bitwise_xor_aten.py b/tests/py/dynamo/conversion/test_bitwise_xor_aten.py new file mode 100644 index 0000000000..8c1a8136ef --- /dev/null +++ b/tests/py/dynamo/conversion/test_bitwise_xor_aten.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestBitwiseXorConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_bitwise_xor_tensor(self, _, shape): + class bitwise_xor(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.bitwise_xor.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_xor(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ("2d", (5, 3), True), + ("3d", (5, 3, 2), False), + ] + ) + def test_bitwise_xor_scalar(self, _, shape, scalar): + class bitwise_xor(nn.Module): + def forward(self, tensor): + return torch.ops.aten.bitwise_xor.Scalar(tensor, scalar) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_xor(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ("2d", (5, 3), True), + ("3d", (5, 3, 2), False), + ] + ) + def test_bitwise_xor_scalar_tensor(self, _, shape, scalar): + class bitwise_xor(nn.Module): + def forward(self, tensor): + return torch.ops.aten.bitwise_xor.Scalar_Tensor(scalar, tensor) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_xor(), + inputs, + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_equal_aten.py b/tests/py/dynamo/conversion/test_eq_aten.py similarity index 58% rename from tests/py/dynamo/conversion/test_equal_aten.py rename to tests/py/dynamo/conversion/test_eq_aten.py index 7761b31410..17a372182c 100644 --- a/tests/py/dynamo/conversion/test_equal_aten.py +++ b/tests/py/dynamo/conversion/test_eq_aten.py @@ -2,7 +2,6 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt import Input from .harness import DispatchTestCase @@ -10,56 +9,58 @@ class TestEqualConverter(DispatchTestCase): @parameterized.expand( [ - ("2d", (2, 1)), - ("3d", (2, 1, 2)), + ("2d", (5, 3)), + ("3d", (5, 3, 2)), ] ) - def test_equal_tensor(self, _, shape): - class equal(nn.Module): + def test_eq_tensor(self, _, shape): + class eq(nn.Module): def forward(self, lhs_val, rhs_val): return torch.ops.aten.eq.Tensor(lhs_val, rhs_val) - inputs = [torch.randn(shape), torch.randn(shape)] + inputs = [ + torch.randint(0, 3, shape, dtype=torch.int32), + torch.randint(0, 3, shape, dtype=torch.int32), + ] self.run_test( - equal(), + eq(), inputs, output_dtypes=[torch.bool], ) @parameterized.expand( [ - ("2d", (2, 1), 1), - ("3d", (2, 1, 2), 2.0), + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), ] ) - def test_equal_tensor_scalar(self, _, shape, scalar): - class equal(nn.Module): + def test_eq_tensor_scalar(self, _, shape, scalar): + class eq(nn.Module): def forward(self, lhs_val): return torch.ops.aten.eq.Tensor(lhs_val, torch.tensor(scalar)) - inputs = [torch.randn(shape)] + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] self.run_test( - equal(), + eq(), inputs, output_dtypes=[torch.bool], ) @parameterized.expand( [ - ("2d", (2, 1), 1), - ("3d", (2, 1, 2), 2.0), + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), ] ) - def test_equal_scalar(self, _, shape, scalar): - class equal(nn.Module): + def test_eq_scalar(self, _, shape, scalar): + class eq(nn.Module): def forward(self, lhs_val): return torch.ops.aten.eq.Scalar(lhs_val, scalar) - inputs = [torch.randn(shape)] + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] self.run_test( - equal(), + eq(), inputs, - # expected_ops={torch.ops.aten.eq.Scalar}, output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_ge_aten.py b/tests/py/dynamo/conversion/test_ge_aten.py new file mode 100644 index 0000000000..6b1ee6d440 --- /dev/null +++ b/tests/py/dynamo/conversion/test_ge_aten.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestGtConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_ge_tensor(self, _, shape): + class ge(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.ge.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 3, shape, dtype=torch.int32), + torch.randint(0, 3, shape, dtype=torch.int32), + ] + self.run_test( + ge(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_ge_tensor_scalar(self, _, shape, scalar): + class ge(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.ge.Tensor(lhs_val, torch.tensor(scalar)) + + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] + self.run_test( + ge(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_ge_scalar(self, _, shape, scalar): + class ge(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.ge.Scalar(lhs_val, scalar) + + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] + self.run_test( + ge(), + inputs, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_greater_aten.py b/tests/py/dynamo/conversion/test_gt_aten.py similarity index 65% rename from tests/py/dynamo/conversion/test_greater_aten.py rename to tests/py/dynamo/conversion/test_gt_aten.py index 230fff23d8..8d9ae24f80 100644 --- a/tests/py/dynamo/conversion/test_greater_aten.py +++ b/tests/py/dynamo/conversion/test_gt_aten.py @@ -2,62 +2,61 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt import Input from .harness import DispatchTestCase -class TestGreaterConverter(DispatchTestCase): +class TestGtConverter(DispatchTestCase): @parameterized.expand( [ - ("2d", (2, 1)), - ("3d", (2, 1, 2)), + ("2d", (5, 3)), + ("3d", (5, 3, 2)), ] ) - def test_greater_tensor(self, _, shape): - class greater(nn.Module): + def test_gt_tensor(self, _, shape): + class gt(nn.Module): def forward(self, lhs_val, rhs_val): return torch.ops.aten.gt.Tensor(lhs_val, rhs_val) inputs = [torch.randn(shape), torch.randn(shape)] self.run_test( - greater(), + gt(), inputs, output_dtypes=[torch.bool], ) @parameterized.expand( [ - ("2d", (2, 1), 1), - ("3d", (2, 1, 2), 2.0), + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), ] ) - def test_greater_tensor_scalar(self, _, shape, scalar): - class greater(nn.Module): + def test_gt_tensor_scalar(self, _, shape, scalar): + class gt(nn.Module): def forward(self, lhs_val): return torch.ops.aten.gt.Tensor(lhs_val, torch.tensor(scalar)) inputs = [torch.randn(shape)] self.run_test( - greater(), + gt(), inputs, output_dtypes=[torch.bool], ) @parameterized.expand( [ - ("2d", (2, 1), 1), - ("3d", (2, 1, 2), 2.0), + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), ] ) - def test_greater_scalar(self, _, shape, scalar): - class greater(nn.Module): + def test_gt_scalar(self, _, shape, scalar): + class gt(nn.Module): def forward(self, lhs_val): return torch.ops.aten.gt.Scalar(lhs_val, scalar) inputs = [torch.randn(shape)] self.run_test( - greater(), + gt(), inputs, output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_le_aten.py b/tests/py/dynamo/conversion/test_le_aten.py new file mode 100644 index 0000000000..373384c6f9 --- /dev/null +++ b/tests/py/dynamo/conversion/test_le_aten.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestLeConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_le_tensor(self, _, shape): + class le(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.lt.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 3, shape, dtype=torch.int32), + torch.randint(0, 3, shape, dtype=torch.int32), + ] + self.run_test( + le(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_le_tensor_scalar(self, _, shape, scalar): + class le(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.lt.Tensor(lhs_val, torch.tensor(scalar)) + + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] + self.run_test( + le(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_le_scalar(self, _, shape, scalar): + class le(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.lt.Scalar(lhs_val, scalar) + + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] + self.run_test( + le(), + inputs, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_less_aten.py b/tests/py/dynamo/conversion/test_lt_aten.py similarity index 77% rename from tests/py/dynamo/conversion/test_less_aten.py rename to tests/py/dynamo/conversion/test_lt_aten.py index 28ca2cb514..89cb7f42c5 100644 --- a/tests/py/dynamo/conversion/test_less_aten.py +++ b/tests/py/dynamo/conversion/test_lt_aten.py @@ -2,26 +2,25 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt import Input from .harness import DispatchTestCase -class TestLessConverter(DispatchTestCase): +class TestLtConverter(DispatchTestCase): @parameterized.expand( [ ("2d", (2, 1)), ("3d", (2, 1, 2)), ] ) - def test_less_tensor(self, _, shape): - class less(nn.Module): + def test_lt_tensor(self, _, shape): + class lt(nn.Module): def forward(self, lhs_val, rhs_val): return torch.ops.aten.lt.Tensor(lhs_val, rhs_val) inputs = [torch.randn(shape), torch.randn(shape)] self.run_test( - less(), + lt(), inputs, output_dtypes=[torch.bool], ) @@ -32,14 +31,14 @@ def forward(self, lhs_val, rhs_val): ("3d", (2, 1, 2), 2.0), ] ) - def test_less_tensor_scalar(self, _, shape, scalar): - class less(nn.Module): + def test_lt_tensor_scalar(self, _, shape, scalar): + class lt(nn.Module): def forward(self, lhs_val): return torch.ops.aten.lt.Tensor(lhs_val, torch.tensor(scalar)) inputs = [torch.randn(shape)] self.run_test( - less(), + lt(), inputs, output_dtypes=[torch.bool], ) @@ -50,14 +49,14 @@ def forward(self, lhs_val): ("3d", (2, 1, 2), 2.0), ] ) - def test_less_scalar(self, _, shape, scalar): - class less(nn.Module): + def test_lt_scalar(self, _, shape, scalar): + class lt(nn.Module): def forward(self, lhs_val): return torch.ops.aten.lt.Scalar(lhs_val, scalar) inputs = [torch.randn(shape)] self.run_test( - less(), + lt(), inputs, output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_ne_aten.py b/tests/py/dynamo/conversion/test_ne_aten.py new file mode 100644 index 0000000000..2450ac0945 --- /dev/null +++ b/tests/py/dynamo/conversion/test_ne_aten.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestNotEqualConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_ne_tensor(self, _, shape): + class ne(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.ne.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 3, shape, dtype=torch.int32), + torch.randint(0, 3, shape, dtype=torch.int32), + ] + self.run_test( + ne(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_ne_tensor_scalar(self, _, shape, scalar): + class ne(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.ne.Tensor(lhs_val, torch.tensor(scalar)) + + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] + self.run_test( + ne(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_ne_scalar(self, _, shape, scalar): + class ne(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.ne.Scalar(lhs_val, scalar) + + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] + self.run_test( + ne(), + inputs, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests()