-
Notifications
You must be signed in to change notification settings - Fork 294
Kleidi 4b blockwise gemv prototype #997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
036b782
c4b9f1e
4a85c4d
49afa4a
569c069
fd1423f
c323fb1
8aa27c4
44ca4de
c272739
ee62be5
ee49c6e
a905ec3
7429bea
f28e556
17f2b43
3049ded
d4bb3ed
f6e22fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,22 @@ | |
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
include(FetchContent) | ||
|
||
# KleidiAI is an open-source library that provides optimized | ||
# performance-critical routines, also known as micro-kernels, for artificial | ||
# intelligence (AI) workloads tailored for Arm® CPUs. | ||
FetchContent_Declare(kleidiai | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why add this as build time dependency instead of 3p-lib? Wait I guess gitlab? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean as opposed to a git submodule? Just to keep it simple for now. |
||
GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git | ||
GIT_TAG 35e156d62d1d7e4d27a39f56ed7770a665628b31) # same as xnnpack for now, TODO - revisit this | ||
|
||
FetchContent_MakeAvailable(kleidiai) | ||
|
||
# Disabled by default. Force enable if we are on a suitable system. | ||
# TODO: Introduce ISA specific flags for i8mm. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you leave it disabled by default until we benchmark it against existing kernel in torchchat? I want to make sure we don't regress torchchat perf. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't wire it up at the op level, and we enable only for armv8 and we only have dotprod kernels so this should be OK. Before we add i8mm kernels we have to fix the CMake and also the op level wiring. |
||
CMAKE_DEPENDENT_OPTION(BUILD_KLEIDI "Download, build, and link against Arm KleidiAI library" | ||
OFF "CMAKE_SYSTEM_PROCESSOR STREQUAL \"arm64\"" ON) | ||
|
||
if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") | ||
add_library( | ||
torchao_kernels_aarch64 | ||
|
@@ -12,6 +28,13 @@ if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") | |
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp | ||
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp | ||
) | ||
if (BUILD_KLEIDI) | ||
# Temporarily exposing this to the parent scope until we wire | ||
# this up properly from the top level | ||
set(TORCHAO_ENABLE_KLEIDI ON PARENT_SCOPE) | ||
message(STATUS "Building with Kleidi") | ||
target_link_libraries(torchao_kernels_aarch64 PUBLIC kleidiai) | ||
endif() | ||
endif() | ||
|
||
install( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
// Copyright (c) Meta Platforms, Inc. and affiliates. | ||
// All rights reserved. | ||
// | ||
// This source code is licensed under the license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#pragma once | ||
|
||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h> | ||
|
||
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h> | ||
|
||
namespace torchao::kernels::cpu::aarch64::kleidi { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the cpp compliance on using namespace like this? Just confirm that it is atleast c++17 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CMake dictates we can assume c++17 |
||
namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { | ||
|
||
namespace neon_dotprod_1x4x32 { | ||
const Ukernel get_ukernel() { | ||
digantdesai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return Ukernel{ | ||
.get_m_step = | ||
kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
.get_n_step = | ||
kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
.get_mr = | ||
kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
.get_nr = | ||
kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
.get_kr = | ||
kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
.get_sr = | ||
kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
.get_lhs_packed_offset = | ||
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
.get_rhs_packed_offset = | ||
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
.get_dst_offset = | ||
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
.get_dst_size = | ||
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
.run_matmul = | ||
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod}; | ||
} | ||
|
||
size_t activation_data_size(int m, int k, int group_size) { | ||
(void)group_size; // unused | ||
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size( | ||
get_ukernel(), m, k); | ||
} | ||
|
||
void prepare_activation_data( | ||
void* activation_data, | ||
int m, | ||
int k, | ||
int group_size, | ||
const float* activations) { | ||
(void)group_size; // unused | ||
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( | ||
get_ukernel(), activation_data, m, k, activations); | ||
} | ||
|
||
size_t weight_data_size(int n, int k, int group_size) { | ||
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size( | ||
get_ukernel(), n, k, group_size); | ||
} | ||
|
||
void prepare_weight_data( | ||
void* weight_data, | ||
int n, | ||
int k, | ||
int group_size, | ||
const int8_t* weight_qvals, | ||
const float* weight_scales, | ||
const int8_t* weight_zeros) { | ||
Comment on lines
+71
to
+72
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. THis is interesting that zero pointn is int8. PT interfaces are doing this sometimes int32 sometimes int64. I have changes that moves it to int32 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But makes sense to have them as int8. so for kleidi zero_point doesnt have to zero? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. prepare_weight_data is the interface we created, not from kleidi, so I think @digantdesai is just adhering to that. kleidi zero_point is zero, so I think @digantdesai asserts that later. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yeah symmetric only ATM |
||
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( | ||
get_ukernel(), | ||
weight_data, | ||
n, | ||
k, | ||
group_size, | ||
weight_qvals, | ||
weight_scales, | ||
weight_zeros); | ||
} | ||
|
||
void kernel( | ||
float32_t* output, | ||
int output_m_stride, | ||
int m, | ||
int n, | ||
int k, | ||
int group_size, | ||
const void* weight_data, | ||
const void* activation_data, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this packed activaiton and packed weight? if so maybe worth naming such |
||
const float* bias, | ||
float clamp_min, | ||
float clamp_max) { | ||
(void)bias; // TODO(T203756650) - unused - needs API fixing | ||
assert(output_m_stride == n); | ||
if (clamp_min == 0 && clamp_max == 0) { | ||
clamp_min = std::numeric_limits<float>::lowest(); | ||
clamp_max = std::numeric_limits<float>::max(); | ||
} | ||
|
||
auto ukernel = get_ukernel(); | ||
ukernel.run_matmul( | ||
m, | ||
n, | ||
k, | ||
group_size, | ||
activation_data, | ||
weight_data, | ||
output, | ||
/*dst_stride_row=*/n * sizeof(float), | ||
/*dst_stride_col=*/sizeof(float), | ||
clamp_min, | ||
clamp_max); | ||
} | ||
|
||
size_t get_preferred_alignement() { | ||
return 16; | ||
} | ||
} // namespace neon_dotprod_1x4x32 | ||
} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p | ||
} // namespace torchao::kernels::cpu::aarch64::kleidi |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
// Copyright (c) Meta Platforms, Inc. and affiliates. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file is very similar to the 1x4x32 one above. Do you think it's possible to reuse some code? Same comment with next file. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes! I want to lean on you c++ experts 😅 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you wanna do this as follow up thats also ok, but I do agree that it can probably be structured differently. e.g get_ukernel can be factored out to take type of the kernel as arg |
||
// All rights reserved. | ||
// | ||
// This source code is licensed under the license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#pragma once | ||
|
||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h> | ||
|
||
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h> | ||
|
||
namespace torchao::kernels::cpu::aarch64::kleidi { | ||
namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { | ||
namespace neon_dotprod_1x8x32 { | ||
const Ukernel get_ukernel() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For furture: I presume you will have to parameterize this for different kernels? Also would it make sense to structure this in a way that this function moves to kleidi? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need to think some more to support (1) AoT/Runtime weight packing, (2) per cpu uArch based uKernel selection. These logic would dictate how this interface looks like. So did something minimal here for the "prototype" but agree we can improve. |
||
return Ukernel{ | ||
.get_m_step = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also what are m/n step? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, | ||
.get_n_step = | ||
kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, | ||
.get_mr = | ||
kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, | ||
.get_nr = | ||
kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, | ||
.get_kr = | ||
kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, | ||
.get_sr = | ||
kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, | ||
.get_lhs_packed_offset = | ||
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, | ||
.get_rhs_packed_offset = | ||
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, | ||
.get_dst_offset = | ||
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, | ||
.get_dst_size = | ||
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, | ||
.run_matmul = | ||
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod}; | ||
} | ||
|
||
size_t activation_data_size(int m, int k, int group_size) { | ||
(void) group_size; // unused | ||
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(get_ukernel(), m, k); | ||
} | ||
|
||
void prepare_activation_data( | ||
void* activation_data, | ||
int m, | ||
int k, | ||
int group_size, | ||
const float* activations) { | ||
(void) group_size; // unused | ||
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( | ||
get_ukernel(), | ||
activation_data, | ||
m, | ||
k, | ||
activations); | ||
} | ||
|
||
size_t weight_data_size(int n, int k, int group_size) { | ||
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(get_ukernel(), n, k, group_size); | ||
} | ||
|
||
void prepare_weight_data( | ||
void* weight_data, | ||
int n, | ||
int k, | ||
int group_size, | ||
const int8_t* weight_qvals, | ||
const float* weight_scales, | ||
const int8_t* weight_zeros) { | ||
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( | ||
get_ukernel(), | ||
weight_data, | ||
n, | ||
k, | ||
group_size, | ||
weight_qvals, | ||
weight_scales, | ||
weight_zeros); | ||
} | ||
|
||
void kernel( | ||
float32_t* output, | ||
int output_m_stride, | ||
int m, | ||
int n, | ||
int k, | ||
int group_size, | ||
const void* weight_data, | ||
const void* activation_data, | ||
const float* bias, | ||
float clamp_min, | ||
float clamp_max) { | ||
(void) bias; // TODO(T203756650) - unused - needs API fixing | ||
assert(output_m_stride == n); | ||
if (clamp_min == 0 && clamp_max == 0) { | ||
clamp_min = std::numeric_limits<float>::lowest(); | ||
clamp_max = std::numeric_limits<float>::max(); | ||
} | ||
|
||
auto ukernel = get_ukernel(); | ||
ukernel.run_matmul( | ||
m, | ||
n, | ||
k, | ||
group_size, | ||
activation_data, | ||
weight_data, | ||
output, | ||
/*dst_stride_row=*/ n * sizeof(float), | ||
/*dst_stride_col=*/ sizeof(float), | ||
clamp_min, | ||
clamp_max); | ||
} | ||
|
||
size_t get_preferred_alignement() { | ||
return 16; | ||
} | ||
} // namespace neon_dotprod_1x4x32 | ||
} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p | ||
} // namespace torchao::kernels::cpu::aarch64::kleidi |
Uh oh!
There was an error while loading. Please reload this page.