Skip to content

Cortex-M: Use q/dq ops in Arm Ethos Runner #10782

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ option(EXECUTORCH_USE_DL "Use libdl library" ON)

option(EXECUTORCH_BUILD_CADENCE "Build the Cadence DSP backend" OFF)

option(EXECUTORCH_BUILD_CORTEX_M "Build the Cortex-M backend" OFF)

#
# pthreadpool: build pthreadpool library. Disable on unsupported platforms
#
Expand Down Expand Up @@ -715,6 +717,10 @@ if(EXECUTORCH_BUILD_XNNPACK)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/xnnpack)
endif()

if(EXECUTORCH_BUILD_CORTEX_M)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cortex_m)
endif()

if(EXECUTORCH_BUILD_DEVTOOLS)
if(NOT EXECUTORCH_BUILD_ARM_BAREMETAL)
set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER
Expand Down
1 change: 1 addition & 0 deletions backends/arm/scripts/build_executorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ cmake \
-DEXECUTORCH_BUILD_ARM_BAREMETAL=ON \
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \
-DEXECUTORCH_BUILD_CORTEX_M=ON \
-DEXECUTORCH_ENABLE_LOGGING=ON \
${build_devtools_flags} \
${build_with_etdump_flags} \
Expand Down
7 changes: 7 additions & 0 deletions backends/arm/test/test_arm_baremetal.sh
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ test_run_ethosu_fvp() { # End to End model tests using run.sh
echo "${TEST_SUITE_NAME}: Test ethos-u target Ethos-U85"
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=add
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=mul

# Cortex-M op tests
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=qadd --bundleio
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=qops --bundleio
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=qops --bundleio --no_delegate --portable_kernels="aten::sub.out,aten::add.out,aten::mul.out"
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=qops --bundleio

echo "${TEST_SUITE_NAME}: PASS"
}

Expand Down
61 changes: 61 additions & 0 deletions backends/cortex_m/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Kernel library for Cortex-M operators. Please keep this file formatted by running:
# ~~~
# cmake-format -i CMakeLists.txt
# ~~~
cmake_minimum_required(VERSION 3.19)

set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17)
endif()

# Source root directory for executorch.
if(NOT EXECUTORCH_ROOT)
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
endif()

include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake)

if(NOT PYTHON_EXECUTABLE)
resolve_python_executable()
endif()

# Cortex-M ops kernel sources
set(_cortex_m_kernels__srcs
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp
)

# Generate C++ bindings to register kernels into Executorch (for runtime).
# Here select all ops in operators.yaml
set(_yaml_file ${CMAKE_CURRENT_LIST_DIR}/ops/operators.yaml)
gen_selected_ops(LIB_NAME "cortex_m_ops_lib" OPS_SCHEMA_YAML "${_yaml_file}")

# Generate bindings for the kernels
generate_bindings_for_kernels(
LIB_NAME "cortex_m_ops_lib" CUSTOM_OPS_YAML "${_yaml_file}"
)
message("Generated files ${gen_command_sources}")

# Build a library for _cortex_m_kernels_srcs
add_library(cortex_m_kernels ${_cortex_m_kernels__srcs})
target_link_libraries(cortex_m_kernels PRIVATE executorch)
target_compile_options(cortex_m_kernels PUBLIC ${_common_compile_options})

# cortex_m_ops_lib: Register Cortex-M ops kernels into Executorch runtime
gen_operators_lib(
LIB_NAME "cortex_m_ops_lib" KERNEL_LIBS cortex_m_kernels DEPS executorch
)

install(
TARGETS cortex_m_kernels cortex_m_ops_lib
DESTINATION lib
PUBLIC_HEADER DESTINATION include/executorch/backends/cortex_m/ops/
)
93 changes: 70 additions & 23 deletions backends/cortex_m/ops/op_dequantize_per_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace {
*/
void check_dequantize_args(
const Tensor& input,
int64_t zero_point,
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
Expand All @@ -39,6 +40,18 @@ void check_dequantize_args(
"input.scalar_type() %" PRId8 " is not char type",
static_cast<int8_t>(input.scalar_type()));

// Check zp range
ET_CHECK_MSG(
zero_point >= quant_min,
"zero_point must be %" PRId64 " <= quant_min %" PRId64,
zero_point,
quant_min);
ET_CHECK_MSG(
zero_point <= quant_max,
"zero_point must be %" PRId64 " >= quant_max %" PRId64,
zero_point,
quant_max);

// Check output dtype is float
ET_CHECK_MSG(
out.scalar_type() == ScalarType::Float,
Expand Down Expand Up @@ -73,18 +86,10 @@ void check_dequantize_args(
/**
* Scalar implementation of quantization for a single value.
*/
template <typename K, typename T>
T dequantize_val(
float scale,
int32_t zero_point,
K value,
int64_t quant_min,
int64_t quant_max) {
(void)quant_min;
(void)quant_max;
return static_cast<T>((static_cast<int32_t>(value) - zero_point) * scale);
template <typename Q, typename F>
F dequantize_val(float scale, int32_t zero_point, Q qvalue) {
return static_cast<F>((static_cast<int32_t>(qvalue) - zero_point) * scale);
}

} // namespace

Tensor& dequantize_per_tensor_out(
Expand All @@ -106,29 +111,71 @@ Tensor& dequantize_per_tensor_out(
"Failed to resize out Tensor in dequantize_per_tensor_out");

// Validate input parameters
check_dequantize_args(input, quant_min, quant_max, dtype, out);
check_dequantize_args(input, zero_point, quant_min, quant_max, dtype, out);

// Pre-compute inverse scale for better performance
int32_t zp = static_cast<int32_t>(zero_point);
int32_t qmin = static_cast<int32_t>(quant_min);
int32_t qmax = static_cast<int32_t>(quant_max);

// Get pointers to input and output data
const int8_t* input_data = input.const_data_ptr<int8_t>();
float* out_data = out.mutable_data_ptr<float>();
const size_t numel = input.numel();

size_t i = 0;
#if defined(HAS_HELIUM_SIMD)
// Helium MVE implementation for float32 to int8 quantization
#Error "Implement MVE version!"
#else
// Scalar implementation for float32 to int8 quantization
for (size_t i = 0; i < numel; i++) {
out_data[i] =
dequantize_val<int8_t, float>(scale, zp, input_data[i], qmin, qmax);
// Helium MVE implementation for int8 to float quantization
static uint8x16_t voffset{
0x0,
0x8,
0x4,
0xC,
0x1,
0x9,
0x5,
0xD,
0x2,
0xA,
0x6,
0xE,
0x3,
0xB,
0x7,
0xF};

int16x8_t vzp = vdupq_n_s16(static_cast<int16_t>(zp));
float32x4_t vscale = vdupq_n_f32(static_cast<float>(scale));

for (; i + 15 < numel; i += 16) {
int8x16_t in_084C195D2A6E3B7F =
vldrbq_gather_offset_s8(input_data, voffset);

int16x8_t in_04152637 = vsubq_s16(vmovlbq_s8(in_084C195D2A6E3B7F), vzp);
int16x8_t in_8C9DAEBF = vsubq_s16(vmovltq_s8(in_084C195D2A6E3B7F), vzp);

float32x4_t inf_0123 = vcvtq_f32_s32(vmovlbq_s16(in_04152637));
float32x4_t inf_4567 = vcvtq_f32_s32(vmovltq_s16(in_04152637));
float32x4_t inf_89AB = vcvtq_f32_s32(vmovlbq_s16(in_8C9DAEBF));
float32x4_t inf_CDEF = vcvtq_f32_s32(vmovltq_s16(in_8C9DAEBF));

float32x4_t out_0123 = vmulq_f32(inf_0123, vscale);
float32x4_t out_4567 = vmulq_f32(inf_4567, vscale);
float32x4_t out_89AB = vmulq_f32(inf_89AB, vscale);
float32x4_t out_CDEF = vmulq_f32(inf_CDEF, vscale);

vstrwq_f32(out_data + 0, out_0123);
vstrwq_f32(out_data + 4, out_4567);
vstrwq_f32(out_data + 8, out_89AB);
vstrwq_f32(out_data + 12, out_CDEF);

input_data += 16;
out_data += 16;
}
#endif
#endif // defined(HAS_HELIUM_SIMD)

for (; i < numel; i++) {
*out_data = dequantize_val<int8_t, float>(scale, zp, *input_data);
*input_data++;
*out_data++;
}
return out;
}

Expand Down
111 changes: 96 additions & 15 deletions backends/cortex_m/ops/op_quantize_per_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ void check_quantize_args(
"input.scalar_type() %" PRId8 " is not float type",
static_cast<int8_t>(input.scalar_type()));

// Check output dtype is int8 (Char)
// Check output dtype is int8
ET_CHECK_MSG(
out.scalar_type() == ScalarType::Char,
"out.scalar_type() %" PRId8 " is not int8 (Char)",
static_cast<int8_t>(out.scalar_type()));

// Check dtype is int8 (Char)
// Check dtype is int8
ET_CHECK_MSG(
dtype == ScalarType::Char,
"dtype %" PRId8 " is not int8 (Char)",
Expand Down Expand Up @@ -75,18 +75,18 @@ void check_quantize_args(
/**
* Scalar implementation of quantization for a single value.
*/
template <typename T, typename K>
T quantize_val(
float inv_scale,
template <typename Q, typename F>
Q quantize_val(
F inv_scale,
int32_t zero_point,
K value,
F value,
int64_t quant_min,
int64_t quant_max) {
int32_t qvalue =
zero_point + static_cast<int32_t>(std::nearbyint(inv_scale * value));
qvalue = std::max<int32_t>(qvalue, static_cast<int32_t>(quant_min));
qvalue = std::min<int32_t>(qvalue, static_cast<int32_t>(quant_max));
return static_cast<T>(qvalue);
return static_cast<Q>(qvalue);
}

} // namespace
Expand Down Expand Up @@ -123,16 +123,97 @@ Tensor& quantize_per_tensor_out(
int8_t* out_data = out.mutable_data_ptr<int8_t>();
const size_t numel = input.numel();

size_t i = 0;

#if defined(HAS_HELIUM_SIMD)
// Helium MVE implementation for float32 to int8 quantization
#Error "Implement MVE version!"
#else
// Scalar implementation for float32 to int8 quantization
for (size_t i = 0; i < numel; i++) {
out_data[i] =
quantize_val<int8_t, float>(inv_scale, zp, input_data[i], qmin, qmax);
// Helium MVE implementation for float32 to int8 quantization
static uint8x16_t voffset{
0x0,
0x8,
0x4,
0xC,
0x1,
0x9,
0x5,
0xD,
0x2,
0xA,
0x6,
0xE,
0x3,
0xB,
0x7,
0xF};

float32x4_t inv_scale_vec = vdupq_n_f32(inv_scale);

// Magic number for float to int conversion, round to nearest even integer
// int magic_round(float f): interpret_as_int32(f + magic_float) - magic_int
// where,
// magic_float = 12582912.0f = (2 ** 23 + 2 ** 22) = (1.5 * 2 ** 23)
// magic_int = 1262485504 = 0x4B400000 = bit_pattern_as_int32(magic_float)

float magic_float = 12582912.0f;
int32_t magic_int = 1262485504;

float32x4_t vmagic_float = vdupq_n_f32(magic_float);
int32x4_t vmagic_int_less_zp =
vdupq_n_s32(magic_int - static_cast<int32_t>(zp));

int16x8_t vqmin = vdupq_n_s16(qmin);
int16x8_t vqmax = vdupq_n_s16(qmax);

// TODO: Measure performnce, we are spilling
for (; i + 15 < numel; i += 16) {
float32x4_t in_0123 = vldrwq_f32(input_data + 0);
float32x4_t in_4567 = vldrwq_f32(input_data + 4);
float32x4_t in_89AB = vldrwq_f32(input_data + 8);
float32x4_t in_CDEF = vldrwq_f32(input_data + 12);

float32x4_t outf_0123 = vfmaq_f32(vmagic_float, in_0123, inv_scale_vec);
float32x4_t outf_4567 = vfmaq_f32(vmagic_float, in_4567, inv_scale_vec);
float32x4_t outf_89AB = vfmaq_f32(vmagic_float, in_89AB, inv_scale_vec);
float32x4_t outf_CDEF = vfmaq_f32(vmagic_float, in_CDEF, inv_scale_vec);

int32x4_t out_0123 =
vsubq_s32(vreinterpretq_s32_f32(outf_0123), vmagic_int_less_zp);
int32x4_t out_4567 =
vsubq_s32(vreinterpretq_s32_f32(outf_4567), vmagic_int_less_zp);
int32x4_t out_89AB =
vsubq_s32(vreinterpretq_s32_f32(outf_89AB), vmagic_int_less_zp);
int32x4_t out_CDEF =
vsubq_s32(vreinterpretq_s32_f32(outf_CDEF), vmagic_int_less_zp);

int16x8_t out_04152637;
int16x8_t out_8C9DAEBF;
out_04152637 = vmovnbq_s32(out_04152637, out_0123);
out_04152637 = vmovntq_s32(out_04152637, out_4567);
out_8C9DAEBF = vmovnbq_s32(out_8C9DAEBF, out_89AB);
out_8C9DAEBF = vmovntq_s32(out_8C9DAEBF, out_CDEF);

int16x8_t out_04152637_clamped =
vminq_s16(vmaxq_s16(out_04152637, vqmin), vqmax);
int16x8_t out_8C9DAEBF_clamped =
vminq_s16(vmaxq_s16(out_8C9DAEBF, vqmin), vqmax);

int8x16_t out_084C195D2A6E3B7F;
out_084C195D2A6E3B7F =
vmovnbq_s16(out_084C195D2A6E3B7F, out_04152637_clamped);
out_084C195D2A6E3B7F =
vmovntq_s16(out_084C195D2A6E3B7F, out_8C9DAEBF_clamped);

vstrbq_scatter_offset_s8(out_data, voffset, out_084C195D2A6E3B7F);
input_data += 16;
out_data += 16;
}
#endif // defined(HAS_HELIUM_SIMD)

for (; i < numel; i++) {
*out_data =
quantize_val<int8_t, float>(inv_scale, zp, *input_data, qmin, qmax);
input_data++;
out_data++;
}
#endif

return out;
}
Expand Down
Loading
Loading