diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 1f9ee6fa42..8c8492158f 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -24,6 +24,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils from torch_tensorrt.fx.converters.impl import activation from torch_tensorrt.fx.converters.impl.elementwise import trunc_div +from torch_tensorrt.fx.converters.impl.elementwise import rsqrt _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -300,6 +301,42 @@ def aten_ops_relu( ) +@tensorrt_converter(torch.ops.aten.relu.default) +def aten_ops_relu( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + + return activation.relu( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@tensorrt_converter(torch.ops.aten.rsqrt.default) +def aten_ops_rsqrt( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + + return rsqrt( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + @tensorrt_converter(torch.ops.aten.sub.Tensor) def aten_ops_sub( network: TRTNetwork, diff --git a/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py b/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py index ae44ce838c..8fddb426a6 100644 --- a/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py +++ b/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py @@ -109,3 +109,33 @@ def trunc_div( ) return output + + +def rsqrt( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + + sqrt_trt_output = convert_unary( + network, + target, + source_ir, + f"{name}_sqrt", + trt.UnaryOperation.SQRT, + input, + ) + + output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_output", + trt.ElementWiseOperation.DIV, + 1, + sqrt_trt_output, + ) + + return output diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py new file mode 100644 index 0000000000..3fa27af1a0 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestRSqrtConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim_alpha", (2, 1), 2), + ("3d_dim_alpha", (2, 1, 2), 2), + ] + ) + def test_rsqrt(self, _, x, alpha): + class rsqrt(nn.Module): + def forward(self, input): + return torch.rsqrt(input) + + inputs = [torch.randn(x) + 1] + self.run_test( + rsqrt(), + inputs, + expected_ops={torch.ops.aten.rsqrt.default}, + ) + + +if __name__ == "__main__": + run_tests()