Skip to content

Commit ec0cfcc

Browse files
committed
Update base for Update on "[Executorch][llm] Enable leveraging ring kv cache via module swap"
This allows us to make some of the attention modules to use sliding window kv cache. Will help enable models like gemma3. Differential Revision: [D73891426](https://our.internmc.facebook.com/intern/diff/D73891426/) [ghstack-poisoned]
2 parents 7638a4b + bf50527 commit ec0cfcc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+794
-292
lines changed

.ci/scripts/test_model.sh

-4
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,6 @@ test_model() {
8787
bash examples/models/llava/install_requirements.sh
8888
STRICT="--no-strict"
8989
fi
90-
if [[ "$MODEL_NAME" == "llama3_2_vision_encoder" || "$MODEL_NAME" == "llama3_2_text_decoder" ]]; then
91-
# Install requirements for llama vision.
92-
bash examples/models/llama3_2_vision/install_requirements.sh
93-
fi
9490
if [[ "${MODEL_NAME}" == "qwen2_5" ]]; then
9591
# Install requirements for export_llama
9692
bash examples/models/llama/install_requirements.sh

.ci/scripts/unittest-linux.sh

-3
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@ if [[ "$BUILD_TOOL" == "cmake" ]]; then
2424
CMAKE_ARGS="-DEXECUTORCH_BUILD_PYBIND=ON -DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON" \
2525
.ci/scripts/setup-linux.sh "$@"
2626

27-
# Install llama3_2_vision dependencies.
28-
PYTHON_EXECUTABLE=python ./examples/models/llama3_2_vision/install_requirements.sh
29-
3027
.ci/scripts/unittest-linux-cmake.sh
3128
elif [[ "$BUILD_TOOL" == "buck2" ]]; then
3229
# Removing this breaks sccache in the Buck build, apparently

.ci/scripts/unittest-macos.sh

-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ if [[ "$BUILD_TOOL" == "cmake" ]]; then
2929
# Install llama3_2_vision dependencies.
3030
PYTHON_EXECUTABLE=python \
3131
${CONDA_RUN} --no-capture-output \
32-
./examples/models/llama3_2_vision/install_requirements.sh
3332

3433
.ci/scripts/unittest-macos-cmake.sh
3534
elif [[ "$BUILD_TOOL" == "buck2" ]]; then

.lintrunner.toml

+7
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,13 @@ exclude_patterns = [
220220
'extension/**',
221221
'kernels/optimized/**',
222222
# Justified <functional> include.
223+
'kernels/portable/cpu/op_bitwise*.cpp',
224+
'kernels/portable/cpu/op_eq.cpp',
225+
'kernels/portable/cpu/op_ge.cpp',
226+
'kernels/portable/cpu/op_gt.cpp',
227+
'kernels/portable/cpu/op_le.cpp',
228+
'kernels/portable/cpu/op_lt.cpp',
229+
'kernels/portable/cpu/op_ne.cpp',
223230
'runtime/kernel/thread_parallel_interface.h',
224231
'scripts/**',
225232
'third-party/**',

CMakeLists.txt

+6
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ option(EXECUTORCH_USE_DL "Use libdl library" ON)
242242

243243
option(EXECUTORCH_BUILD_CADENCE "Build the Cadence DSP backend" OFF)
244244

245+
option(EXECUTORCH_BUILD_CORTEX_M "Build the Cortex-M backend" OFF)
246+
245247
#
246248
# pthreadpool: build pthreadpool library. Disable on unsupported platforms
247249
#
@@ -715,6 +717,10 @@ if(EXECUTORCH_BUILD_XNNPACK)
715717
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/xnnpack)
716718
endif()
717719

720+
if(EXECUTORCH_BUILD_CORTEX_M)
721+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cortex_m)
722+
endif()
723+
718724
if(EXECUTORCH_BUILD_DEVTOOLS)
719725
if(NOT EXECUTORCH_BUILD_ARM_BAREMETAL)
720726
set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER

backends/arm/scripts/build_executorch.sh

+1
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ cmake \
129129
-DEXECUTORCH_BUILD_ARM_BAREMETAL=ON \
130130
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
131131
-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \
132+
-DEXECUTORCH_BUILD_CORTEX_M=ON \
132133
-DEXECUTORCH_ENABLE_LOGGING=ON \
133134
${build_devtools_flags} \
134135
${build_with_etdump_flags} \

backends/arm/scripts/pre-push

+7-5
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
# non-interactive mode. "$#" gives the number of positional arguments.
99
[ "$#" -eq 0 ] && is_script_interactive=1 || is_script_interactive=0
1010

11-
RESET='\e[0m'
12-
RED='\e[31m'
13-
GREEN='\e[32m'
14-
YELLOW='\e[33m'
15-
BLUE='\e[34m'
11+
if [ $is_script_interactive -eq 1 ]; then
12+
RESET='\e[0m'
13+
RED='\e[31m'
14+
GREEN='\e[32m'
15+
YELLOW='\e[33m'
16+
BLUE='\e[34m'
17+
fi
1618

1719
INFO="${BLUE}[INFO]${RESET}"
1820
WARNING="${YELLOW}[WARNING]${RESET}"

backends/arm/test/test_arm_baremetal.sh

+7-4
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,6 @@ test_pytest_ops() { # Test ops and other things
8383
test_pytest_models() { # Test ops and other things
8484
echo "${TEST_SUITE_NAME}: Run pytest"
8585

86-
examples/models/llama3_2_vision/install_requirements.sh
87-
8886
# Prepare for pytest
8987
backends/arm/scripts/build_executorch.sh
9088

@@ -117,8 +115,6 @@ test_pytest_ops_ethosu_fvp() { # Same as test_pytest but also sometime verify us
117115
test_pytest_models_ethosu_fvp() { # Same as test_pytest but also sometime verify using Corstone FVP
118116
echo "${TEST_SUITE_NAME}: Run pytest with fvp"
119117

120-
examples/models/llama3_2_vision/install_requirements.sh
121-
122118
# Prepare Corstone-3x0 FVP for pytest
123119
backends/arm/scripts/build_executorch.sh
124120
backends/arm/scripts/build_portable_kernels.sh
@@ -154,6 +150,13 @@ test_run_ethosu_fvp() { # End to End model tests using run.sh
154150
echo "${TEST_SUITE_NAME}: Test ethos-u target Ethos-U85"
155151
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=add
156152
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=mul
153+
154+
# Cortex-M op tests
155+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=qadd --bundleio
156+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=qops --bundleio
157+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=qops --bundleio --no_delegate --portable_kernels="aten::sub.out,aten::add.out,aten::mul.out"
158+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=qops --bundleio
159+
157160
echo "${TEST_SUITE_NAME}: PASS"
158161
}
159162

backends/cortex_m/CMakeLists.txt

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Kernel library for Cortex-M operators. Please keep this file formatted by running:
8+
# ~~~
9+
# cmake-format -i CMakeLists.txt
10+
# ~~~
11+
cmake_minimum_required(VERSION 3.19)
12+
13+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
14+
if(NOT CMAKE_CXX_STANDARD)
15+
set(CMAKE_CXX_STANDARD 17)
16+
endif()
17+
18+
# Source root directory for executorch.
19+
if(NOT EXECUTORCH_ROOT)
20+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
21+
endif()
22+
23+
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
24+
include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake)
25+
26+
if(NOT PYTHON_EXECUTABLE)
27+
resolve_python_executable()
28+
endif()
29+
30+
# Cortex-M ops kernel sources
31+
set(_cortex_m_kernels__srcs
32+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp
33+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp
34+
)
35+
36+
# Generate C++ bindings to register kernels into Executorch (for runtime).
37+
# Here select all ops in operators.yaml
38+
set(_yaml_file ${CMAKE_CURRENT_LIST_DIR}/ops/operators.yaml)
39+
gen_selected_ops(LIB_NAME "cortex_m_ops_lib" OPS_SCHEMA_YAML "${_yaml_file}")
40+
41+
# Generate bindings for the kernels
42+
generate_bindings_for_kernels(
43+
LIB_NAME "cortex_m_ops_lib" CUSTOM_OPS_YAML "${_yaml_file}"
44+
)
45+
message("Generated files ${gen_command_sources}")
46+
47+
# Build a library for _cortex_m_kernels_srcs
48+
add_library(cortex_m_kernels ${_cortex_m_kernels__srcs})
49+
target_link_libraries(cortex_m_kernels PRIVATE executorch)
50+
target_compile_options(cortex_m_kernels PUBLIC ${_common_compile_options})
51+
52+
# cortex_m_ops_lib: Register Cortex-M ops kernels into Executorch runtime
53+
gen_operators_lib(
54+
LIB_NAME "cortex_m_ops_lib" KERNEL_LIBS cortex_m_kernels DEPS executorch
55+
)
56+
57+
install(
58+
TARGETS cortex_m_kernels cortex_m_ops_lib
59+
DESTINATION lib
60+
PUBLIC_HEADER DESTINATION include/executorch/backends/cortex_m/ops/
61+
)

backends/cortex_m/ops/op_dequantize_per_tensor.cpp

+70-23
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ namespace {
2929
*/
3030
void check_dequantize_args(
3131
const Tensor& input,
32+
int64_t zero_point,
3233
int64_t quant_min,
3334
int64_t quant_max,
3435
ScalarType dtype,
@@ -39,6 +40,18 @@ void check_dequantize_args(
3940
"input.scalar_type() %" PRId8 " is not char type",
4041
static_cast<int8_t>(input.scalar_type()));
4142

43+
// Check zp range
44+
ET_CHECK_MSG(
45+
zero_point >= quant_min,
46+
"zero_point must be %" PRId64 " <= quant_min %" PRId64,
47+
zero_point,
48+
quant_min);
49+
ET_CHECK_MSG(
50+
zero_point <= quant_max,
51+
"zero_point must be %" PRId64 " >= quant_max %" PRId64,
52+
zero_point,
53+
quant_max);
54+
4255
// Check output dtype is float
4356
ET_CHECK_MSG(
4457
out.scalar_type() == ScalarType::Float,
@@ -73,18 +86,10 @@ void check_dequantize_args(
7386
/**
7487
* Scalar implementation of quantization for a single value.
7588
*/
76-
template <typename K, typename T>
77-
T dequantize_val(
78-
float scale,
79-
int32_t zero_point,
80-
K value,
81-
int64_t quant_min,
82-
int64_t quant_max) {
83-
(void)quant_min;
84-
(void)quant_max;
85-
return static_cast<T>((static_cast<int32_t>(value) - zero_point) * scale);
89+
template <typename Q, typename F>
90+
F dequantize_val(float scale, int32_t zero_point, Q qvalue) {
91+
return static_cast<F>((static_cast<int32_t>(qvalue) - zero_point) * scale);
8692
}
87-
8893
} // namespace
8994

9095
Tensor& dequantize_per_tensor_out(
@@ -106,29 +111,71 @@ Tensor& dequantize_per_tensor_out(
106111
"Failed to resize out Tensor in dequantize_per_tensor_out");
107112

108113
// Validate input parameters
109-
check_dequantize_args(input, quant_min, quant_max, dtype, out);
114+
check_dequantize_args(input, zero_point, quant_min, quant_max, dtype, out);
110115

111-
// Pre-compute inverse scale for better performance
112116
int32_t zp = static_cast<int32_t>(zero_point);
113-
int32_t qmin = static_cast<int32_t>(quant_min);
114-
int32_t qmax = static_cast<int32_t>(quant_max);
115117

116118
// Get pointers to input and output data
117119
const int8_t* input_data = input.const_data_ptr<int8_t>();
118120
float* out_data = out.mutable_data_ptr<float>();
119121
const size_t numel = input.numel();
120122

123+
size_t i = 0;
121124
#if defined(HAS_HELIUM_SIMD)
122-
// Helium MVE implementation for float32 to int8 quantization
123-
#Error "Implement MVE version!"
124-
#else
125-
// Scalar implementation for float32 to int8 quantization
126-
for (size_t i = 0; i < numel; i++) {
127-
out_data[i] =
128-
dequantize_val<int8_t, float>(scale, zp, input_data[i], qmin, qmax);
125+
// Helium MVE implementation for int8 to float quantization
126+
static uint8x16_t voffset{
127+
0x0,
128+
0x8,
129+
0x4,
130+
0xC,
131+
0x1,
132+
0x9,
133+
0x5,
134+
0xD,
135+
0x2,
136+
0xA,
137+
0x6,
138+
0xE,
139+
0x3,
140+
0xB,
141+
0x7,
142+
0xF};
143+
144+
int16x8_t vzp = vdupq_n_s16(static_cast<int16_t>(zp));
145+
float32x4_t vscale = vdupq_n_f32(static_cast<float>(scale));
146+
147+
for (; i + 15 < numel; i += 16) {
148+
int8x16_t in_084C195D2A6E3B7F =
149+
vldrbq_gather_offset_s8(input_data, voffset);
150+
151+
int16x8_t in_04152637 = vsubq_s16(vmovlbq_s8(in_084C195D2A6E3B7F), vzp);
152+
int16x8_t in_8C9DAEBF = vsubq_s16(vmovltq_s8(in_084C195D2A6E3B7F), vzp);
153+
154+
float32x4_t inf_0123 = vcvtq_f32_s32(vmovlbq_s16(in_04152637));
155+
float32x4_t inf_4567 = vcvtq_f32_s32(vmovltq_s16(in_04152637));
156+
float32x4_t inf_89AB = vcvtq_f32_s32(vmovlbq_s16(in_8C9DAEBF));
157+
float32x4_t inf_CDEF = vcvtq_f32_s32(vmovltq_s16(in_8C9DAEBF));
158+
159+
float32x4_t out_0123 = vmulq_f32(inf_0123, vscale);
160+
float32x4_t out_4567 = vmulq_f32(inf_4567, vscale);
161+
float32x4_t out_89AB = vmulq_f32(inf_89AB, vscale);
162+
float32x4_t out_CDEF = vmulq_f32(inf_CDEF, vscale);
163+
164+
vstrwq_f32(out_data + 0, out_0123);
165+
vstrwq_f32(out_data + 4, out_4567);
166+
vstrwq_f32(out_data + 8, out_89AB);
167+
vstrwq_f32(out_data + 12, out_CDEF);
168+
169+
input_data += 16;
170+
out_data += 16;
129171
}
130-
#endif
172+
#endif // defined(HAS_HELIUM_SIMD)
131173

174+
for (; i < numel; i++) {
175+
*out_data = dequantize_val<int8_t, float>(scale, zp, *input_data);
176+
input_data++;
177+
out_data++;
178+
}
132179
return out;
133180
}
134181

0 commit comments

Comments
 (0)