Skip to content

Commit db72dd1

Browse files
authored
Kleidi 4b blockwise gemv prototype
Differential Revision: D64194844 Pull Request resolved: #997
1 parent 5277507 commit db72dd1

11 files changed

+842
-10
lines changed

torchao/experimental/build_torchao_ops.sh

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
1-
#!/bin/bash
1+
#!/bin/bash -eu
22
# Copyright (c) Meta Platforms, Inc. and affiliates.
33
# All rights reserved.
44
#
55
# This source code is licensed under the license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
if [[ $# -ne 1 ]]; then
9+
echo "Usage: $0 <aten|executorch>";
10+
exit 1;
11+
fi
12+
TARGET="${1}"
813
export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')"
914
echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}"
1015
export CMAKE_OUT=/tmp/cmake-out/torchao
1116
cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
1217
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \
13-
-DTORCHAO_OP_TARGET="$1" \
18+
-DTORCHAO_OP_TARGET="${TARGET}" \
1419
-DEXECUTORCH_LIBRARIES="${EXECUTORCH_LIBRARIES}" \
1520
-DEXECUTORCH_INCLUDE_DIRS="${EXECUTORCH_INCLUDE_DIRS}" \
1621
-S . \

torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,22 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
include(FetchContent)
8+
9+
# KleidiAI is an open-source library that provides optimized
10+
# performance-critical routines, also known as micro-kernels, for artificial
11+
# intelligence (AI) workloads tailored for Arm® CPUs.
12+
FetchContent_Declare(kleidiai
13+
GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git
14+
GIT_TAG 35e156d62d1d7e4d27a39f56ed7770a665628b31) # same as xnnpack for now, TODO - revisit this
15+
16+
FetchContent_MakeAvailable(kleidiai)
17+
18+
# Disabled by default. Force enable if we are on a suitable system.
19+
# TODO: Introduce ISA specific flags for i8mm.
20+
CMAKE_DEPENDENT_OPTION(BUILD_KLEIDI "Download, build, and link against Arm KleidiAI library"
21+
OFF "CMAKE_SYSTEM_PROCESSOR STREQUAL \"arm64\"" ON)
22+
723
if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
824
add_library(
925
torchao_kernels_aarch64
@@ -12,6 +28,13 @@ if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
1228
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
1329
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
1430
)
31+
if (BUILD_KLEIDI)
32+
# Temporarily exposing this to the parent scope until we wire
33+
# this up properly from the top level
34+
set(TORCHAO_ENABLE_KLEIDI ON PARENT_SCOPE)
35+
message(STATUS "Building with Kleidi")
36+
target_link_libraries(torchao_kernels_aarch64 PUBLIC kleidiai)
37+
endif()
1538
endif()
1639

1740
install(

torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,32 @@
1-
#!/bin/bash
1+
#!/bin/bash -eu
22
# Copyright (c) Meta Platforms, Inc. and affiliates.
33
# All rights reserved.
44
#
55
# This source code is licensed under the license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
# Call script with sh build_and_run_benchmarks.sh {BENCHAMRK}
8+
set -eu
99

10+
if [[ $# -ne 1 ]]; then
11+
echo "Usage: $0 <quantization|bitpacking|linear>";
12+
exit 1;
13+
fi
14+
15+
BENCHMARK_TYPE="${1}"
1016
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
17+
1118
export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../..
1219
export CMAKE_OUT=/tmp/cmake-out/torch_ao/benchmarks
20+
21+
# Build
1322
cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
1423
-S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/benchmarks \
1524
-B ${CMAKE_OUT}
1625

1726
cmake --build ${CMAKE_OUT}
1827

1928
# Run
20-
case "$1" in
29+
case "${BENCHMARK_TYPE}" in
2130
quantization) ${CMAKE_OUT}/benchmark_quantization; ;;
2231
bitpacking) ${CMAKE_OUT}/benchmark_bitpacking; ;;
2332
linear) ${CMAKE_OUT}/benchmark_linear; ;;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h>
10+
11+
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h>
12+
13+
namespace torchao::kernels::cpu::aarch64::kleidi {
14+
namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {
15+
16+
namespace neon_dotprod_1x4x32 {
17+
const Ukernel get_ukernel() {
18+
return Ukernel{
19+
.get_m_step =
20+
kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
21+
.get_n_step =
22+
kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
23+
.get_mr =
24+
kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
25+
.get_nr =
26+
kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
27+
.get_kr =
28+
kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
29+
.get_sr =
30+
kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
31+
.get_lhs_packed_offset =
32+
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
33+
.get_rhs_packed_offset =
34+
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
35+
.get_dst_offset =
36+
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
37+
.get_dst_size =
38+
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
39+
.run_matmul =
40+
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod};
41+
}
42+
43+
size_t activation_data_size(int m, int k, int group_size) {
44+
(void)group_size; // unused
45+
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(
46+
get_ukernel(), m, k);
47+
}
48+
49+
void prepare_activation_data(
50+
void* activation_data,
51+
int m,
52+
int k,
53+
int group_size,
54+
const float* activations) {
55+
(void)group_size; // unused
56+
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data(
57+
get_ukernel(), activation_data, m, k, activations);
58+
}
59+
60+
size_t weight_data_size(int n, int k, int group_size) {
61+
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(
62+
get_ukernel(), n, k, group_size);
63+
}
64+
65+
void prepare_weight_data(
66+
void* weight_data,
67+
int n,
68+
int k,
69+
int group_size,
70+
const int8_t* weight_qvals,
71+
const float* weight_scales,
72+
const int8_t* weight_zeros) {
73+
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data(
74+
get_ukernel(),
75+
weight_data,
76+
n,
77+
k,
78+
group_size,
79+
weight_qvals,
80+
weight_scales,
81+
weight_zeros);
82+
}
83+
84+
void kernel(
85+
float32_t* output,
86+
int output_m_stride,
87+
int m,
88+
int n,
89+
int k,
90+
int group_size,
91+
const void* weight_data,
92+
const void* activation_data,
93+
const float* bias,
94+
float clamp_min,
95+
float clamp_max) {
96+
(void)bias; // TODO(T203756650) - unused - needs API fixing
97+
assert(output_m_stride == n);
98+
if (clamp_min == 0 && clamp_max == 0) {
99+
clamp_min = std::numeric_limits<float>::lowest();
100+
clamp_max = std::numeric_limits<float>::max();
101+
}
102+
103+
auto ukernel = get_ukernel();
104+
ukernel.run_matmul(
105+
m,
106+
n,
107+
k,
108+
group_size,
109+
activation_data,
110+
weight_data,
111+
output,
112+
/*dst_stride_row=*/n * sizeof(float),
113+
/*dst_stride_col=*/sizeof(float),
114+
clamp_min,
115+
clamp_max);
116+
}
117+
118+
size_t get_preferred_alignement() {
119+
return 16;
120+
}
121+
} // namespace neon_dotprod_1x4x32
122+
} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p
123+
} // namespace torchao::kernels::cpu::aarch64::kleidi
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h>
10+
11+
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h>
12+
13+
namespace torchao::kernels::cpu::aarch64::kleidi {
14+
namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {
15+
namespace neon_dotprod_1x8x32 {
16+
const Ukernel get_ukernel() {
17+
return Ukernel{
18+
.get_m_step =
19+
kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
20+
.get_n_step =
21+
kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
22+
.get_mr =
23+
kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
24+
.get_nr =
25+
kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
26+
.get_kr =
27+
kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
28+
.get_sr =
29+
kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
30+
.get_lhs_packed_offset =
31+
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
32+
.get_rhs_packed_offset =
33+
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
34+
.get_dst_offset =
35+
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
36+
.get_dst_size =
37+
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
38+
.run_matmul =
39+
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod};
40+
}
41+
42+
size_t activation_data_size(int m, int k, int group_size) {
43+
(void) group_size; // unused
44+
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(get_ukernel(), m, k);
45+
}
46+
47+
void prepare_activation_data(
48+
void* activation_data,
49+
int m,
50+
int k,
51+
int group_size,
52+
const float* activations) {
53+
(void) group_size; // unused
54+
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data(
55+
get_ukernel(),
56+
activation_data,
57+
m,
58+
k,
59+
activations);
60+
}
61+
62+
size_t weight_data_size(int n, int k, int group_size) {
63+
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(get_ukernel(), n, k, group_size);
64+
}
65+
66+
void prepare_weight_data(
67+
void* weight_data,
68+
int n,
69+
int k,
70+
int group_size,
71+
const int8_t* weight_qvals,
72+
const float* weight_scales,
73+
const int8_t* weight_zeros) {
74+
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data(
75+
get_ukernel(),
76+
weight_data,
77+
n,
78+
k,
79+
group_size,
80+
weight_qvals,
81+
weight_scales,
82+
weight_zeros);
83+
}
84+
85+
void kernel(
86+
float32_t* output,
87+
int output_m_stride,
88+
int m,
89+
int n,
90+
int k,
91+
int group_size,
92+
const void* weight_data,
93+
const void* activation_data,
94+
const float* bias,
95+
float clamp_min,
96+
float clamp_max) {
97+
(void) bias; // TODO(T203756650) - unused - needs API fixing
98+
assert(output_m_stride == n);
99+
if (clamp_min == 0 && clamp_max == 0) {
100+
clamp_min = std::numeric_limits<float>::lowest();
101+
clamp_max = std::numeric_limits<float>::max();
102+
}
103+
104+
auto ukernel = get_ukernel();
105+
ukernel.run_matmul(
106+
m,
107+
n,
108+
k,
109+
group_size,
110+
activation_data,
111+
weight_data,
112+
output,
113+
/*dst_stride_row=*/ n * sizeof(float),
114+
/*dst_stride_col=*/ sizeof(float),
115+
clamp_min,
116+
clamp_max);
117+
}
118+
119+
size_t get_preferred_alignement() {
120+
return 16;
121+
}
122+
} // namespace neon_dotprod_1x4x32
123+
} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p
124+
} // namespace torchao::kernels::cpu::aarch64::kleidi

0 commit comments

Comments
 (0)