Skip to content

Arm backend: Add where.self #9869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .insert_table_ops import InsertTableOpsPass # noqa
from .keep_dims_false_to_squeeze_pass import KeepDimsFalseToSqueezePass # noqa
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
from .match_where_self_arg_dtype_pass import MatchWhereSelfDtypePass # noqa
from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass # noqa
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
from .remove_clone_pass import RemoveClonePass # noqa
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
InsertTableOpsPass,
KeepDimsFalseToSqueezePass,
MatchArgRanksPass,
MatchWhereSelfDtypePass,
QuantizeOperatorArguments,
RemoveClonePass,
ReplaceScalarWithTensorArgPassTOSABI,
Expand Down Expand Up @@ -80,6 +81,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertToClampPass())
self.add_pass(ConvertMinMaxPass())
self.add_pass(ConvertAnyDefaultDimDimsPass())
self.add_pass(MatchWhereSelfDtypePass())
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
self.add_pass(CastToInt32Pass())

Expand Down Expand Up @@ -130,6 +132,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertToClampPass())
self.add_pass(ConvertMinMaxPass())
self.add_pass(ConvertAnyDefaultDimDimsPass())
self.add_pass(MatchWhereSelfDtypePass())

self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeOperatorArguments())
Expand Down
1 change: 1 addition & 0 deletions backends/arm/_passes/match_arg_ranks_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self, exported_program):
exir_ops.edge.aten.bitwise_left_shift.Tensor,
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.pow.Tensor_Tensor,
exir_ops.edge.aten.where.self,
]

def _match_op_rank(self, graph_module, node, arg, max_rank):
Expand Down
95 changes: 95 additions & 0 deletions backends/arm/_passes/match_where_self_arg_dtype_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

DTYPE_RANK = {
torch.bool: 0,
torch.uint8: 1,
torch.int8: 2,
torch.int16: 3,
torch.int32: 4,
torch.int64: 5,
torch.float16: 6,
torch.float32: 7,
torch.float64: 8,
}


def get_largest_dtype(dtype_1, dtype_2):
"""Find the largest dtype."""
return dtype_1 if DTYPE_RANK[dtype_1] > DTYPE_RANK[dtype_2] else dtype_2


class MatchWhereSelfDtypePass(ExportPass):
"""Pass to match data types of non-condition input tensors.

Edge dialect allows different data types for non-condition tensors, while TOSA
does not. In cases where they differ a TOSA CAST operator is inserted.

There is an edge case where one input is `boolean`, which cannot be directly cast
to, for example, float32. When this occurs two CAST operators are added to first
cast to int8 and then to the correct target data type.

"""

def call(self, graph_module: torch.fx.GraphModule):
modified_graph = False
graph = graph_module.graph
node_list = graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.where.self
)
for node in node_list:
cond, input_, other_ = node.args

input_dtype = input_.meta["val"].dtype
other_dtype = other_.meta["val"].dtype
target_dtype = torch.float32
if input_dtype != other_dtype:
target_dtype = get_largest_dtype(input_dtype, other_dtype)

for arg in node.args[1:]:
arg_dtype = arg.meta["val"].dtype

if arg_dtype != target_dtype:
if arg_dtype == torch.bool:
# Bool is an edge case which cannot necessarily be directly
# converted to the target data type.
with graph.inserting_after(arg):
replace_node_int8 = create_node(
graph,
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
)
replace_node_int8.args = (arg,)
replace_node_int8.kwargs = {"dtype": torch.int8}

with graph.inserting_after(replace_node_int8):
replace_node_fp32 = create_node(
graph,
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
)
replace_node_fp32.args = (replace_node_int8,)
replace_node_fp32.kwargs = {"dtype": target_dtype}
node.replace_input_with(arg, replace_node_fp32)
else:
with graph.inserting_after(arg):
replace_node = create_node(
graph,
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
)
replace_node.args = (arg,)
replace_node.kwargs = {"dtype": target_dtype}
node.replace_input_with(arg, replace_node)

modified_graph = True

if modified_graph:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, modified_graph)
1 change: 1 addition & 0 deletions backends/arm/operator_support/ethos_u55_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class EthosU55NotSupported(OperatorSupportBase):
exir_ops.edge.aten.reflection_pad1d.default, # REVERSE
exir_ops.edge.aten.reflection_pad2d.default, # REVERSE
exir_ops.edge.aten.reflection_pad3d.default, # REVERSE
exir_ops.edge.aten.where.self, # SELECT
]

def __init__(self, reporter: WhyNoPartitionReporter):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def is_node_supported(
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.pow.Tensor_Scalar,
exir_ops.edge.aten.pow.Tensor_Tensor,
exir_ops.edge.aten.where.self,
operator.getitem,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
op_transpose,
op_upsample_nearest2d,
op_view,
op_where,
ops_binary,
ops_unary,
)
103 changes: 103 additions & 0 deletions backends/arm/operators/op_where.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Sequence

import serializer.tosa_serializer as ts # type: ignore

from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


def _add_node_to_tosa_graph(
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
supported_dtypes: Sequence,
) -> None:
if len(inputs) != 3:
raise ValueError(f"aten.where.self expects 3 arguments, got {len(inputs)}")

if inputs[0].dtype is not ts.DType.BOOL:
raise ValueError("Input 0 needs to have dtype BOOL")
if inputs[1].dtype != inputs[2].dtype:
raise ValueError(
"Non-condition tensors must have same data type, got "
f"{inputs[1].dtype} and {inputs[2].dtype}"
)
for input_ in inputs[1:]:
if input_.dtype not in supported_dtypes:
raise ValueError(
f"Input needs to be of torch dtype {supported_dtypes}, got {input_.dtype}"
)

tosa_graph.addOperator(
TosaOp.Op().SELECT,
[inputs[0].name, inputs[1].name, inputs[2].name],
[output.name],
None,
)


@register_node_visitor
class WhereVisitor_080_BI(NodeVisitor):
target = "aten.where.self"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

bi_supported_dtypes = [
ts.DType.INT8,
ts.DType.INT16,
ts.DType.INT32,
ts.DType.BOOL,
]
_add_node_to_tosa_graph(tosa_graph, inputs, output, bi_supported_dtypes)


@register_node_visitor
class WhereVisitor_080_MI(WhereVisitor_080_BI):

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
mi_supported_dtypes = [
ts.DType.FP16,
ts.DType.FP32,
ts.DType.INT8,
ts.DType.INT16,
ts.DType.INT32,
ts.DType.BOOL,
]
_add_node_to_tosa_graph(tosa_graph, inputs, output, mi_supported_dtypes)
18 changes: 13 additions & 5 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,14 @@ def _match_pattern(
torch.ops.aten.dropout_.default,
torch.ops.aten.clamp.default,
torch.ops.aten.clamp.Tensor,
torch.ops.aten.where,
operator.getitem,
]


def get_quant_properties( # noqa: C901
node: Node, gm: torch.fx.GraphModule, quantization_config
) -> _OpQuantProperties:
) -> _OpQuantProperties | None:
input_act_qspec = quantization_config.get_input_act_qspec()
weight_qspec = quantization_config.get_weight_qspec()
output_act_qspec = quantization_config.get_output_act_qspec()
Expand Down Expand Up @@ -322,6 +323,13 @@ def any_or_hardtanh_min_zero(n: Node):
),
]
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
elif node.target in (torch.ops.aten.where.self,):
shared_qspec = SharedQuantizationSpec(node.args[1]) # type: ignore[arg-type]
quant_properties.quant_inputs = [
_QuantProperty(1, shared_qspec), # type: ignore[arg-type]
_QuantProperty(2, shared_qspec), # type: ignore[arg-type]
]
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
elif node.target == torch.ops.aten.adaptive_avg_pool2d.default:
input_qspec = (
SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
Expand Down Expand Up @@ -376,16 +384,16 @@ def any_or_hardtanh_min_zero(n: Node):
quant_properties.quant_output = None
elif node.target in _parent_shared_qspec:
if not isinstance(node.args[0], Node):
return None # type: ignore[return-value]
return None

if not arm_quantizer_utils.is_output_annotated(node.args[0]): # type: ignore[attr-defined]
return None # type: ignore[return-value]
return None

shared_qspec = SharedQuantizationSpec(node.args[0])
quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] # type: ignore[arg-type]
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
else:
return None # type: ignore[return-value]
return None

# Don't check if operator.getitem is ok for quantization, it's always ok
if node.target == operator.getitem:
Expand All @@ -394,7 +402,7 @@ def any_or_hardtanh_min_zero(n: Node):
# Check that each inputs/outputs can be quantized properly with the
# provided quantization properties.
if not _is_ok_for_quantization(node, quant_properties, gm):
return None # type: ignore[return-value]
return None

return quant_properties

Expand Down
2 changes: 0 additions & 2 deletions backends/arm/test/models/test_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ class TestConformer(unittest.TestCase):
# .to_executorch step, i.e. after Arm partitioner.
ops_after_partitioner = {
"executorch_exir_dialects_edge__ops_aten_max_default": 1,
"executorch_exir_dialects_edge__ops_aten_where_self": 4,
"torch.ops.aten._assert_scalar.default": 10,
"torch.ops.aten._local_scalar_dense.default": 1,
"torch.ops.higher_order.executorch_call_delegate": 4,
}

dim = 16
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/models/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_llama_tosa_MI(self):
)
.export()
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 14})
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(
inputs=llama_inputs,
Expand Down
Loading
Loading