From ca6c5013a175f0363aee858bc0d0de1624cb17db Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Mon, 10 Feb 2025 14:00:56 +0100 Subject: [PATCH 1/2] Arm backend: Add where.self Add operator where.self, which maps to the SELECT operator in the TOSA specification. Tests are also added for the aformentioned operator. Change-Id: I7bd609dd741357b78a0a059cb0f61f2c975bbf06 Signed-off-by: Sebastian Larsson --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 3 + backends/arm/_passes/match_arg_ranks_pass.py | 1 + .../match_where_self_arg_dtype_pass.py | 95 ++++++ .../arm/operator_support/ethos_u55_support.py | 1 + .../tosa_supported_operators.py | 1 + backends/arm/operators/__init__.py | 1 + backends/arm/operators/op_where.py | 103 +++++++ .../arm/quantizer/quantization_annotator.py | 8 + backends/arm/test/models/test_conformer.py | 2 - backends/arm/test/models/test_llama.py | 2 +- backends/arm/test/ops/test_where.py | 276 ++++++++++++++++++ 12 files changed, 491 insertions(+), 3 deletions(-) create mode 100644 backends/arm/_passes/match_where_self_arg_dtype_pass.py create mode 100644 backends/arm/operators/op_where.py create mode 100644 backends/arm/test/ops/test_where.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 8409e1a9f38..bc1c2ef3d66 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -39,6 +39,7 @@ from .insert_table_ops import InsertTableOpsPass # noqa from .keep_dims_false_to_squeeze_pass import KeepDimsFalseToSqueezePass # noqa from .match_arg_ranks_pass import MatchArgRanksPass # noqa +from .match_where_self_arg_dtype_pass import MatchWhereSelfDtypePass # noqa from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass # noqa from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa from .remove_clone_pass import RemoveClonePass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index f56be30083c..c085e3def1b 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -40,6 +40,7 @@ InsertTableOpsPass, KeepDimsFalseToSqueezePass, MatchArgRanksPass, + MatchWhereSelfDtypePass, QuantizeOperatorArguments, RemoveClonePass, ReplaceScalarWithTensorArgPassTOSABI, @@ -80,6 +81,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertToClampPass()) self.add_pass(ConvertMinMaxPass()) self.add_pass(ConvertAnyDefaultDimDimsPass()) + self.add_pass(MatchWhereSelfDtypePass()) if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset: self.add_pass(CastToInt32Pass()) @@ -130,6 +132,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertToClampPass()) self.add_pass(ConvertMinMaxPass()) self.add_pass(ConvertAnyDefaultDimDimsPass()) + self.add_pass(MatchWhereSelfDtypePass()) self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeOperatorArguments()) diff --git a/backends/arm/_passes/match_arg_ranks_pass.py b/backends/arm/_passes/match_arg_ranks_pass.py index 759f215a034..2cfc9b2b86a 100644 --- a/backends/arm/_passes/match_arg_ranks_pass.py +++ b/backends/arm/_passes/match_arg_ranks_pass.py @@ -49,6 +49,7 @@ def __init__(self, exported_program): exir_ops.edge.aten.bitwise_left_shift.Tensor, exir_ops.edge.aten.eq.Tensor, exir_ops.edge.aten.pow.Tensor_Tensor, + exir_ops.edge.aten.where.self, ] def _match_op_rank(self, graph_module, node, arg, max_rank): diff --git a/backends/arm/_passes/match_where_self_arg_dtype_pass.py b/backends/arm/_passes/match_where_self_arg_dtype_pass.py new file mode 100644 index 00000000000..154602129f8 --- /dev/null +++ b/backends/arm/_passes/match_where_self_arg_dtype_pass.py @@ -0,0 +1,95 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +DTYPE_RANK = { + torch.bool: 0, + torch.uint8: 1, + torch.int8: 2, + torch.int16: 3, + torch.int32: 4, + torch.int64: 5, + torch.float16: 6, + torch.float32: 7, + torch.float64: 8, +} + + +def get_largest_dtype(dtype_1, dtype_2): + """Find the largest dtype.""" + return dtype_1 if DTYPE_RANK[dtype_1] > DTYPE_RANK[dtype_2] else dtype_2 + + +class MatchWhereSelfDtypePass(ExportPass): + """Pass to match data types of non-condition input tensors. + + Edge dialect allows different data types for non-condition tensors, while TOSA + does not. In cases where they differ a TOSA CAST operator is inserted. + + There is an edge case where one input is `boolean`, which cannot be directly cast + to, for example, float32. When this occurs two CAST operators are added to first + cast to int8 and then to the correct target data type. + + """ + + def call(self, graph_module: torch.fx.GraphModule): + modified_graph = False + graph = graph_module.graph + node_list = graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.where.self + ) + for node in node_list: + cond, input_, other_ = node.args + + input_dtype = input_.meta["val"].dtype + other_dtype = other_.meta["val"].dtype + target_dtype = torch.float32 + if input_dtype != other_dtype: + target_dtype = get_largest_dtype(input_dtype, other_dtype) + + for arg in node.args[1:]: + arg_dtype = arg.meta["val"].dtype + + if arg_dtype != target_dtype: + if arg_dtype == torch.bool: + # Bool is an edge case which cannot necessarily be directly + # converted to the target data type. + with graph.inserting_after(arg): + replace_node_int8 = create_node( + graph, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + ) + replace_node_int8.args = (arg,) + replace_node_int8.kwargs = {"dtype": torch.int8} + + with graph.inserting_after(replace_node_int8): + replace_node_fp32 = create_node( + graph, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + ) + replace_node_fp32.args = (replace_node_int8,) + replace_node_fp32.kwargs = {"dtype": target_dtype} + node.replace_input_with(arg, replace_node_fp32) + else: + with graph.inserting_after(arg): + replace_node = create_node( + graph, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + ) + replace_node.args = (arg,) + replace_node.kwargs = {"dtype": target_dtype} + node.replace_input_with(arg, replace_node) + + modified_graph = True + + if modified_graph: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified_graph) diff --git a/backends/arm/operator_support/ethos_u55_support.py b/backends/arm/operator_support/ethos_u55_support.py index 64f3fb3f816..69fda636423 100644 --- a/backends/arm/operator_support/ethos_u55_support.py +++ b/backends/arm/operator_support/ethos_u55_support.py @@ -149,6 +149,7 @@ class EthosU55NotSupported(OperatorSupportBase): exir_ops.edge.aten.reflection_pad1d.default, # REVERSE exir_ops.edge.aten.reflection_pad2d.default, # REVERSE exir_ops.edge.aten.reflection_pad3d.default, # REVERSE + exir_ops.edge.aten.where.self, # SELECT ] def __init__(self, reporter: WhyNoPartitionReporter): diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 2a31ecbc775..0e5d7ecc958 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -207,6 +207,7 @@ def is_node_supported( exir_ops.edge.aten.squeeze_copy.dims, exir_ops.edge.aten.pow.Tensor_Scalar, exir_ops.edge.aten.pow.Tensor_Tensor, + exir_ops.edge.aten.where.self, operator.getitem, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index f891d8d3b69..2a610536f3e 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -49,6 +49,7 @@ op_transpose, op_upsample_nearest2d, op_view, + op_where, ops_binary, ops_unary, ) diff --git a/backends/arm/operators/op_where.py b/backends/arm/operators/op_where.py new file mode 100644 index 00000000000..c8b35e831d4 --- /dev/null +++ b/backends/arm/operators/op_where.py @@ -0,0 +1,103 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Sequence + +import serializer.tosa_serializer as ts # type: ignore + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification +from serializer.tosa_serializer import TosaOp +from torch.fx import Node + + +def _add_node_to_tosa_graph( + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + supported_dtypes: Sequence, +) -> None: + if len(inputs) != 3: + raise ValueError(f"aten.where.self expects 3 arguments, got {len(inputs)}") + + if inputs[0].dtype is not ts.DType.BOOL: + raise ValueError("Input 0 needs to have dtype BOOL") + if inputs[1].dtype != inputs[2].dtype: + raise ValueError( + "Non-condition tensors must have same data type, got " + f"{inputs[1].dtype} and {inputs[2].dtype}" + ) + for input_ in inputs[1:]: + if input_.dtype not in supported_dtypes: + raise ValueError( + f"Input needs to be of torch dtype {supported_dtypes}, got {input_.dtype}" + ) + + tosa_graph.addOperator( + TosaOp.Op().SELECT, + [inputs[0].name, inputs[1].name, inputs[2].name], + [output.name], + None, + ) + + +@register_node_visitor +class WhereVisitor_080_BI(NodeVisitor): + target = "aten.where.self" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + bi_supported_dtypes = [ + ts.DType.INT8, + ts.DType.INT16, + ts.DType.INT32, + ts.DType.BOOL, + ] + _add_node_to_tosa_graph(tosa_graph, inputs, output, bi_supported_dtypes) + + +@register_node_visitor +class WhereVisitor_080_MI(WhereVisitor_080_BI): + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + mi_supported_dtypes = [ + ts.DType.FP16, + ts.DType.FP32, + ts.DType.INT8, + ts.DType.INT16, + ts.DType.INT32, + ts.DType.BOOL, + ] + _add_node_to_tosa_graph(tosa_graph, inputs, output, mi_supported_dtypes) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index e9ed6be81f3..ea00dcde012 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -238,6 +238,7 @@ def _match_pattern( torch.ops.aten.dropout_.default, torch.ops.aten.clamp.default, torch.ops.aten.clamp.Tensor, + torch.ops.aten.where, operator.getitem, ] @@ -322,6 +323,13 @@ def any_or_hardtanh_min_zero(n: Node): ), ] quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type] + elif node.target in (torch.ops.aten.where.self,): + shared_qspec = SharedQuantizationSpec(node.args[1]) # type: ignore[arg-type] + quant_properties.quant_inputs = [ + _QuantProperty(1, shared_qspec), # type: ignore[arg-type] + _QuantProperty(2, shared_qspec), # type: ignore[arg-type] + ] + quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type] elif node.target == torch.ops.aten.adaptive_avg_pool2d.default: input_qspec = ( SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type] diff --git a/backends/arm/test/models/test_conformer.py b/backends/arm/test/models/test_conformer.py index e270fb18205..dc5ecc7ca97 100644 --- a/backends/arm/test/models/test_conformer.py +++ b/backends/arm/test/models/test_conformer.py @@ -31,10 +31,8 @@ class TestConformer(unittest.TestCase): # .to_executorch step, i.e. after Arm partitioner. ops_after_partitioner = { "executorch_exir_dialects_edge__ops_aten_max_default": 1, - "executorch_exir_dialects_edge__ops_aten_where_self": 4, "torch.ops.aten._assert_scalar.default": 10, "torch.ops.aten._local_scalar_dense.default": 1, - "torch.ops.higher_order.executorch_call_delegate": 4, } dim = 16 diff --git a/backends/arm/test/models/test_llama.py b/backends/arm/test/models/test_llama.py index a6da04b0e2e..bd18ff1856f 100644 --- a/backends/arm/test/models/test_llama.py +++ b/backends/arm/test/models/test_llama.py @@ -114,7 +114,7 @@ def test_llama_tosa_MI(self): ) .export() .to_edge_transform_and_lower() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 14}) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .run_method_and_compare_outputs( inputs=llama_inputs, diff --git a/backends/arm/test/ops/test_where.py b/backends/arm/test/ops/test_where.py new file mode 100644 index 00000000000..bf127460f3e --- /dev/null +++ b/backends/arm/test/ops/test_where.py @@ -0,0 +1,276 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Tuple + +import pytest + +import torch + +from executorch.backends.arm.quantizer.arm_quantizer import ( + EthosUQuantizer, + get_symmetric_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU85PipelineBI, + OpNotSupportedPipeline, + TosaPipelineBI, + TosaPipelineMI, +) +from executorch.backends.xnnpack.test.tester.tester import Quantize + +aten_op = "torch.ops.aten.where.self" +exir_op = "executorch_exir_dialects_edge__ops_aten_where_self" + + +class Where(torch.nn.Module): + def __init__( + self, shape: tuple | int, dtype: torch.dtype | Tuple[torch.dtype], condition + ): + super().__init__() + self.shape = shape if isinstance(shape, tuple) else (shape,) * shape + self.dtype = (dtype, dtype) if isinstance(dtype, torch.dtype) else dtype + self.condition = condition + + def get_inputs(self): + inputs: List = [0, 0] + for i in range(2): + if self.dtype[i] in [torch.int8, torch.int16, torch.int32]: + inputs[i] = torch.randint( + torch.iinfo(self.dtype[i]).min, + torch.iinfo(self.dtype[i]).max, + self.shape, + dtype=self.dtype[i], + ) + elif self.dtype[i] in [torch.float32]: + inputs[i] = torch.randn(*self.shape).to(self.dtype[i]) + elif self.dtype[i] is torch.bool: + inputs[i] = torch.randint(0, 1, self.shape, dtype=torch.bool) + else: + raise TypeError( + f"Input generation for dtype {self.dtype[i]} not implemented in " + "Where()" + ) + + return tuple(inputs) + + def forward( + self, + input_: torch.Tensor, + other_: torch.Tensor, + ): + return torch.where(self.condition(input_), input_, other_) + + +def tensor_condition(input: torch.Tensor): + return input > torch.zeros_like(input) + + +def scalar_condition(input: torch.Tensor): + return input > 0 + + +two_dim_tensor_cond = Where( + 2, + torch.float32, + tensor_condition, +) + +three_dim_tensor_cond = Where( + 3, + torch.float32, + tensor_condition, +) + +float32_tensor_cond = Where( + 1, + torch.float32, + tensor_condition, +) + +float32_tensor_cond_tuple_dtype = Where( + 1, + (torch.float32, torch.int8), + tensor_condition, +) + +float32_tensor_cond_tuple_dtype_bool = Where( + 1, + (torch.float32, torch.bool), + tensor_condition, +) + +# Scalar tests +two_dim_scalar_cond = Where( + 2, + torch.float32, + scalar_condition, +) + +three_dim_scalar_cond = Where( + 3, + torch.float32, + scalar_condition, +) + +float32_scalar_cond = Where( + 1, + torch.float32, + scalar_condition, +) + +test_modules_common = { + "two_dim_tensor_cond": two_dim_tensor_cond, + "three_dim_tensor_cond": three_dim_tensor_cond, + "float32_tensor_cond": float32_tensor_cond, + "two_dim_scalar_cond": two_dim_scalar_cond, + "three_dim_scalar_cond": three_dim_scalar_cond, + "float32_scalar_cond": float32_scalar_cond, +} + +test_modules_MI = { + **test_modules_common, + "float32_tensor_cond_tuple_dtype": float32_tensor_cond_tuple_dtype, + "float32_tensor_cond_tuple_dtype_bool": float32_tensor_cond_tuple_dtype_bool, +} + +test_modules_BI = { + **test_modules_common, +} + +input_t = Tuple[torch.Tensor] + + +@common.parametrize("test_module", test_modules_MI) +def test_where_tosa_MI(test_module): + pipeline = TosaPipelineMI[input_t]( + test_module, test_module.get_inputs(), aten_op, exir_op + ) + pipeline.run() + + +@common.parametrize("test_module", test_modules_BI) +def test_where_tosa_BI(test_module): + compile_spec = common.get_tosa_compile_spec("TOSA-0.80+BI") + quantizer = TOSAQuantizer(compile_spec).set_io(get_symmetric_quantization_config()) + pipeline = TosaPipelineBI[input_t]( + test_module, test_module.get_inputs(), aten_op, exir_op + ) + pipeline.change_args( + "quantize", Quantize(quantizer, get_symmetric_quantization_config()) + ) + pipeline.run() + + +@common.parametrize("test_module", test_modules_BI) +def test_where_u55_BI(test_module): + compile_spec = common.get_u55_compile_spec() + quantizer = EthosUQuantizer(compile_spec).set_io( + get_symmetric_quantization_config() + ) + + # If condition is tensor_condition then there will be one full_like op which will be + # delegated. + if test_module.condition == tensor_condition: + num_delegates = 1 + num_exir = 0 + else: + num_delegates = 0 + num_exir = 0 + + pipeline = OpNotSupportedPipeline[input_t]( + test_module, + test_module.get_inputs(), + "TOSA-0.80+BI+u55", + { + exir_op: 1, + "executorch_exir_dialects_edge__ops_aten_full_default": num_exir, + }, + num_delegates, + ) + + pipeline.change_args( + "quantize", Quantize(quantizer, get_symmetric_quantization_config()) + ) + pipeline.run() + + +@common.parametrize("test_module", test_modules_BI) +def test_where_u85_BI(test_module): + compile_spec = common.get_u85_compile_spec() + quantizer = EthosUQuantizer(compile_spec).set_io( + get_symmetric_quantization_config() + ) + pipeline = EthosU85PipelineBI[input_t]( + test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=False + ) + pipeline.change_args( + "quantize", Quantize(quantizer, get_symmetric_quantization_config()) + ) + pipeline.run() + + +@common.parametrize("test_module", test_modules_BI) +@pytest.mark.skip(reason="The same as test_where_u55_BI") +@common.XfailIfNoCorstone300 +def test_where_u55_BI_on_fvp(test_module): + compile_spec = common.get_u55_compile_spec() + quantizer = EthosUQuantizer(compile_spec).set_io( + get_symmetric_quantization_config() + ) + + # If condition is tensor_condition then there will be one full_like op which will be + # delegated. + if test_module.condition == tensor_condition: + num_delegates = 1 + num_exir = 0 + else: + num_delegates = 0 + num_exir = 0 + + pipeline = OpNotSupportedPipeline[input_t]( + test_module, + test_module.get_inputs(), + "TOSA-0.80+BI+u55", + { + exir_op: 1, + "executorch_exir_dialects_edge__ops_aten_full_default": num_exir, + }, + num_delegates, + ) + + pipeline.change_args( + "quantize", Quantize(quantizer, get_symmetric_quantization_config()) + ) + pipeline.run() + + +@common.parametrize( + "test_module", + test_modules_BI, + xfails={ + "two_dim_scalar_cond": "E [executorch:method.cpp:601] Missing operator: " + "[2] aten::gt.Scalar_out", + "three_dim_scalar_cond": "E [executorch:method.cpp:601] Missing operator: " + "[2] aten::gt.Scalar_out", + "float32_scalar_cond": "E [executorch:method.cpp:601] Missing operator: " + "[2] aten::gt.Scalar_out", + }, +) +@common.XfailIfNoCorstone320 +def test_where_u85_BI_on_fvp(test_module): + compile_spec = common.get_u85_compile_spec() + quantizer = EthosUQuantizer(compile_spec).set_io( + get_symmetric_quantization_config() + ) + pipeline = EthosU85PipelineBI[input_t]( + test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=True + ) + pipeline.change_args( + "quantize", Quantize(quantizer, get_symmetric_quantization_config()) + ) + pipeline.run() From 7c92392541ca262ce53cbcdc8ec71d545d474907 Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Fri, 7 Mar 2025 14:43:04 +0100 Subject: [PATCH 2/2] Arm backend: Fix get_quant_properties return type The function `get_quant_properties` can return `None`, which was not reflected in its return type annotation. This could cause type checker warnings or incorrect assumptions about its return value. Update the return type from `_OpQuantProperties` to `_OpQuantProperties | None` to correctly represent the possible outputs. Change-Id: I1f360ab1b9dcda119473f17a3eb46ffbaeba831c Signed-off-by: Sebastian Larsson --- backends/arm/quantizer/quantization_annotator.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index ea00dcde012..baca13029a3 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -245,7 +245,7 @@ def _match_pattern( def get_quant_properties( # noqa: C901 node: Node, gm: torch.fx.GraphModule, quantization_config -) -> _OpQuantProperties: +) -> _OpQuantProperties | None: input_act_qspec = quantization_config.get_input_act_qspec() weight_qspec = quantization_config.get_weight_qspec() output_act_qspec = quantization_config.get_output_act_qspec() @@ -384,16 +384,16 @@ def any_or_hardtanh_min_zero(n: Node): quant_properties.quant_output = None elif node.target in _parent_shared_qspec: if not isinstance(node.args[0], Node): - return None # type: ignore[return-value] + return None if not arm_quantizer_utils.is_output_annotated(node.args[0]): # type: ignore[attr-defined] - return None # type: ignore[return-value] + return None shared_qspec = SharedQuantizationSpec(node.args[0]) quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] # type: ignore[arg-type] quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type] else: - return None # type: ignore[return-value] + return None # Don't check if operator.getitem is ok for quantization, it's always ok if node.target == operator.getitem: @@ -402,7 +402,7 @@ def any_or_hardtanh_min_zero(n: Node): # Check that each inputs/outputs can be quantized properly with the # provided quantization properties. if not _is_ok_for_quantization(node, quant_properties, gm): - return None # type: ignore[return-value] + return None return quant_properties