Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def aten_acosh(self: TFloat) -> TFloat:
return op.Acosh(self)


@torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True)
@torch_op("aten::add.Tensor", trace_only=True)
def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""

Expand All @@ -148,7 +148,15 @@ def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor:
return op.Add(self, other)


@torch_op(("_operator::add"), trace_only=True)
@torch_op("aten::add.Scalar", trace_only=True)
def aten_add_scalar(self: TTensor, other: float, alpha: float = 1.0) -> TTensor:
"""add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor"""

other = op.Constant(value=ir.tensor(other, dtype=self.dtype))
return aten_add(self, other, alpha=alpha)


@torch_op("_operator::add", trace_only=True)
def operator_add(self: TTensor, other: TTensor) -> TTensor:
return op.Add(self, other)

Expand Down Expand Up @@ -8113,9 +8121,7 @@ def aten_std_mean_correction(
@torch_op(
(
"aten::sub.Tensor",
"aten::sub.Scalar",
"aten::subtract.Tensor",
"aten::subtract.Scalar",
"_operator::sub",
),
trace_only=True,
Expand All @@ -8128,6 +8134,14 @@ def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
return op.Sub(self, other)


@torch_op(("aten::sub.Scalar", "aten::subtract.Scalar"), trace_only=True)
def aten_sub_scalar(self: TTensor, other: float, alpha: float = 1.0) -> TTensor:
"""sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor"""

other = op.Constant(value=ir.tensor(other, dtype=self.dtype))
return aten_sub(self, other, alpha=alpha)


@torch_op(
(
"aten::sub.Tensor",
Expand Down
Loading