Skip to content

Kleidi 4b blockwise gemv prototype #997

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
036b782
[experimental] simple script UX fixes
digantdesai Oct 2, 2024
c4b9f1e
[experimental][kleidi] Add build support
digantdesai Oct 2, 2024
4a85c4d
[experimental][kleidi] Add uConfig support for qb4w 1x4x32 neon dotprod
digantdesai Oct 2, 2024
49afa4a
[experimental][kleidi] Add a basic test - compiles
digantdesai Oct 2, 2024
569c069
[experimental][kleidi] Pin kleidiai repo
digantdesai Oct 8, 2024
fd1423f
[experimental][kleidi] Clean up pack.h
digantdesai Oct 8, 2024
c323fb1
[experimental][kleidi] Refactor interface header
digantdesai Oct 8, 2024
8aa27c4
[experimental][kleidi] Improve unit-tests
digantdesai Oct 8, 2024
44ca4de
[experimental][kleidi] move common functions to interface
digantdesai Oct 8, 2024
c272739
[experimental][kleidi] Add 1x8x32 neon dotprod kernel
digantdesai Oct 8, 2024
ee62be5
[experimental][kleidi] linter
digantdesai Oct 8, 2024
ee49c6e
[experimental][kleidi] Reduce template types for tests
digantdesai Oct 8, 2024
a905ec3
[experimental][kleidi] Add m>1 tests
digantdesai Oct 10, 2024
7429bea
[experimental][kleidi] rename bf16 weight scale flag
digantdesai Oct 10, 2024
f28e556
[experimental][kleidi] Build kernel tests in debug mode
digantdesai Oct 10, 2024
17f2b43
[experimental][kleidi] Add TODO tasks
digantdesai Oct 10, 2024
3049ded
[experimental][kleidi] Allow weight zeros to be a nullptr
digantdesai Oct 10, 2024
d4bb3ed
[experimental][kleidi] rebase fixes with int to size_t
digantdesai Oct 10, 2024
f6e22fb
[experimental][kleidi] compile-time preprocessor switch for kleidi tests
digantdesai Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions torchao/experimental/build_torchao_ops.sh
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
#!/bin/bash
#!/bin/bash -eu
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

if [[ $# -ne 1 ]]; then
echo "Usage: $0 <aten|executorch>";
exit 1;
fi
TARGET="${1}"
export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')"
echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}"
export CMAKE_OUT=/tmp/cmake-out/torchao
cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \
-DTORCHAO_OP_TARGET="$1" \
-DTORCHAO_OP_TARGET="${TARGET}" \
-DEXECUTORCH_LIBRARIES="${EXECUTORCH_LIBRARIES}" \
-DEXECUTORCH_INCLUDE_DIRS="${EXECUTORCH_INCLUDE_DIRS}" \
-S . \
Expand Down
23 changes: 23 additions & 0 deletions torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

include(FetchContent)

# KleidiAI is an open-source library that provides optimized
# performance-critical routines, also known as micro-kernels, for artificial
# intelligence (AI) workloads tailored for Arm® CPUs.
FetchContent_Declare(kleidiai
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add this as build time dependency instead of 3p-lib? Wait I guess gitlab?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean as opposed to a git submodule? Just to keep it simple for now.

GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git
GIT_TAG 35e156d62d1d7e4d27a39f56ed7770a665628b31) # same as xnnpack for now, TODO - revisit this

FetchContent_MakeAvailable(kleidiai)

# Disabled by default. Force enable if we are on a suitable system.
# TODO: Introduce ISA specific flags for i8mm.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you leave it disabled by default until we benchmark it against existing kernel in torchchat? I want to make sure we don't regress torchchat perf.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't wire it up at the op level, and we enable only for armv8 and we only have dotprod kernels so this should be OK. Before we add i8mm kernels we have to fix the CMake and also the op level wiring.

CMAKE_DEPENDENT_OPTION(BUILD_KLEIDI "Download, build, and link against Arm KleidiAI library"
OFF "CMAKE_SYSTEM_PROCESSOR STREQUAL \"arm64\"" ON)

if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
add_library(
torchao_kernels_aarch64
Expand All @@ -12,6 +28,13 @@ if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
)
if (BUILD_KLEIDI)
# Temporarily exposing this to the parent scope until we wire
# this up properly from the top level
set(TORCHAO_ENABLE_KLEIDI ON PARENT_SCOPE)
message(STATUS "Building with Kleidi")
target_link_libraries(torchao_kernels_aarch64 PUBLIC kleidiai)
endif()
endif()

install(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
#!/bin/bash
#!/bin/bash -eu
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Call script with sh build_and_run_benchmarks.sh {BENCHAMRK}
set -eu

if [[ $# -ne 1 ]]; then
echo "Usage: $0 <quantization|bitpacking|linear>";
exit 1;
fi

BENCHMARK_TYPE="${1}"
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)

export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../..
export CMAKE_OUT=/tmp/cmake-out/torch_ao/benchmarks

# Build
cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
-S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/benchmarks \
-B ${CMAKE_OUT}

cmake --build ${CMAKE_OUT}

# Run
case "$1" in
case "${BENCHMARK_TYPE}" in
quantization) ${CMAKE_OUT}/benchmark_quantization; ;;
bitpacking) ${CMAKE_OUT}/benchmark_bitpacking; ;;
linear) ${CMAKE_OUT}/benchmark_linear; ;;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h>

#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h>

namespace torchao::kernels::cpu::aarch64::kleidi {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the cpp compliance on using namespace like this? Just confirm that it is atleast c++17

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CMake dictates we can assume c++17

namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {

namespace neon_dotprod_1x4x32 {
const Ukernel get_ukernel() {
return Ukernel{
.get_m_step =
kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.get_n_step =
kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.get_mr =
kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.get_nr =
kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.get_kr =
kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.get_sr =
kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.get_lhs_packed_offset =
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.get_rhs_packed_offset =
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.get_dst_offset =
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.get_dst_size =
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.run_matmul =
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod};
}

size_t activation_data_size(int m, int k, int group_size) {
(void)group_size; // unused
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(
get_ukernel(), m, k);
}

void prepare_activation_data(
void* activation_data,
int m,
int k,
int group_size,
const float* activations) {
(void)group_size; // unused
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data(
get_ukernel(), activation_data, m, k, activations);
}

size_t weight_data_size(int n, int k, int group_size) {
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(
get_ukernel(), n, k, group_size);
}

void prepare_weight_data(
void* weight_data,
int n,
int k,
int group_size,
const int8_t* weight_qvals,
const float* weight_scales,
const int8_t* weight_zeros) {
Comment on lines +71 to +72
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THis is interesting that zero pointn is int8. PT interfaces are doing this sometimes int32 sometimes int64. I have changes that moves it to int32

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But makes sense to have them as int8. so for kleidi zero_point doesnt have to zero?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prepare_weight_data is the interface we created, not from kleidi, so I think @digantdesai is just adhering to that.

kleidi zero_point is zero, so I think @digantdesai asserts that later.

Copy link
Contributor Author

@digantdesai digantdesai Oct 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so for kleidi zero_point doesnt have to zero?

Yeah symmetric only ATM

kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data(
get_ukernel(),
weight_data,
n,
k,
group_size,
weight_qvals,
weight_scales,
weight_zeros);
}

void kernel(
float32_t* output,
int output_m_stride,
int m,
int n,
int k,
int group_size,
const void* weight_data,
const void* activation_data,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this packed activaiton and packed weight? if so maybe worth naming such

const float* bias,
float clamp_min,
float clamp_max) {
(void)bias; // TODO(T203756650) - unused - needs API fixing
assert(output_m_stride == n);
if (clamp_min == 0 && clamp_max == 0) {
clamp_min = std::numeric_limits<float>::lowest();
clamp_max = std::numeric_limits<float>::max();
}

auto ukernel = get_ukernel();
ukernel.run_matmul(
m,
n,
k,
group_size,
activation_data,
weight_data,
output,
/*dst_stride_row=*/n * sizeof(float),
/*dst_stride_col=*/sizeof(float),
clamp_min,
clamp_max);
}

size_t get_preferred_alignement() {
return 16;
}
} // namespace neon_dotprod_1x4x32
} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p
} // namespace torchao::kernels::cpu::aarch64::kleidi
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is very similar to the 1x4x32 one above. Do you think it's possible to reuse some code? Same comment with next file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! I want to lean on you c++ experts 😅

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you wanna do this as follow up thats also ok, but I do agree that it can probably be structured differently. e.g get_ukernel can be factored out to take type of the kernel as arg

// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h>

#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h>

namespace torchao::kernels::cpu::aarch64::kleidi {
namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {
namespace neon_dotprod_1x8x32 {
const Ukernel get_ukernel() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For furture: I presume you will have to parameterize this for different kernels?

Also would it make sense to structure this in a way that this function moves to kleidi?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to think some more to support (1) AoT/Runtime weight packing, (2) per cpu uArch based uKernel selection. These logic would dictate how this interface looks like. So did something minimal here for the "prototype" but agree we can improve.

return Ukernel{
.get_m_step =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also what are m/n step?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
.get_n_step =
kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
.get_mr =
kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
.get_nr =
kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
.get_kr =
kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
.get_sr =
kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
.get_lhs_packed_offset =
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
.get_rhs_packed_offset =
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
.get_dst_offset =
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
.get_dst_size =
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
.run_matmul =
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod};
}

size_t activation_data_size(int m, int k, int group_size) {
(void) group_size; // unused
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(get_ukernel(), m, k);
}

void prepare_activation_data(
void* activation_data,
int m,
int k,
int group_size,
const float* activations) {
(void) group_size; // unused
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data(
get_ukernel(),
activation_data,
m,
k,
activations);
}

size_t weight_data_size(int n, int k, int group_size) {
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(get_ukernel(), n, k, group_size);
}

void prepare_weight_data(
void* weight_data,
int n,
int k,
int group_size,
const int8_t* weight_qvals,
const float* weight_scales,
const int8_t* weight_zeros) {
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data(
get_ukernel(),
weight_data,
n,
k,
group_size,
weight_qvals,
weight_scales,
weight_zeros);
}

void kernel(
float32_t* output,
int output_m_stride,
int m,
int n,
int k,
int group_size,
const void* weight_data,
const void* activation_data,
const float* bias,
float clamp_min,
float clamp_max) {
(void) bias; // TODO(T203756650) - unused - needs API fixing
assert(output_m_stride == n);
if (clamp_min == 0 && clamp_max == 0) {
clamp_min = std::numeric_limits<float>::lowest();
clamp_max = std::numeric_limits<float>::max();
}

auto ukernel = get_ukernel();
ukernel.run_matmul(
m,
n,
k,
group_size,
activation_data,
weight_data,
output,
/*dst_stride_row=*/ n * sizeof(float),
/*dst_stride_col=*/ sizeof(float),
clamp_min,
clamp_max);
}

size_t get_preferred_alignement() {
return 16;
}
} // namespace neon_dotprod_1x4x32
} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p
} // namespace torchao::kernels::cpu::aarch64::kleidi
Loading
Loading