diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 2ec3b8f207..36f2a70f8c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -132,7 +132,7 @@ def aten_acosh(self: TFloat) -> TFloat: return op.Acosh(self) -@torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True) +@torch_op("aten::add.Tensor", trace_only=True) def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" @@ -148,7 +148,15 @@ def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor: return op.Add(self, other) -@torch_op(("_operator::add"), trace_only=True) +@torch_op("aten::add.Scalar", trace_only=True) +def aten_add_scalar(self: TTensor, other: float, alpha: float = 1.0) -> TTensor: + """add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor""" + + other = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_add(self, other, alpha=alpha) + + +@torch_op("_operator::add", trace_only=True) def operator_add(self: TTensor, other: TTensor) -> TTensor: return op.Add(self, other) @@ -8113,9 +8121,7 @@ def aten_std_mean_correction( @torch_op( ( "aten::sub.Tensor", - "aten::sub.Scalar", "aten::subtract.Tensor", - "aten::subtract.Scalar", "_operator::sub", ), trace_only=True, @@ -8128,6 +8134,14 @@ def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: return op.Sub(self, other) +@torch_op(("aten::sub.Scalar", "aten::subtract.Scalar"), trace_only=True) +def aten_sub_scalar(self: TTensor, other: float, alpha: float = 1.0) -> TTensor: + """sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor""" + + other = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_sub(self, other, alpha=alpha) + + @torch_op( ( "aten::sub.Tensor",