From 537941616187a174d427b3c5abf64dc089b88a8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Lindstr=C3=B6m?= Date: Thu, 13 Feb 2025 15:53:32 +0100 Subject: [PATCH] Arm backend: Add POW operator Implement support for torch.pow in the MI and BI profile of TOSA. For MI, the operator works as Pytorch's reference implementation except for that the base operand cannot be a scalar but must be a tensor. For BI, the exponent operand must be a scalar and a constant value. The base operand must be a tensor. Split the ReplaceScalarWithTensorArgsPass into two subclasses: one for MI and one for BI. For MI, the pow operator's exponent will converted to a tensor in case it is a scalar. For BI, the scalar will be kept, but instead it will be consumed in the InsertTableOpsPass, meaning that the operator will be converted into a table operation with one input and output. This still enforces the exponent to be constant for the BI profile. Change-Id: I464ab91ff46c0a6ad28d0fb84735a403a74e6323 --- backends/arm/_passes/__init__.py | 4 + backends/arm/_passes/arm_pass_manager.py | 13 +- backends/arm/_passes/insert_table_ops.py | 64 ++++++-- backends/arm/_passes/match_arg_ranks_pass.py | 1 + .../replace_scalar_with_tensor_pass.py | 53 +++++++ .../tosa_supported_operators.py | 2 + backends/arm/operators/__init__.py | 1 + backends/arm/operators/op_pow.py | 57 +++++++ .../arm/quantizer/quantization_annotator.py | 1 + backends/arm/test/ops/test_pow.py | 144 ++++++++++++++++++ .../transforms/replace_scalar_with_tensor.py | 20 ++- 11 files changed, 335 insertions(+), 25 deletions(-) create mode 100644 backends/arm/_passes/replace_scalar_with_tensor_pass.py create mode 100644 backends/arm/operators/op_pow.py create mode 100644 backends/arm/test/ops/test_pow.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 1142f5565c0..5af6cecc843 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -41,6 +41,10 @@ from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass # noqa from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa from .remove_clone_pass import RemoveClonePass # noqa +from .replace_scalar_with_tensor_pass import ( # noqa + ReplaceScalarWithTensorArgPassTOSABI, + ReplaceScalarWithTensorArgPassTOSAMI, +) from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa from .size_adjust_conv2d_pass import SizeAdjustConv2DPass # noqa from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index aa5cdece3f6..f56be30083c 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -42,18 +42,17 @@ MatchArgRanksPass, QuantizeOperatorArguments, RemoveClonePass, + ReplaceScalarWithTensorArgPassTOSABI, + ReplaceScalarWithTensorArgPassTOSAMI, RetraceFoldedDtypesPass, ScalarsToAttributePass, SizeAdjustConv2DPass, UnsqueezeBeforeRepeatPass, UnsqueezeScalarPlaceholdersPass, ) + from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform - -from executorch.backends.transforms.replace_scalar_with_tensor import ( - ReplaceScalarWithTensorArgPass, -) from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass from executorch.exir import ExportedProgram from executorch.exir.pass_manager import PassManager @@ -84,7 +83,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset: self.add_pass(CastToInt32Pass()) - self.add_pass(ReplaceScalarWithTensorArgPass()) + self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeOperatorArguments()) self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg] @@ -113,7 +112,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul return self._transform(exported_program.graph_module) def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule: - self.add_pass(ReplaceScalarWithTensorArgPass()) + self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI()) self.add_pass(FuseQuantizedActivationPass()) self.add_pass(RemoveGetItemPass()) self.add_pass(ConvertSplitToSlicePass()) @@ -170,7 +169,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram): ) def transform_for_annotation_pipeline(self, graph_module: GraphModule): - self.add_pass(ReplaceScalarWithTensorArgPass()) + self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) self.add_pass(ScalarsToAttributePass()) self.add_pass(DecomposeLayerNormPass()) self.add_pass(DecomposeVarPass()) diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index 05d37e1e8e9..3732c8a367b 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -5,7 +5,8 @@ # pyre-unsafe -from typing import Callable, Dict +from itertools import chain +from typing import Callable, cast, Dict, Iterator, Set import torch from executorch.backends.arm._passes.arm_pass_utils import create_node @@ -17,7 +18,7 @@ from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule - +from torch.fx.node import Node from torch.library import impl, Library lib = Library("tosa", "DEF") @@ -32,15 +33,13 @@ def _table_impl(*args, **kwargs): # pyre-ignore return args[0].to(dtype=torch.int32) -class InsertTableOpsPass(ExportPass): +class TableOps: """ - For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these - edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target). - When lowering the _table node target_str will be used to find the corresponding torch operator - which will be used to produce the table values in operators/op_table.py. + Helper class for finding the corresponding table operator for a given Node. """ - table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = { + # Targets that follow a straigtforward one-to-one mapping to their table op + unary_table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = { exir_ops.edge.aten.ceil.default: torch.ceil, exir_ops.edge.aten.exp.default: torch.exp, exir_ops.edge.aten.floor.default: torch.floor, @@ -53,9 +52,52 @@ class InsertTableOpsPass(ExportPass): exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish, } + # Targets that must be treated explicitly + special_table_ops: Set[EdgeOpOverload] = { + exir_ops.edge.aten.pow.Tensor_Scalar, + } + + def __init__(self, exported_program: ExportedProgram): + self.exported_program = exported_program + + def __contains__(self, node: Node) -> bool: + return ( + node.target in self.unary_table_ops or node.target in self.special_table_ops + ) + + def __getitem__(self, node: Node): + target = cast(EdgeOpOverload, node.target) + if target in self.unary_table_ops: + return self.unary_table_ops[target] + elif target in self.special_table_ops: + match target: + case exir_ops.edge.aten.pow.Tensor_Scalar: + # Exponent is a constant. Embed it into a lambda. + exp = cast(int, node.args[1]) + return lambda x: torch.pow(x, exp).flatten() + case _: + # Op must be handled if it's inside self.special_ops + raise AssertionError("Unhandled table operation") + else: + raise KeyError("Table op for {target} does not exist") + + @staticmethod + def included_ops() -> Iterator[EdgeOpOverload]: + return chain(TableOps.unary_table_ops, TableOps.special_table_ops) + + +class InsertTableOpsPass(ExportPass): + """ + For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these + edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target). + When lowering the _table node target_str will be used to find the corresponding torch operator + which will be used to produce the table values in operators/op_table.py. + """ + def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() self.exported_program = exported_program + self.table_ops = TableOps(exported_program) def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None: """ @@ -166,7 +208,7 @@ def generate_table_values( def call(self, graph_module: GraphModule) -> PassResult: modified = False for node in graph_module.graph.nodes: - if node.op != "call_function" or node.target not in self.table_ops: + if node.op != "call_function" or node not in self.table_ops: continue input_qparams = node.meta["input_qparams"] output_qparams = node.meta["output_qparams"] @@ -186,7 +228,7 @@ def call(self, graph_module: GraphModule) -> PassResult: # Generate table buffer and how much to lshift the table output. buffer, lshift = self.generate_table_values( - torch_op=self.table_ops[node.target], + torch_op=self.table_ops[node], in_quantargs=input_qparams[0], out_quantargs=output_qparams[0], ) @@ -207,7 +249,9 @@ def call(self, graph_module: GraphModule) -> PassResult: output_node = rescale_node node.replace_all_uses_with(output_node) + graph_module.graph.erase_node(node) + output_node.meta["input_qparams"] = input_qparams output_node.meta["output_qparams"] = output_qparams modified = True diff --git a/backends/arm/_passes/match_arg_ranks_pass.py b/backends/arm/_passes/match_arg_ranks_pass.py index 04fc8d00c70..759f215a034 100644 --- a/backends/arm/_passes/match_arg_ranks_pass.py +++ b/backends/arm/_passes/match_arg_ranks_pass.py @@ -48,6 +48,7 @@ def __init__(self, exported_program): exir_ops.edge.aten.bitwise_right_shift.Tensor, exir_ops.edge.aten.bitwise_left_shift.Tensor, exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.pow.Tensor_Tensor, ] def _match_op_rank(self, graph_module, node, arg, max_rank): diff --git a/backends/arm/_passes/replace_scalar_with_tensor_pass.py b/backends/arm/_passes/replace_scalar_with_tensor_pass.py new file mode 100644 index 00000000000..97e89132979 --- /dev/null +++ b/backends/arm/_passes/replace_scalar_with_tensor_pass.py @@ -0,0 +1,53 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + + +from typing import Dict + +import torch +from executorch.backends.transforms.replace_scalar_with_tensor import ( + ReplaceScalarWithTensorArgPass, +) +from executorch.exir.dialects._ops import ops as exir_ops + +from executorch.exir.dialects.edge._ops import EdgeOpOverload + + +# Operators that are included for both TOSA profiles +_common_ops: Dict[EdgeOpOverload, EdgeOpOverload] = { + exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten.__rshift__.Scalar: exir_ops.edge.aten.bitwise_right_shift.Tensor, + exir_ops.edge.aten.__lshift__.Scalar: exir_ops.edge.aten.bitwise_left_shift.Tensor, + exir_ops.edge.aten.eq.Scalar: exir_ops.edge.aten.eq.Tensor, + torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor, + torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor, + torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor, + torch.ops.aten.div.Scalar: torch.ops.aten.div.Tensor, + torch.ops.aten.__rshift__.Scalar: torch.ops.aten.bitwise_right_shift.Tensor, + torch.ops.aten.__lshift__.Scalar: torch.ops.aten.bitwise_left_shift.Tensor, + torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor, +} + + +class ReplaceScalarWithTensorArgPassTOSAMI(ReplaceScalarWithTensorArgPass): + scalar_to_tensor_ops = _common_ops | { + exir_ops.edge.aten.pow.Tensor_Scalar: exir_ops.edge.aten.pow.Tensor_Tensor, + torch.ops.aten.pow.Tensor_Scalar: torch.ops.aten.pow.Tensor_Tensor, + } + + def __init__(self): + super().__init__(self.scalar_to_tensor_ops) + + +class ReplaceScalarWithTensorArgPassTOSABI(ReplaceScalarWithTensorArgPass): + scalar_to_tensor_ops = _common_ops + + def __init__(self): + super().__init__(self.scalar_to_tensor_ops) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 4843a17eb1d..f37289fa001 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -198,6 +198,8 @@ def is_node_supported( exir_ops.edge.aten.clone.default, exir_ops.edge.aten.unsqueeze_copy.default, exir_ops.edge.aten.squeeze_copy.dims, + exir_ops.edge.aten.pow.Tensor_Scalar, + exir_ops.edge.aten.pow.Tensor_Tensor, operator.getitem, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 3da4a2617f6..f891d8d3b69 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -32,6 +32,7 @@ op_minimum, op_mul, op_permute, + op_pow, op_reciprocal, op_repeat, op_rescale, diff --git a/backends/arm/operators/op_pow.py b/backends/arm/operators/op_pow.py new file mode 100644 index 00000000000..0f251a8aa6d --- /dev/null +++ b/backends/arm/operators/op_pow.py @@ -0,0 +1,57 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import List + +import serializer.tosa_serializer as ts +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification +from serializer.tosa_serializer import TosaOp +from torch.fx import Node + + +@register_node_visitor +class PowVisitor_080_MI(NodeVisitor): + target = "aten.pow.Tensor_Tensor" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + if not (inputs[0].dtype == inputs[1].dtype == output.dtype): + raise ValueError( + "All inputs and outputs need same dtype." + f"Got {inputs[0].dtype=}, {inputs[1].dtype=}, {output.dtype=}" + ) + if inputs[0].dtype not in [ts.DType.FP32, ts.DType.FP16]: + raise ValueError( + f"All inputs need to be FP32 or FP16. Got {inputs[0].dtype}" + ) + + tosa_graph.addOperator( + TosaOp.Op().POW, + [ + inputs[0].name, + inputs[1].name, + ], + [output.name], + None, + ) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 3e98df7f6f5..1c6d05f2557 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -139,6 +139,7 @@ def _match_pattern( torch.ops.aten.hardswish.default, torch.ops.aten.hardswish_.default, torch.ops.aten.full_like.default, + torch.ops.aten.pow.Tensor_Scalar, ] _one_to_one_shared_input_qspec = [ diff --git a/backends/arm/test/ops/test_pow.py b/backends/arm/test/ops/test_pow.py new file mode 100644 index 00000000000..618acf50fc2 --- /dev/null +++ b/backends/arm/test/ops/test_pow.py @@ -0,0 +1,144 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + + +class Pow_TensorTensor(torch.nn.Module): + aten_op = "torch.ops.aten.pow.Tensor_Tensor" + exir_op = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Tensor" + + input_t = Tuple[torch.Tensor | float, torch.Tensor | float] + + # The sign of the operands are important w.r.t. TOSA's spec of pow + test_data = { + "zero_base_pos_exp": lambda: ( + torch.zeros(1, 8, 3, 7), + torch.abs(torch.randn((1, 8, 1, 7))) + 1e5, + ), + "pos_base": lambda: ( + torch.abs(torch.randn((3, 2, 4, 2))) + 1e5, + torch.randn((1, 2, 4, 1)), + ), + "zero_base_zero_exp": lambda: (torch.zeros(2, 3), torch.zeros(2, 3)), + "pos_base_zero_exp": lambda: ( + torch.abs(torch.randn((1, 7, 2, 3))) + 1e5, + torch.zeros(1, 1, 2, 3), + ), + "neg_base_zero_exp": lambda: ( + -torch.abs(torch.randn((1, 2, 3, 4))) - 1e5, + torch.zeros(1, 2, 3, 4), + ), + "base_has_lower_rank": lambda: (torch.ones(3, 4), torch.ones(1, 2, 3, 4)), + "exp_has_lower_rank": lambda: (torch.ones(1, 2, 3, 4), torch.ones(3, 4)), + "f16_tensors": lambda: ( + torch.HalfTensor([[1.0, 2.0, 3.0], [0.5, 1.5, 2.5]]), + torch.HalfTensor([[1.0, 2.0, 0.0]]), + ), + } + + def forward(self, x: torch.Tensor | float, y: torch.Tensor | float): + return torch.pow(x, y) + + +class Pow_TensorScalar(torch.nn.Module): + aten_op = "torch.ops.aten.pow.Tensor_Scalar" + exir_op = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar" + + input_t = Tuple[torch.Tensor] + + test_data = { + # Test whole number exponents + "exp_minus_three": lambda: (torch.randn((10, 5)), -3.0), + "exp_minus_one": lambda: (torch.randn((42,)), -1.0), + "exp_zero": lambda: (torch.randn((1, 2, 3, 7)), 0.0), + "exp_one": lambda: (torch.randn((1, 4, 6, 2)), 1.0), + "exp_two": lambda: (torch.randn((1, 2, 3, 6)), 2.0), + # Test decimal exponent (base must be non-negative) + "non_neg_base_exp_pos_decimal": lambda: ( + torch.abs(torch.randn((1, 2, 3, 6))), + 6.789, + ), + } + + def __init__(self, exp): + super().__init__() + self.exp = exp + + def forward(self, x: torch.Tensor): + return torch.pow(x, self.exp) + + +@common.parametrize("test_data", Pow_TensorTensor.test_data) +def test_pow_tensor_tensor_MI(test_data: Pow_TensorTensor.input_t): + pipeline = TosaPipelineMI[Pow_TensorTensor.input_t]( + Pow_TensorTensor(), + test_data(), + Pow_TensorTensor.aten_op, + Pow_TensorTensor.exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", Pow_TensorScalar.test_data) +def test_pow_tensor_scalar_MI(test_data: Pow_TensorScalar.input_t): + base, exp = test_data() + pipeline = TosaPipelineMI[Pow_TensorScalar.input_t]( + Pow_TensorScalar(exp), + (base,), + Pow_TensorScalar.aten_op, + Pow_TensorScalar.exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", Pow_TensorScalar.test_data) +def test_pow_tensor_scalar_BI(test_data: Pow_TensorScalar.input_t): + base, exp = test_data() + pipeline = TosaPipelineBI[Pow_TensorScalar.input_t]( + Pow_TensorScalar(exp), + (base,), + Pow_TensorScalar.aten_op, + Pow_TensorScalar.exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", Pow_TensorScalar.test_data) +@common.XfailIfNoCorstone300 +def test_pow_tensor_scalar_u55_BI(test_data: Pow_TensorScalar.input_t): + base, exp = test_data() + pipeline = EthosU55PipelineBI[Pow_TensorScalar.input_t]( + Pow_TensorScalar(exp), + (base,), + Pow_TensorScalar.aten_op, + Pow_TensorScalar.exir_op, + run_on_fvp=True, + ) + pipeline.run() + + +@common.parametrize("test_data", Pow_TensorScalar.test_data) +@common.XfailIfNoCorstone320 +def test_pow_tensor_scalar_u85_BI(test_data: Pow_TensorScalar.input_t): + base, exp = test_data() + pipeline = EthosU85PipelineBI[Pow_TensorScalar.input_t]( + Pow_TensorScalar(exp), + (base,), + Pow_TensorScalar.aten_op, + Pow_TensorScalar.exir_op, + run_on_fvp=True, + ) + pipeline.run() diff --git a/backends/transforms/replace_scalar_with_tensor.py b/backends/transforms/replace_scalar_with_tensor.py index b1bab5b0b66..21eb325b646 100644 --- a/backends/transforms/replace_scalar_with_tensor.py +++ b/backends/transforms/replace_scalar_with_tensor.py @@ -5,7 +5,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict +from typing import Dict, Optional import torch from executorch.exir.dialects._ops import ops as exir_ops @@ -19,23 +19,27 @@ class ReplaceScalarWithTensorArgPass(ExportPass): replace the scalar arg with Tensor arg. """ - scalar_to_tensor_ops: Dict[EdgeOpOverload, EdgeOpOverload] = { + default_ops: Dict[EdgeOpOverload, EdgeOpOverload] = { exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor, - exir_ops.edge.aten.__rshift__.Scalar: exir_ops.edge.aten.bitwise_right_shift.Tensor, - exir_ops.edge.aten.__lshift__.Scalar: exir_ops.edge.aten.bitwise_left_shift.Tensor, - exir_ops.edge.aten.eq.Scalar: exir_ops.edge.aten.eq.Tensor, torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor, torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor, torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor, torch.ops.aten.div.Scalar: torch.ops.aten.div.Tensor, - torch.ops.aten.__rshift__.Scalar: torch.ops.aten.bitwise_right_shift.Tensor, - torch.ops.aten.__lshift__.Scalar: torch.ops.aten.bitwise_left_shift.Tensor, - torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor, } + def __init__( + self, + scalar_to_tensor_ops: Optional[Dict[EdgeOpOverload, EdgeOpOverload]] = None, + ): + if scalar_to_tensor_ops is not None: + self.scalar_to_tensor_ops = scalar_to_tensor_ops + else: + self.scalar_to_tensor_ops = self.default_ops + super().__init__() + def get_replacement(self, op, args, kwargs, meta): return super().call_operator( # Replace with .Tensor variant.