diff --git a/CMakeLists.txt b/CMakeLists.txt index 76c75270d5f..4d292c209a6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -242,6 +242,8 @@ option(EXECUTORCH_USE_DL "Use libdl library" ON) option(EXECUTORCH_BUILD_CADENCE "Build the Cadence DSP backend" OFF) +option(EXECUTORCH_BUILD_CORTEX_M "Build the Cortex-M backend" OFF) + # # pthreadpool: build pthreadpool library. Disable on unsupported platforms # @@ -715,6 +717,10 @@ if(EXECUTORCH_BUILD_XNNPACK) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/xnnpack) endif() +if(EXECUTORCH_BUILD_CORTEX_M) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cortex_m) +endif() + if(EXECUTORCH_BUILD_DEVTOOLS) if(NOT EXECUTORCH_BUILD_ARM_BAREMETAL) set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER diff --git a/backends/arm/scripts/build_executorch.sh b/backends/arm/scripts/build_executorch.sh index 87d9fd23070..573f93221d4 100755 --- a/backends/arm/scripts/build_executorch.sh +++ b/backends/arm/scripts/build_executorch.sh @@ -129,6 +129,7 @@ cmake \ -DEXECUTORCH_BUILD_ARM_BAREMETAL=ON \ -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_BUILD_CORTEX_M=ON \ -DEXECUTORCH_ENABLE_LOGGING=ON \ ${build_devtools_flags} \ ${build_with_etdump_flags} \ diff --git a/backends/arm/test/test_arm_baremetal.sh b/backends/arm/test/test_arm_baremetal.sh index 48cee9acd95..476d417a69a 100755 --- a/backends/arm/test/test_arm_baremetal.sh +++ b/backends/arm/test/test_arm_baremetal.sh @@ -154,6 +154,13 @@ test_run_ethosu_fvp() { # End to End model tests using run.sh echo "${TEST_SUITE_NAME}: Test ethos-u target Ethos-U85" examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=add examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=mul + + # Cortex-M op tests + examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=qadd --bundleio + examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=qops --bundleio + examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=qops --bundleio --no_delegate --portable_kernels="aten::sub.out,aten::add.out,aten::mul.out" + examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=qops --bundleio + echo "${TEST_SUITE_NAME}: PASS" } diff --git a/backends/cortex_m/CMakeLists.txt b/backends/cortex_m/CMakeLists.txt new file mode 100644 index 00000000000..39638bf0ee4 --- /dev/null +++ b/backends/cortex_m/CMakeLists.txt @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Kernel library for Cortex-M operators. Please keep this file formatted by running: +# ~~~ +# cmake-format -i CMakeLists.txt +# ~~~ +cmake_minimum_required(VERSION 3.19) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 17) +endif() + +# Source root directory for executorch. +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) +endif() + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) +include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake) + +if(NOT PYTHON_EXECUTABLE) + resolve_python_executable() +endif() + +# Cortex-M ops kernel sources +set(_cortex_m_kernels__srcs + ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp +) + +# Generate C++ bindings to register kernels into Executorch (for runtime). +# Here select all ops in operators.yaml +set(_yaml_file ${CMAKE_CURRENT_LIST_DIR}/ops/operators.yaml) +gen_selected_ops(LIB_NAME "cortex_m_ops_lib" OPS_SCHEMA_YAML "${_yaml_file}") + +# Generate bindings for the kernels +generate_bindings_for_kernels( + LIB_NAME "cortex_m_ops_lib" CUSTOM_OPS_YAML "${_yaml_file}" +) +message("Generated files ${gen_command_sources}") + +# Build a library for _cortex_m_kernels_srcs +add_library(cortex_m_kernels ${_cortex_m_kernels__srcs}) +target_link_libraries(cortex_m_kernels PRIVATE executorch) +target_compile_options(cortex_m_kernels PUBLIC ${_common_compile_options}) + +# cortex_m_ops_lib: Register Cortex-M ops kernels into Executorch runtime +gen_operators_lib( + LIB_NAME "cortex_m_ops_lib" KERNEL_LIBS cortex_m_kernels DEPS executorch +) + +install( + TARGETS cortex_m_kernels cortex_m_ops_lib + DESTINATION lib + PUBLIC_HEADER DESTINATION include/executorch/backends/cortex_m/ops/ +) diff --git a/backends/cortex_m/ops/op_dequantize_per_tensor.cpp b/backends/cortex_m/ops/op_dequantize_per_tensor.cpp index 1011de73be7..6d3f3698c67 100644 --- a/backends/cortex_m/ops/op_dequantize_per_tensor.cpp +++ b/backends/cortex_m/ops/op_dequantize_per_tensor.cpp @@ -29,6 +29,7 @@ namespace { */ void check_dequantize_args( const Tensor& input, + int64_t zero_point, int64_t quant_min, int64_t quant_max, ScalarType dtype, @@ -39,6 +40,18 @@ void check_dequantize_args( "input.scalar_type() %" PRId8 " is not char type", static_cast(input.scalar_type())); + // Check zp range + ET_CHECK_MSG( + zero_point >= quant_min, + "zero_point must be %" PRId64 " <= quant_min %" PRId64, + zero_point, + quant_min); + ET_CHECK_MSG( + zero_point <= quant_max, + "zero_point must be %" PRId64 " >= quant_max %" PRId64, + zero_point, + quant_max); + // Check output dtype is float ET_CHECK_MSG( out.scalar_type() == ScalarType::Float, @@ -73,18 +86,10 @@ void check_dequantize_args( /** * Scalar implementation of quantization for a single value. */ -template -T dequantize_val( - float scale, - int32_t zero_point, - K value, - int64_t quant_min, - int64_t quant_max) { - (void)quant_min; - (void)quant_max; - return static_cast((static_cast(value) - zero_point) * scale); +template +F dequantize_val(float scale, int32_t zero_point, Q qvalue) { + return static_cast((static_cast(qvalue) - zero_point) * scale); } - } // namespace Tensor& dequantize_per_tensor_out( @@ -106,29 +111,71 @@ Tensor& dequantize_per_tensor_out( "Failed to resize out Tensor in dequantize_per_tensor_out"); // Validate input parameters - check_dequantize_args(input, quant_min, quant_max, dtype, out); + check_dequantize_args(input, zero_point, quant_min, quant_max, dtype, out); - // Pre-compute inverse scale for better performance int32_t zp = static_cast(zero_point); - int32_t qmin = static_cast(quant_min); - int32_t qmax = static_cast(quant_max); // Get pointers to input and output data const int8_t* input_data = input.const_data_ptr(); float* out_data = out.mutable_data_ptr(); const size_t numel = input.numel(); + size_t i = 0; #if defined(HAS_HELIUM_SIMD) -// Helium MVE implementation for float32 to int8 quantization -#Error "Implement MVE version!" -#else - // Scalar implementation for float32 to int8 quantization - for (size_t i = 0; i < numel; i++) { - out_data[i] = - dequantize_val(scale, zp, input_data[i], qmin, qmax); + // Helium MVE implementation for int8 to float quantization + static uint8x16_t voffset{ + 0x0, + 0x8, + 0x4, + 0xC, + 0x1, + 0x9, + 0x5, + 0xD, + 0x2, + 0xA, + 0x6, + 0xE, + 0x3, + 0xB, + 0x7, + 0xF}; + + int16x8_t vzp = vdupq_n_s16(static_cast(zp)); + float32x4_t vscale = vdupq_n_f32(static_cast(scale)); + + for (; i + 15 < numel; i += 16) { + int8x16_t in_084C195D2A6E3B7F = + vldrbq_gather_offset_s8(input_data, voffset); + + int16x8_t in_04152637 = vsubq_s16(vmovlbq_s8(in_084C195D2A6E3B7F), vzp); + int16x8_t in_8C9DAEBF = vsubq_s16(vmovltq_s8(in_084C195D2A6E3B7F), vzp); + + float32x4_t inf_0123 = vcvtq_f32_s32(vmovlbq_s16(in_04152637)); + float32x4_t inf_4567 = vcvtq_f32_s32(vmovltq_s16(in_04152637)); + float32x4_t inf_89AB = vcvtq_f32_s32(vmovlbq_s16(in_8C9DAEBF)); + float32x4_t inf_CDEF = vcvtq_f32_s32(vmovltq_s16(in_8C9DAEBF)); + + float32x4_t out_0123 = vmulq_f32(inf_0123, vscale); + float32x4_t out_4567 = vmulq_f32(inf_4567, vscale); + float32x4_t out_89AB = vmulq_f32(inf_89AB, vscale); + float32x4_t out_CDEF = vmulq_f32(inf_CDEF, vscale); + + vstrwq_f32(out_data + 0, out_0123); + vstrwq_f32(out_data + 4, out_4567); + vstrwq_f32(out_data + 8, out_89AB); + vstrwq_f32(out_data + 12, out_CDEF); + + input_data += 16; + out_data += 16; } -#endif +#endif // defined(HAS_HELIUM_SIMD) + for (; i < numel; i++) { + *out_data = dequantize_val(scale, zp, *input_data); + *input_data++; + *out_data++; + } return out; } diff --git a/backends/cortex_m/ops/op_quantize_per_tensor.cpp b/backends/cortex_m/ops/op_quantize_per_tensor.cpp index 25385602e58..d92d2666a8f 100644 --- a/backends/cortex_m/ops/op_quantize_per_tensor.cpp +++ b/backends/cortex_m/ops/op_quantize_per_tensor.cpp @@ -41,13 +41,13 @@ void check_quantize_args( "input.scalar_type() %" PRId8 " is not float type", static_cast(input.scalar_type())); - // Check output dtype is int8 (Char) + // Check output dtype is int8 ET_CHECK_MSG( out.scalar_type() == ScalarType::Char, "out.scalar_type() %" PRId8 " is not int8 (Char)", static_cast(out.scalar_type())); - // Check dtype is int8 (Char) + // Check dtype is int8 ET_CHECK_MSG( dtype == ScalarType::Char, "dtype %" PRId8 " is not int8 (Char)", @@ -75,18 +75,18 @@ void check_quantize_args( /** * Scalar implementation of quantization for a single value. */ -template -T quantize_val( - float inv_scale, +template +Q quantize_val( + F inv_scale, int32_t zero_point, - K value, + F value, int64_t quant_min, int64_t quant_max) { int32_t qvalue = zero_point + static_cast(std::nearbyint(inv_scale * value)); qvalue = std::max(qvalue, static_cast(quant_min)); qvalue = std::min(qvalue, static_cast(quant_max)); - return static_cast(qvalue); + return static_cast(qvalue); } } // namespace @@ -123,16 +123,97 @@ Tensor& quantize_per_tensor_out( int8_t* out_data = out.mutable_data_ptr(); const size_t numel = input.numel(); + size_t i = 0; + #if defined(HAS_HELIUM_SIMD) -// Helium MVE implementation for float32 to int8 quantization -#Error "Implement MVE version!" -#else - // Scalar implementation for float32 to int8 quantization - for (size_t i = 0; i < numel; i++) { - out_data[i] = - quantize_val(inv_scale, zp, input_data[i], qmin, qmax); + // Helium MVE implementation for float32 to int8 quantization + static uint8x16_t voffset{ + 0x0, + 0x8, + 0x4, + 0xC, + 0x1, + 0x9, + 0x5, + 0xD, + 0x2, + 0xA, + 0x6, + 0xE, + 0x3, + 0xB, + 0x7, + 0xF}; + + float32x4_t inv_scale_vec = vdupq_n_f32(inv_scale); + + // Magic number for float to int conversion, round to nearest even integer + // int magic_round(float f): interpret_as_int32(f + magic_float) - magic_int + // where, + // magic_float = 12582912.0f = (2 ** 23 + 2 ** 22) = (1.5 * 2 ** 23) + // magic_int = 1262485504 = 0x4B400000 = bit_pattern_as_int32(magic_float) + + float magic_float = 12582912.0f; + int32_t magic_int = 1262485504; + + float32x4_t vmagic_float = vdupq_n_f32(magic_float); + int32x4_t vmagic_int_less_zp = + vdupq_n_s32(magic_int - static_cast(zp)); + + int16x8_t vqmin = vdupq_n_s16(qmin); + int16x8_t vqmax = vdupq_n_s16(qmax); + + // TODO: Measure performnce, we are spilling + for (; i + 15 < numel; i += 16) { + float32x4_t in_0123 = vldrwq_f32(input_data + 0); + float32x4_t in_4567 = vldrwq_f32(input_data + 4); + float32x4_t in_89AB = vldrwq_f32(input_data + 8); + float32x4_t in_CDEF = vldrwq_f32(input_data + 12); + + float32x4_t outf_0123 = vfmaq_f32(vmagic_float, in_0123, inv_scale_vec); + float32x4_t outf_4567 = vfmaq_f32(vmagic_float, in_4567, inv_scale_vec); + float32x4_t outf_89AB = vfmaq_f32(vmagic_float, in_89AB, inv_scale_vec); + float32x4_t outf_CDEF = vfmaq_f32(vmagic_float, in_CDEF, inv_scale_vec); + + int32x4_t out_0123 = + vsubq_s32(vreinterpretq_s32_f32(outf_0123), vmagic_int_less_zp); + int32x4_t out_4567 = + vsubq_s32(vreinterpretq_s32_f32(outf_4567), vmagic_int_less_zp); + int32x4_t out_89AB = + vsubq_s32(vreinterpretq_s32_f32(outf_89AB), vmagic_int_less_zp); + int32x4_t out_CDEF = + vsubq_s32(vreinterpretq_s32_f32(outf_CDEF), vmagic_int_less_zp); + + int16x8_t out_04152637; + int16x8_t out_8C9DAEBF; + out_04152637 = vmovnbq_s32(out_04152637, out_0123); + out_04152637 = vmovntq_s32(out_04152637, out_4567); + out_8C9DAEBF = vmovnbq_s32(out_8C9DAEBF, out_89AB); + out_8C9DAEBF = vmovntq_s32(out_8C9DAEBF, out_CDEF); + + int16x8_t out_04152637_clamped = + vminq_s16(vmaxq_s16(out_04152637, vqmin), vqmax); + int16x8_t out_8C9DAEBF_clamped = + vminq_s16(vmaxq_s16(out_8C9DAEBF, vqmin), vqmax); + + int8x16_t out_084C195D2A6E3B7F; + out_084C195D2A6E3B7F = + vmovnbq_s16(out_084C195D2A6E3B7F, out_04152637_clamped); + out_084C195D2A6E3B7F = + vmovntq_s16(out_084C195D2A6E3B7F, out_8C9DAEBF_clamped); + + vstrbq_scatter_offset_s8(out_data, voffset, out_084C195D2A6E3B7F); + input_data += 16; + out_data += 16; + } +#endif // defined(HAS_HELIUM_SIMD) + + for (; i < numel; i++) { + *out_data = + quantize_val(inv_scale, zp, *input_data, qmin, qmax); + input_data++; + out_data++; } -#endif return out; } diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 3d6acf2b94a..73fa4b24d4e 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -40,6 +40,11 @@ ) from executorch.backends.arm.vgf_partitioner import VgfPartitioner + +# To use Cortex-M backend +from executorch.backends.cortex_m.passes.replace_quant_nodes_pass import ( + ReplaceQuantNodesPass, +) from executorch.devtools.backend_debug import get_delegation_info from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite @@ -59,6 +64,7 @@ from ..models import MODEL_NAME_TO_MODEL from ..models.model_factory import EagerModelFactory + FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.WARNING, format=FORMAT) @@ -216,6 +222,54 @@ def forward(self, x, y): can_delegate = True +class QuantAddTest(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a): + return a + a + + example_input = (torch.rand([13, 3], dtype=torch.float32),) # a - normal values + can_delegate = True # when quantized + + +class QuantAddTest2(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b): + p = a + a + q = b + b + r = p + q + return p, q, r + + example_input = ( + torch.randn([13, 7, 3], dtype=torch.float32), + torch.randn([13, 7, 3], dtype=torch.float32), + ) + can_delegate = True # when quantized + + +class QuantOpTest(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, w, x, y, z): + o1 = w - x + o2 = o1 + y + o3 = o2 * z + return o1, o2, o3 + + example_input = ( + torch.randn([3, 1, 2], dtype=torch.float32), # w - normal values + torch.randn([3, 5, 2], dtype=torch.float32), # x - normal values + torch.randn([3, 5, 1], dtype=torch.float32) + * -0.000001, # y - small -ve values, needs to be calibration for tests + torch.randn([3, 5, 2], dtype=torch.float32) * 1000, # z - large values + ) + can_delegate = True # when quantized + + class SoftmaxModule(torch.nn.Module): def __init__(self): super().__init__() @@ -241,6 +295,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): "add": AddModule, "add2": AddModule2, "add3": AddModule3, + "qadd": QuantAddTest, + "qadd2": QuantAddTest2, + "qops": QuantOpTest, "softmax": SoftmaxModule, "MultipleOutputsModule": MultipleOutputsModule, } @@ -255,6 +312,17 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): torch.randn(32, 5), torch.randn(32, 5), ), + "qadd": (torch.randn(32, 2, 1),), + "qadd2": ( + torch.randn(32, 2, 1), + torch.randn(32, 2, 1), + ), + "qops": ( + torch.randn(32, 2, 1), + torch.randn(32, 2, 1), + torch.randn(32, 2, 1) * -0.000001, + torch.randn(32, 2, 1) * 1000, + ), "softmax": (torch.randn(32, 2, 2),), } @@ -656,6 +724,7 @@ def to_edge_TOSA_delegate( _check_ir_validity=False, ), ) + return model_int8, edge @@ -681,9 +750,18 @@ def to_edge_no_delegate(exported_program, args, model: torch.nn.Module, example_ _check_ir_validity=False, ), ) + return model_int8, edge +def transform_for_cortex_m_backend(edge): + # Let's make sure we are using optimized Cortex M backend + # NB: If we can't find and replace ops those are expected to be replaced, + # bad things will happen at runtime, like "missing operator" errors! + edge = edge.transform([ReplaceQuantNodesPass()]) + return edge + + if __name__ == "__main__": # noqa: C901 args = get_args() @@ -715,6 +793,9 @@ def to_edge_no_delegate(exported_program, args, model: torch.nn.Module, example_ exported_program, args, model, example_inputs ) + # Transform so we can use ops from the Cortex M backend + edge = transform_for_cortex_m_backend(edge) + dump_delegation_info(edge, args.intermediates) try: @@ -759,7 +840,9 @@ def to_edge_no_delegate(exported_program, args, model: torch.nn.Module, example_ output_name = os.path.join(args.output, output_name) if args.bundleio: - save_bpte_program(exec_prog, original_model, output_name) + # Realize the quantization impact on numerics when generating reference output + reference_model = original_model if not model_int8 else model_int8 + save_bpte_program(exec_prog, reference_model, output_name) print(f"Bundle PTE file saved as {output_name}") else: save_pte_program(exec_prog, output_name) diff --git a/examples/arm/executor_runner/CMakeLists.txt b/examples/arm/executor_runner/CMakeLists.txt index 63cdcc45aad..1568bef0301 100644 --- a/examples/arm/executor_runner/CMakeLists.txt +++ b/examples/arm/executor_runner/CMakeLists.txt @@ -492,7 +492,6 @@ set_property( PROPERTY IMPORTED_LOCATION "${ET_BUILD_DIR_PATH}/kernels/portable/libportable_kernels.a" ) - add_library(quantized_ops_lib STATIC IMPORTED) set_property( TARGET quantized_ops_lib @@ -505,7 +504,18 @@ set_property( PROPERTY IMPORTED_LOCATION "${ET_BUILD_DIR_PATH}/kernels/quantized/libquantized_kernels.a" ) - +add_library(cortex_m_ops_lib STATIC IMPORTED) +set_property( + TARGET cortex_m_ops_lib + PROPERTY IMPORTED_LOCATION + "${ET_BUILD_DIR_PATH}/backends/cortex_m/libcortex_m_ops_lib.a" +) +add_library(cortex_m_kernels STATIC IMPORTED) +set_property( + TARGET cortex_m_kernels + PROPERTY IMPORTED_LOCATION + "${ET_BUILD_DIR_PATH}/backends/cortex_m/libcortex_m_kernels.a" +) add_library(extension_runner_util STATIC IMPORTED) set_property( TARGET extension_runner_util @@ -546,9 +556,11 @@ list(APPEND arm_executor_runner_link executorch "-Wl,--whole-archive" executorch_delegate_ethos_u + cortex_m_ops_lib quantized_ops_lib portable_ops_lib quantized_kernels + cortex_m_kernels portable_kernels "-Wl,--no-whole-archive" -Xlinker -Map=arm_executor_runner.map @@ -561,7 +573,7 @@ if(EXECUTORCH_ENABLE_EVENT_TRACER) set_property( TARGET etdump PROPERTY IMPORTED_LOCATION - "${ET_BUILD_DIR_PATH}/lib/libetdump.a" + "${ET_BUILD_DIR_PATH}/lib/libetdump.a" ) if(CMAKE_BUILD_TYPE MATCHES "Debug") @@ -574,7 +586,7 @@ if(EXECUTORCH_ENABLE_EVENT_TRACER) set_property( TARGET ${FLATCCRT_LIB} PROPERTY IMPORTED_LOCATION - "${ET_BUILD_DIR_PATH}/lib/lib${FLATCCRT_LIB}.a" + "${ET_BUILD_DIR_PATH}/lib/lib${FLATCCRT_LIB}.a" ) list(APPEND arm_executor_runner_link @@ -643,4 +655,4 @@ if(SEMIHOSTING) ${ETHOS_SDK_PATH}/core_platform/targets/${TARGET_BOARD}/retarget.c PROPERTIES HEADER_FILE_ONLY TRUE ) -endif() \ No newline at end of file +endif() diff --git a/examples/arm/run.sh b/examples/arm/run.sh index 01699087443..ed1cbc5e015 100755 --- a/examples/arm/run.sh +++ b/examples/arm/run.sh @@ -177,8 +177,24 @@ backends/arm/scripts/build_portable_kernels.sh --et_build_root="${et_build_root} if [[ -z "$model_name" ]]; then # the test models run, and whether to delegate - test_model=( "softmax" "add" "add3" "mv2" ) - model_compiler_flags=( "" "--delegate" "--delegate" "--delegate --quantize" ) + test_model=( + "softmax" # 0 + "add" # 1 + "add3" # 2 + "qadd" # 3 + "qadd3" # 4 + "qops" # 5 + "mv2" # 6 + ) + model_compiler_flags=( + "" # 0 softmax + "--delegate" # 1 add + "--delegate" # 2 add3 + "--delegate --quantize" # 3 qadd + "--delegate --quantize" # 4 qadd3 + "--delegate --quantize" # 5 qops + "--delegate --quantize" # 6 mv2 + ) else test_model=( "$model_name" ) model_compiler_flags=( "$aot_arm_compiler_flag_delegate $aot_arm_compiler_flag_quantize $aot_arm_compiler_flags" )