Skip to content

Commit f7c906f

Browse files
authored
Cortex-M: Use q/dq ops in Arm Ethos Runner (#10782)
1 parent 54a14d9 commit f7c906f

File tree

9 files changed

+360
-46
lines changed

9 files changed

+360
-46
lines changed

CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ option(EXECUTORCH_USE_DL "Use libdl library" ON)
242242

243243
option(EXECUTORCH_BUILD_CADENCE "Build the Cadence DSP backend" OFF)
244244

245+
option(EXECUTORCH_BUILD_CORTEX_M "Build the Cortex-M backend" OFF)
246+
245247
#
246248
# pthreadpool: build pthreadpool library. Disable on unsupported platforms
247249
#
@@ -715,6 +717,10 @@ if(EXECUTORCH_BUILD_XNNPACK)
715717
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/xnnpack)
716718
endif()
717719

720+
if(EXECUTORCH_BUILD_CORTEX_M)
721+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cortex_m)
722+
endif()
723+
718724
if(EXECUTORCH_BUILD_DEVTOOLS)
719725
if(NOT EXECUTORCH_BUILD_ARM_BAREMETAL)
720726
set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER

backends/arm/scripts/build_executorch.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ cmake \
129129
-DEXECUTORCH_BUILD_ARM_BAREMETAL=ON \
130130
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
131131
-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \
132+
-DEXECUTORCH_BUILD_CORTEX_M=ON \
132133
-DEXECUTORCH_ENABLE_LOGGING=ON \
133134
${build_devtools_flags} \
134135
${build_with_etdump_flags} \

backends/arm/test/test_arm_baremetal.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,13 @@ test_run_ethosu_fvp() { # End to End model tests using run.sh
154154
echo "${TEST_SUITE_NAME}: Test ethos-u target Ethos-U85"
155155
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=add
156156
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=mul
157+
158+
# Cortex-M op tests
159+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=qadd --bundleio
160+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=qops --bundleio
161+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=qops --bundleio --no_delegate --portable_kernels="aten::sub.out,aten::add.out,aten::mul.out"
162+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=qops --bundleio
163+
157164
echo "${TEST_SUITE_NAME}: PASS"
158165
}
159166

backends/cortex_m/CMakeLists.txt

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Kernel library for Cortex-M operators. Please keep this file formatted by running:
8+
# ~~~
9+
# cmake-format -i CMakeLists.txt
10+
# ~~~
11+
cmake_minimum_required(VERSION 3.19)
12+
13+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
14+
if(NOT CMAKE_CXX_STANDARD)
15+
set(CMAKE_CXX_STANDARD 17)
16+
endif()
17+
18+
# Source root directory for executorch.
19+
if(NOT EXECUTORCH_ROOT)
20+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
21+
endif()
22+
23+
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
24+
include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake)
25+
26+
if(NOT PYTHON_EXECUTABLE)
27+
resolve_python_executable()
28+
endif()
29+
30+
# Cortex-M ops kernel sources
31+
set(_cortex_m_kernels__srcs
32+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp
33+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp
34+
)
35+
36+
# Generate C++ bindings to register kernels into Executorch (for runtime).
37+
# Here select all ops in operators.yaml
38+
set(_yaml_file ${CMAKE_CURRENT_LIST_DIR}/ops/operators.yaml)
39+
gen_selected_ops(LIB_NAME "cortex_m_ops_lib" OPS_SCHEMA_YAML "${_yaml_file}")
40+
41+
# Generate bindings for the kernels
42+
generate_bindings_for_kernels(
43+
LIB_NAME "cortex_m_ops_lib" CUSTOM_OPS_YAML "${_yaml_file}"
44+
)
45+
message("Generated files ${gen_command_sources}")
46+
47+
# Build a library for _cortex_m_kernels_srcs
48+
add_library(cortex_m_kernels ${_cortex_m_kernels__srcs})
49+
target_link_libraries(cortex_m_kernels PRIVATE executorch)
50+
target_compile_options(cortex_m_kernels PUBLIC ${_common_compile_options})
51+
52+
# cortex_m_ops_lib: Register Cortex-M ops kernels into Executorch runtime
53+
gen_operators_lib(
54+
LIB_NAME "cortex_m_ops_lib" KERNEL_LIBS cortex_m_kernels DEPS executorch
55+
)
56+
57+
install(
58+
TARGETS cortex_m_kernels cortex_m_ops_lib
59+
DESTINATION lib
60+
PUBLIC_HEADER DESTINATION include/executorch/backends/cortex_m/ops/
61+
)

backends/cortex_m/ops/op_dequantize_per_tensor.cpp

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ namespace {
2929
*/
3030
void check_dequantize_args(
3131
const Tensor& input,
32+
int64_t zero_point,
3233
int64_t quant_min,
3334
int64_t quant_max,
3435
ScalarType dtype,
@@ -39,6 +40,18 @@ void check_dequantize_args(
3940
"input.scalar_type() %" PRId8 " is not char type",
4041
static_cast<int8_t>(input.scalar_type()));
4142

43+
// Check zp range
44+
ET_CHECK_MSG(
45+
zero_point >= quant_min,
46+
"zero_point must be %" PRId64 " <= quant_min %" PRId64,
47+
zero_point,
48+
quant_min);
49+
ET_CHECK_MSG(
50+
zero_point <= quant_max,
51+
"zero_point must be %" PRId64 " >= quant_max %" PRId64,
52+
zero_point,
53+
quant_max);
54+
4255
// Check output dtype is float
4356
ET_CHECK_MSG(
4457
out.scalar_type() == ScalarType::Float,
@@ -73,18 +86,10 @@ void check_dequantize_args(
7386
/**
7487
* Scalar implementation of quantization for a single value.
7588
*/
76-
template <typename K, typename T>
77-
T dequantize_val(
78-
float scale,
79-
int32_t zero_point,
80-
K value,
81-
int64_t quant_min,
82-
int64_t quant_max) {
83-
(void)quant_min;
84-
(void)quant_max;
85-
return static_cast<T>((static_cast<int32_t>(value) - zero_point) * scale);
89+
template <typename Q, typename F>
90+
F dequantize_val(float scale, int32_t zero_point, Q qvalue) {
91+
return static_cast<F>((static_cast<int32_t>(qvalue) - zero_point) * scale);
8692
}
87-
8893
} // namespace
8994

9095
Tensor& dequantize_per_tensor_out(
@@ -106,29 +111,71 @@ Tensor& dequantize_per_tensor_out(
106111
"Failed to resize out Tensor in dequantize_per_tensor_out");
107112

108113
// Validate input parameters
109-
check_dequantize_args(input, quant_min, quant_max, dtype, out);
114+
check_dequantize_args(input, zero_point, quant_min, quant_max, dtype, out);
110115

111-
// Pre-compute inverse scale for better performance
112116
int32_t zp = static_cast<int32_t>(zero_point);
113-
int32_t qmin = static_cast<int32_t>(quant_min);
114-
int32_t qmax = static_cast<int32_t>(quant_max);
115117

116118
// Get pointers to input and output data
117119
const int8_t* input_data = input.const_data_ptr<int8_t>();
118120
float* out_data = out.mutable_data_ptr<float>();
119121
const size_t numel = input.numel();
120122

123+
size_t i = 0;
121124
#if defined(HAS_HELIUM_SIMD)
122-
// Helium MVE implementation for float32 to int8 quantization
123-
#Error "Implement MVE version!"
124-
#else
125-
// Scalar implementation for float32 to int8 quantization
126-
for (size_t i = 0; i < numel; i++) {
127-
out_data[i] =
128-
dequantize_val<int8_t, float>(scale, zp, input_data[i], qmin, qmax);
125+
// Helium MVE implementation for int8 to float quantization
126+
static uint8x16_t voffset{
127+
0x0,
128+
0x8,
129+
0x4,
130+
0xC,
131+
0x1,
132+
0x9,
133+
0x5,
134+
0xD,
135+
0x2,
136+
0xA,
137+
0x6,
138+
0xE,
139+
0x3,
140+
0xB,
141+
0x7,
142+
0xF};
143+
144+
int16x8_t vzp = vdupq_n_s16(static_cast<int16_t>(zp));
145+
float32x4_t vscale = vdupq_n_f32(static_cast<float>(scale));
146+
147+
for (; i + 15 < numel; i += 16) {
148+
int8x16_t in_084C195D2A6E3B7F =
149+
vldrbq_gather_offset_s8(input_data, voffset);
150+
151+
int16x8_t in_04152637 = vsubq_s16(vmovlbq_s8(in_084C195D2A6E3B7F), vzp);
152+
int16x8_t in_8C9DAEBF = vsubq_s16(vmovltq_s8(in_084C195D2A6E3B7F), vzp);
153+
154+
float32x4_t inf_0123 = vcvtq_f32_s32(vmovlbq_s16(in_04152637));
155+
float32x4_t inf_4567 = vcvtq_f32_s32(vmovltq_s16(in_04152637));
156+
float32x4_t inf_89AB = vcvtq_f32_s32(vmovlbq_s16(in_8C9DAEBF));
157+
float32x4_t inf_CDEF = vcvtq_f32_s32(vmovltq_s16(in_8C9DAEBF));
158+
159+
float32x4_t out_0123 = vmulq_f32(inf_0123, vscale);
160+
float32x4_t out_4567 = vmulq_f32(inf_4567, vscale);
161+
float32x4_t out_89AB = vmulq_f32(inf_89AB, vscale);
162+
float32x4_t out_CDEF = vmulq_f32(inf_CDEF, vscale);
163+
164+
vstrwq_f32(out_data + 0, out_0123);
165+
vstrwq_f32(out_data + 4, out_4567);
166+
vstrwq_f32(out_data + 8, out_89AB);
167+
vstrwq_f32(out_data + 12, out_CDEF);
168+
169+
input_data += 16;
170+
out_data += 16;
129171
}
130-
#endif
172+
#endif // defined(HAS_HELIUM_SIMD)
131173

174+
for (; i < numel; i++) {
175+
*out_data = dequantize_val<int8_t, float>(scale, zp, *input_data);
176+
*input_data++;
177+
*out_data++;
178+
}
132179
return out;
133180
}
134181

backends/cortex_m/ops/op_quantize_per_tensor.cpp

Lines changed: 96 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ void check_quantize_args(
4141
"input.scalar_type() %" PRId8 " is not float type",
4242
static_cast<int8_t>(input.scalar_type()));
4343

44-
// Check output dtype is int8 (Char)
44+
// Check output dtype is int8
4545
ET_CHECK_MSG(
4646
out.scalar_type() == ScalarType::Char,
4747
"out.scalar_type() %" PRId8 " is not int8 (Char)",
4848
static_cast<int8_t>(out.scalar_type()));
4949

50-
// Check dtype is int8 (Char)
50+
// Check dtype is int8
5151
ET_CHECK_MSG(
5252
dtype == ScalarType::Char,
5353
"dtype %" PRId8 " is not int8 (Char)",
@@ -75,18 +75,18 @@ void check_quantize_args(
7575
/**
7676
* Scalar implementation of quantization for a single value.
7777
*/
78-
template <typename T, typename K>
79-
T quantize_val(
80-
float inv_scale,
78+
template <typename Q, typename F>
79+
Q quantize_val(
80+
F inv_scale,
8181
int32_t zero_point,
82-
K value,
82+
F value,
8383
int64_t quant_min,
8484
int64_t quant_max) {
8585
int32_t qvalue =
8686
zero_point + static_cast<int32_t>(std::nearbyint(inv_scale * value));
8787
qvalue = std::max<int32_t>(qvalue, static_cast<int32_t>(quant_min));
8888
qvalue = std::min<int32_t>(qvalue, static_cast<int32_t>(quant_max));
89-
return static_cast<T>(qvalue);
89+
return static_cast<Q>(qvalue);
9090
}
9191

9292
} // namespace
@@ -123,16 +123,97 @@ Tensor& quantize_per_tensor_out(
123123
int8_t* out_data = out.mutable_data_ptr<int8_t>();
124124
const size_t numel = input.numel();
125125

126+
size_t i = 0;
127+
126128
#if defined(HAS_HELIUM_SIMD)
127-
// Helium MVE implementation for float32 to int8 quantization
128-
#Error "Implement MVE version!"
129-
#else
130-
// Scalar implementation for float32 to int8 quantization
131-
for (size_t i = 0; i < numel; i++) {
132-
out_data[i] =
133-
quantize_val<int8_t, float>(inv_scale, zp, input_data[i], qmin, qmax);
129+
// Helium MVE implementation for float32 to int8 quantization
130+
static uint8x16_t voffset{
131+
0x0,
132+
0x8,
133+
0x4,
134+
0xC,
135+
0x1,
136+
0x9,
137+
0x5,
138+
0xD,
139+
0x2,
140+
0xA,
141+
0x6,
142+
0xE,
143+
0x3,
144+
0xB,
145+
0x7,
146+
0xF};
147+
148+
float32x4_t inv_scale_vec = vdupq_n_f32(inv_scale);
149+
150+
// Magic number for float to int conversion, round to nearest even integer
151+
// int magic_round(float f): interpret_as_int32(f + magic_float) - magic_int
152+
// where,
153+
// magic_float = 12582912.0f = (2 ** 23 + 2 ** 22) = (1.5 * 2 ** 23)
154+
// magic_int = 1262485504 = 0x4B400000 = bit_pattern_as_int32(magic_float)
155+
156+
float magic_float = 12582912.0f;
157+
int32_t magic_int = 1262485504;
158+
159+
float32x4_t vmagic_float = vdupq_n_f32(magic_float);
160+
int32x4_t vmagic_int_less_zp =
161+
vdupq_n_s32(magic_int - static_cast<int32_t>(zp));
162+
163+
int16x8_t vqmin = vdupq_n_s16(qmin);
164+
int16x8_t vqmax = vdupq_n_s16(qmax);
165+
166+
// TODO: Measure performnce, we are spilling
167+
for (; i + 15 < numel; i += 16) {
168+
float32x4_t in_0123 = vldrwq_f32(input_data + 0);
169+
float32x4_t in_4567 = vldrwq_f32(input_data + 4);
170+
float32x4_t in_89AB = vldrwq_f32(input_data + 8);
171+
float32x4_t in_CDEF = vldrwq_f32(input_data + 12);
172+
173+
float32x4_t outf_0123 = vfmaq_f32(vmagic_float, in_0123, inv_scale_vec);
174+
float32x4_t outf_4567 = vfmaq_f32(vmagic_float, in_4567, inv_scale_vec);
175+
float32x4_t outf_89AB = vfmaq_f32(vmagic_float, in_89AB, inv_scale_vec);
176+
float32x4_t outf_CDEF = vfmaq_f32(vmagic_float, in_CDEF, inv_scale_vec);
177+
178+
int32x4_t out_0123 =
179+
vsubq_s32(vreinterpretq_s32_f32(outf_0123), vmagic_int_less_zp);
180+
int32x4_t out_4567 =
181+
vsubq_s32(vreinterpretq_s32_f32(outf_4567), vmagic_int_less_zp);
182+
int32x4_t out_89AB =
183+
vsubq_s32(vreinterpretq_s32_f32(outf_89AB), vmagic_int_less_zp);
184+
int32x4_t out_CDEF =
185+
vsubq_s32(vreinterpretq_s32_f32(outf_CDEF), vmagic_int_less_zp);
186+
187+
int16x8_t out_04152637;
188+
int16x8_t out_8C9DAEBF;
189+
out_04152637 = vmovnbq_s32(out_04152637, out_0123);
190+
out_04152637 = vmovntq_s32(out_04152637, out_4567);
191+
out_8C9DAEBF = vmovnbq_s32(out_8C9DAEBF, out_89AB);
192+
out_8C9DAEBF = vmovntq_s32(out_8C9DAEBF, out_CDEF);
193+
194+
int16x8_t out_04152637_clamped =
195+
vminq_s16(vmaxq_s16(out_04152637, vqmin), vqmax);
196+
int16x8_t out_8C9DAEBF_clamped =
197+
vminq_s16(vmaxq_s16(out_8C9DAEBF, vqmin), vqmax);
198+
199+
int8x16_t out_084C195D2A6E3B7F;
200+
out_084C195D2A6E3B7F =
201+
vmovnbq_s16(out_084C195D2A6E3B7F, out_04152637_clamped);
202+
out_084C195D2A6E3B7F =
203+
vmovntq_s16(out_084C195D2A6E3B7F, out_8C9DAEBF_clamped);
204+
205+
vstrbq_scatter_offset_s8(out_data, voffset, out_084C195D2A6E3B7F);
206+
input_data += 16;
207+
out_data += 16;
208+
}
209+
#endif // defined(HAS_HELIUM_SIMD)
210+
211+
for (; i < numel; i++) {
212+
*out_data =
213+
quantize_val<int8_t, float>(inv_scale, zp, *input_data, qmin, qmax);
214+
input_data++;
215+
out_data++;
134216
}
135-
#endif
136217

137218
return out;
138219
}

0 commit comments

Comments
 (0)