Skip to content

Commit 4c06366

Browse files
committed
Update on "[Executorch][llm] Make custom update cache op operate on indices"
This allows us to use ring buffer kv cache Differential Revision: [D73891424](https://our.internmc.facebook.com/intern/diff/D73891424/) [ghstack-poisoned]
2 parents 6176b62 + 4935c16 commit 4c06366

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+794
-292
lines changed

.ci/scripts/test_model.sh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,6 @@ test_model() {
8787
bash examples/models/llava/install_requirements.sh
8888
STRICT="--no-strict"
8989
fi
90-
if [[ "$MODEL_NAME" == "llama3_2_vision_encoder" || "$MODEL_NAME" == "llama3_2_text_decoder" ]]; then
91-
# Install requirements for llama vision.
92-
bash examples/models/llama3_2_vision/install_requirements.sh
93-
fi
9490
if [[ "${MODEL_NAME}" == "qwen2_5" ]]; then
9591
# Install requirements for export_llama
9692
bash examples/models/llama/install_requirements.sh

.ci/scripts/unittest-linux.sh

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@ if [[ "$BUILD_TOOL" == "cmake" ]]; then
2424
CMAKE_ARGS="-DEXECUTORCH_BUILD_PYBIND=ON -DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON" \
2525
.ci/scripts/setup-linux.sh "$@"
2626

27-
# Install llama3_2_vision dependencies.
28-
PYTHON_EXECUTABLE=python ./examples/models/llama3_2_vision/install_requirements.sh
29-
3027
.ci/scripts/unittest-linux-cmake.sh
3128
elif [[ "$BUILD_TOOL" == "buck2" ]]; then
3229
# Removing this breaks sccache in the Buck build, apparently

.ci/scripts/unittest-macos.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ if [[ "$BUILD_TOOL" == "cmake" ]]; then
2929
# Install llama3_2_vision dependencies.
3030
PYTHON_EXECUTABLE=python \
3131
${CONDA_RUN} --no-capture-output \
32-
./examples/models/llama3_2_vision/install_requirements.sh
3332

3433
.ci/scripts/unittest-macos-cmake.sh
3534
elif [[ "$BUILD_TOOL" == "buck2" ]]; then

.lintrunner.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,13 @@ exclude_patterns = [
220220
'extension/**',
221221
'kernels/optimized/**',
222222
# Justified <functional> include.
223+
'kernels/portable/cpu/op_bitwise*.cpp',
224+
'kernels/portable/cpu/op_eq.cpp',
225+
'kernels/portable/cpu/op_ge.cpp',
226+
'kernels/portable/cpu/op_gt.cpp',
227+
'kernels/portable/cpu/op_le.cpp',
228+
'kernels/portable/cpu/op_lt.cpp',
229+
'kernels/portable/cpu/op_ne.cpp',
223230
'runtime/kernel/thread_parallel_interface.h',
224231
'scripts/**',
225232
'third-party/**',

CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ option(EXECUTORCH_USE_DL "Use libdl library" ON)
242242

243243
option(EXECUTORCH_BUILD_CADENCE "Build the Cadence DSP backend" OFF)
244244

245+
option(EXECUTORCH_BUILD_CORTEX_M "Build the Cortex-M backend" OFF)
246+
245247
#
246248
# pthreadpool: build pthreadpool library. Disable on unsupported platforms
247249
#
@@ -715,6 +717,10 @@ if(EXECUTORCH_BUILD_XNNPACK)
715717
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/xnnpack)
716718
endif()
717719

720+
if(EXECUTORCH_BUILD_CORTEX_M)
721+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cortex_m)
722+
endif()
723+
718724
if(EXECUTORCH_BUILD_DEVTOOLS)
719725
if(NOT EXECUTORCH_BUILD_ARM_BAREMETAL)
720726
set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER

backends/arm/scripts/build_executorch.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ cmake \
129129
-DEXECUTORCH_BUILD_ARM_BAREMETAL=ON \
130130
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
131131
-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \
132+
-DEXECUTORCH_BUILD_CORTEX_M=ON \
132133
-DEXECUTORCH_ENABLE_LOGGING=ON \
133134
${build_devtools_flags} \
134135
${build_with_etdump_flags} \

backends/arm/scripts/pre-push

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
# non-interactive mode. "$#" gives the number of positional arguments.
99
[ "$#" -eq 0 ] && is_script_interactive=1 || is_script_interactive=0
1010

11-
RESET='\e[0m'
12-
RED='\e[31m'
13-
GREEN='\e[32m'
14-
YELLOW='\e[33m'
15-
BLUE='\e[34m'
11+
if [ $is_script_interactive -eq 1 ]; then
12+
RESET='\e[0m'
13+
RED='\e[31m'
14+
GREEN='\e[32m'
15+
YELLOW='\e[33m'
16+
BLUE='\e[34m'
17+
fi
1618

1719
INFO="${BLUE}[INFO]${RESET}"
1820
WARNING="${YELLOW}[WARNING]${RESET}"

backends/arm/test/test_arm_baremetal.sh

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,6 @@ test_pytest_ops() { # Test ops and other things
8383
test_pytest_models() { # Test ops and other things
8484
echo "${TEST_SUITE_NAME}: Run pytest"
8585

86-
examples/models/llama3_2_vision/install_requirements.sh
87-
8886
# Prepare for pytest
8987
backends/arm/scripts/build_executorch.sh
9088

@@ -117,8 +115,6 @@ test_pytest_ops_ethosu_fvp() { # Same as test_pytest but also sometime verify us
117115
test_pytest_models_ethosu_fvp() { # Same as test_pytest but also sometime verify using Corstone FVP
118116
echo "${TEST_SUITE_NAME}: Run pytest with fvp"
119117

120-
examples/models/llama3_2_vision/install_requirements.sh
121-
122118
# Prepare Corstone-3x0 FVP for pytest
123119
backends/arm/scripts/build_executorch.sh
124120
backends/arm/scripts/build_portable_kernels.sh
@@ -154,6 +150,13 @@ test_run_ethosu_fvp() { # End to End model tests using run.sh
154150
echo "${TEST_SUITE_NAME}: Test ethos-u target Ethos-U85"
155151
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=add
156152
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=mul
153+
154+
# Cortex-M op tests
155+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=qadd --bundleio
156+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=qops --bundleio
157+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=qops --bundleio --no_delegate --portable_kernels="aten::sub.out,aten::add.out,aten::mul.out"
158+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=qops --bundleio
159+
157160
echo "${TEST_SUITE_NAME}: PASS"
158161
}
159162

backends/cortex_m/CMakeLists.txt

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Kernel library for Cortex-M operators. Please keep this file formatted by running:
8+
# ~~~
9+
# cmake-format -i CMakeLists.txt
10+
# ~~~
11+
cmake_minimum_required(VERSION 3.19)
12+
13+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
14+
if(NOT CMAKE_CXX_STANDARD)
15+
set(CMAKE_CXX_STANDARD 17)
16+
endif()
17+
18+
# Source root directory for executorch.
19+
if(NOT EXECUTORCH_ROOT)
20+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
21+
endif()
22+
23+
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
24+
include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake)
25+
26+
if(NOT PYTHON_EXECUTABLE)
27+
resolve_python_executable()
28+
endif()
29+
30+
# Cortex-M ops kernel sources
31+
set(_cortex_m_kernels__srcs
32+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp
33+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp
34+
)
35+
36+
# Generate C++ bindings to register kernels into Executorch (for runtime).
37+
# Here select all ops in operators.yaml
38+
set(_yaml_file ${CMAKE_CURRENT_LIST_DIR}/ops/operators.yaml)
39+
gen_selected_ops(LIB_NAME "cortex_m_ops_lib" OPS_SCHEMA_YAML "${_yaml_file}")
40+
41+
# Generate bindings for the kernels
42+
generate_bindings_for_kernels(
43+
LIB_NAME "cortex_m_ops_lib" CUSTOM_OPS_YAML "${_yaml_file}"
44+
)
45+
message("Generated files ${gen_command_sources}")
46+
47+
# Build a library for _cortex_m_kernels_srcs
48+
add_library(cortex_m_kernels ${_cortex_m_kernels__srcs})
49+
target_link_libraries(cortex_m_kernels PRIVATE executorch)
50+
target_compile_options(cortex_m_kernels PUBLIC ${_common_compile_options})
51+
52+
# cortex_m_ops_lib: Register Cortex-M ops kernels into Executorch runtime
53+
gen_operators_lib(
54+
LIB_NAME "cortex_m_ops_lib" KERNEL_LIBS cortex_m_kernels DEPS executorch
55+
)
56+
57+
install(
58+
TARGETS cortex_m_kernels cortex_m_ops_lib
59+
DESTINATION lib
60+
PUBLIC_HEADER DESTINATION include/executorch/backends/cortex_m/ops/
61+
)

backends/cortex_m/ops/op_dequantize_per_tensor.cpp

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ namespace {
2929
*/
3030
void check_dequantize_args(
3131
const Tensor& input,
32+
int64_t zero_point,
3233
int64_t quant_min,
3334
int64_t quant_max,
3435
ScalarType dtype,
@@ -39,6 +40,18 @@ void check_dequantize_args(
3940
"input.scalar_type() %" PRId8 " is not char type",
4041
static_cast<int8_t>(input.scalar_type()));
4142

43+
// Check zp range
44+
ET_CHECK_MSG(
45+
zero_point >= quant_min,
46+
"zero_point must be %" PRId64 " <= quant_min %" PRId64,
47+
zero_point,
48+
quant_min);
49+
ET_CHECK_MSG(
50+
zero_point <= quant_max,
51+
"zero_point must be %" PRId64 " >= quant_max %" PRId64,
52+
zero_point,
53+
quant_max);
54+
4255
// Check output dtype is float
4356
ET_CHECK_MSG(
4457
out.scalar_type() == ScalarType::Float,
@@ -73,18 +86,10 @@ void check_dequantize_args(
7386
/**
7487
* Scalar implementation of quantization for a single value.
7588
*/
76-
template <typename K, typename T>
77-
T dequantize_val(
78-
float scale,
79-
int32_t zero_point,
80-
K value,
81-
int64_t quant_min,
82-
int64_t quant_max) {
83-
(void)quant_min;
84-
(void)quant_max;
85-
return static_cast<T>((static_cast<int32_t>(value) - zero_point) * scale);
89+
template <typename Q, typename F>
90+
F dequantize_val(float scale, int32_t zero_point, Q qvalue) {
91+
return static_cast<F>((static_cast<int32_t>(qvalue) - zero_point) * scale);
8692
}
87-
8893
} // namespace
8994

9095
Tensor& dequantize_per_tensor_out(
@@ -106,29 +111,71 @@ Tensor& dequantize_per_tensor_out(
106111
"Failed to resize out Tensor in dequantize_per_tensor_out");
107112

108113
// Validate input parameters
109-
check_dequantize_args(input, quant_min, quant_max, dtype, out);
114+
check_dequantize_args(input, zero_point, quant_min, quant_max, dtype, out);
110115

111-
// Pre-compute inverse scale for better performance
112116
int32_t zp = static_cast<int32_t>(zero_point);
113-
int32_t qmin = static_cast<int32_t>(quant_min);
114-
int32_t qmax = static_cast<int32_t>(quant_max);
115117

116118
// Get pointers to input and output data
117119
const int8_t* input_data = input.const_data_ptr<int8_t>();
118120
float* out_data = out.mutable_data_ptr<float>();
119121
const size_t numel = input.numel();
120122

123+
size_t i = 0;
121124
#if defined(HAS_HELIUM_SIMD)
122-
// Helium MVE implementation for float32 to int8 quantization
123-
#Error "Implement MVE version!"
124-
#else
125-
// Scalar implementation for float32 to int8 quantization
126-
for (size_t i = 0; i < numel; i++) {
127-
out_data[i] =
128-
dequantize_val<int8_t, float>(scale, zp, input_data[i], qmin, qmax);
125+
// Helium MVE implementation for int8 to float quantization
126+
static uint8x16_t voffset{
127+
0x0,
128+
0x8,
129+
0x4,
130+
0xC,
131+
0x1,
132+
0x9,
133+
0x5,
134+
0xD,
135+
0x2,
136+
0xA,
137+
0x6,
138+
0xE,
139+
0x3,
140+
0xB,
141+
0x7,
142+
0xF};
143+
144+
int16x8_t vzp = vdupq_n_s16(static_cast<int16_t>(zp));
145+
float32x4_t vscale = vdupq_n_f32(static_cast<float>(scale));
146+
147+
for (; i + 15 < numel; i += 16) {
148+
int8x16_t in_084C195D2A6E3B7F =
149+
vldrbq_gather_offset_s8(input_data, voffset);
150+
151+
int16x8_t in_04152637 = vsubq_s16(vmovlbq_s8(in_084C195D2A6E3B7F), vzp);
152+
int16x8_t in_8C9DAEBF = vsubq_s16(vmovltq_s8(in_084C195D2A6E3B7F), vzp);
153+
154+
float32x4_t inf_0123 = vcvtq_f32_s32(vmovlbq_s16(in_04152637));
155+
float32x4_t inf_4567 = vcvtq_f32_s32(vmovltq_s16(in_04152637));
156+
float32x4_t inf_89AB = vcvtq_f32_s32(vmovlbq_s16(in_8C9DAEBF));
157+
float32x4_t inf_CDEF = vcvtq_f32_s32(vmovltq_s16(in_8C9DAEBF));
158+
159+
float32x4_t out_0123 = vmulq_f32(inf_0123, vscale);
160+
float32x4_t out_4567 = vmulq_f32(inf_4567, vscale);
161+
float32x4_t out_89AB = vmulq_f32(inf_89AB, vscale);
162+
float32x4_t out_CDEF = vmulq_f32(inf_CDEF, vscale);
163+
164+
vstrwq_f32(out_data + 0, out_0123);
165+
vstrwq_f32(out_data + 4, out_4567);
166+
vstrwq_f32(out_data + 8, out_89AB);
167+
vstrwq_f32(out_data + 12, out_CDEF);
168+
169+
input_data += 16;
170+
out_data += 16;
129171
}
130-
#endif
172+
#endif // defined(HAS_HELIUM_SIMD)
131173

174+
for (; i < numel; i++) {
175+
*out_data = dequantize_val<int8_t, float>(scale, zp, *input_data);
176+
input_data++;
177+
out_data++;
178+
}
132179
return out;
133180
}
134181

0 commit comments

Comments
 (0)