diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 3607a11361..a66faae0be 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -925,16 +925,21 @@ def aten_atan(self: TFloat) -> TFloat: return op.Atan(self) -@torch_op("aten::atan2") +@torch_op("aten::atan2", trace_only=True) def aten_atan2(self: TFloat, other: TFloat) -> TFloat: """atan2(Tensor self, Tensor other) -> Tensor""" # self is y, and other is x on coordinate slope = op.Div(self, other) atan = op.Atan(slope) + zero = common_ops.constant(0.0, dtype=self.dtype) + pi = common_ops.constant(_MATH_PI, dtype=self.dtype) - second_third_quadrant = op.Where(self > 0.0, atan + _MATH_PI, atan - _MATH_PI) - result = op.Where(other < 0.0, second_third_quadrant, atan) + second_third_quadrant = op.Where(op.Greater(self, zero), atan + pi, atan - pi) + result = op.Where(op.Less(other, zero), second_third_quadrant, atan) + + # Map NaN to 0 to match PyTorch behavior + result = op.Where(op.IsNaN(result), zero, result) return result diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 646a5133fa..0cf8898241 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -578,7 +578,7 @@ def _where_input_wrangler( TorchLibOpInfo("asin", core_ops.aten_asin), TorchLibOpInfo("asinh", core_ops.aten_asinh), TorchLibOpInfo("atan", core_ops.aten_atan), - TorchLibOpInfo("atan2", core_ops.aten_atan2, tolerance={torch.float16: (1e-3, 1e-3)}), + TorchLibOpInfo("atan2", core_ops.aten_atan2), TorchLibOpInfo("atanh", core_ops.aten_atanh), TorchLibOpInfo("atleast_1d", core_ops.aten_atleast_1d).skip( matcher=lambda sample: isinstance(sample.input, (list, tuple)),