diff --git a/extension/llm/custom_ops/TARGETS b/extension/llm/custom_ops/TARGETS index 5d0c0490506..61be3d191a7 100644 --- a/extension/llm/custom_ops/TARGETS +++ b/extension/llm/custom_ops/TARGETS @@ -47,3 +47,17 @@ runtime.python_test( "//caffe2:torch", ], ) + +runtime.python_test( + name = "test_quantized_sdpa", + srcs = [ + "test_quantized_sdpa.py", + ], + preload_deps = [ + ":custom_ops_aot_lib_mkl_noomp", + ":custom_ops_aot_py", + ], + deps = [ + "//caffe2:torch", + ], +) diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 202ff17188d..391d2ab0646 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -44,7 +44,9 @@ bool validate_flash_attention_args( "scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size"); ET_CHECK_OR_RETURN_FALSE( - (query.scalar_type() == ScalarType::Float), "Query must be Float type"); + (query.scalar_type() == ScalarType::Float) || + (query.scalar_type() == ScalarType::Char), + "Query must be Float type"); ET_CHECK_OR_RETURN_FALSE( (query.scalar_type() == key.scalar_type()) && @@ -354,9 +356,14 @@ Tensor& custom_sdpa_out_impl( output, "Invalid arguments"); + int64_t seq_len = q.size(1); + auto q_seq_len = q.size(1); + bool is_seq_at_dim_1{true}; if (q.scalar_type() == ScalarType::Char) { is_seq_at_dim_1 = false; + seq_len = q.size(2); + q_seq_len = q.size(2); ET_KERNEL_CHECK_MSG( ctx, q_scales.has_value() && q_zero_points.has_value() && @@ -390,9 +397,6 @@ Tensor& custom_sdpa_out_impl( ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor"); - const int64_t seq_len = q.size(1); - auto q_seq_len = q.size(1); - const int64_t num_keys_for_causal_attention = start_pos + seq_len; ET_KERNEL_CHECK( @@ -418,12 +422,12 @@ Tensor& custom_sdpa_out_impl( is_causal, attn_mask, scale, - nullopt, // q_zero_points - nullopt, // q_scales - nullopt, // k_zero_points - nullopt, // k_scales - nullopt, // v_zero_points - nullopt, // v_scales + q_zero_points, // q_zero_points + q_scales, // q_scales + k_zero_points, // k_zero_points + k_scales, // k_scales + v_zero_points, // v_zero_points + v_scales, // v_scales is_seq_at_dim_1, /* is_seq_at_dim_1 */ start_pos, num_keys_for_causal_attention); @@ -437,12 +441,12 @@ Tensor& custom_sdpa_out_impl( is_causal, attn_mask, scale, - nullopt, // q_zero_points - nullopt, // q_scales - nullopt, // k_zero_points - nullopt, // k_scales - nullopt, // v_zero_points - nullopt, // v_scales + q_zero_points, // q_zero_points + q_scales, // q_scales + k_zero_points, // k_zero_points + k_scales, // k_scales + v_zero_points, // v_zero_points + v_scales, // v_scales is_seq_at_dim_1, /* is_seq_at_dim_1 */ start_pos, num_keys_for_causal_attention); @@ -456,12 +460,12 @@ Tensor& custom_sdpa_out_impl( is_causal, attn_mask, scale, - nullopt, // q_zero_points - nullopt, // q_scales - nullopt, // k_zero_points - nullopt, // k_scales - nullopt, // v_zero_points - nullopt, // v_scales + q_zero_points, // q_zero_points + q_scales, // q_scales + k_zero_points, // k_zero_points + k_scales, // k_scales + v_zero_points, // v_zero_points + v_scales, // v_scales is_seq_at_dim_1, /* is_seq_at_dim_1 */ start_pos, num_keys_for_causal_attention); @@ -470,6 +474,45 @@ Tensor& custom_sdpa_out_impl( return output; } +#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA +Tensor& custom_quantized_sdpa_out( + RuntimeContext& ctx, + const Tensor& q, + const Tensor& k, + const Tensor& v, + const int64_t start_pos, + const optional& attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional scale, + const optional& q_zero_points, + const optional& q_scales, + const optional& k_zero_points, + const optional& k_scales, + const optional& v_zero_points, + const optional& v_scales, + Tensor& output) { + return custom_sdpa_out_impl( + ctx, + q, + k, + v, + start_pos, + attn_mask, + dropout_p, + is_causal, + scale, + output, + q_zero_points, + q_scales, + k_zero_points, + k_scales, + v_zero_points, + v_scales); +} +#endif // ENABLE_CUSTOM_QUANTIZED_SDPA + /* Input params @param[in] q_projected Projected query with query weights. @@ -570,3 +613,10 @@ EXECUTORCH_LIBRARY( llama, "custom_sdpa.out", torch::executor::native::custom_sdpa_out); + +#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA +EXECUTORCH_LIBRARY( + llama, + "custom_quantized_sdpa.out", + torch::executor::native::custom_quantized_sdpa_out); +#endif // ENABLE_CUSTOM_QUANTIZED_SDPA diff --git a/extension/llm/custom_ops/op_sdpa.h b/extension/llm/custom_ops/op_sdpa.h index bc2202b9bd8..92b8a41b706 100644 --- a/extension/llm/custom_ops/op_sdpa.h +++ b/extension/llm/custom_ops/op_sdpa.h @@ -56,6 +56,26 @@ Tensor& flash_attention_kernel_out( const optional scale, Tensor& output); +#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA +Tensor& custom_quantized_sdpa_out( + RuntimeContext& ctx, + const Tensor& q, + const Tensor& k, + const Tensor& v, + const int64_t start_pos, + const optional& attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional scale, + const optional& q_zero_points, + const optional& q_scales, + const optional& k_zero_points, + const optional& k_scales, + const optional& v_zero_points, + const optional& v_scales, + Tensor& output); +#endif // ENABLE_CUSTOM_QUANTIZED_SDPA } // namespace native } // namespace executor } // namespace torch diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index 213adf1c8ab..a3adcbbf866 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -77,6 +77,47 @@ at::Tensor custom_sdpa_aten( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale); +#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA +Tensor& custom_quantized_sdpa_out_no_context( + const Tensor& q, + const Tensor& k, + const Tensor& v, + const int64_t start_pos, + // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional scale, + const optional q_zero_points, + const optional q_scales, + const optional k_zero_points, + const optional k_scales, + const optional v_zero_points, + const optional v_scales, + Tensor& output); + +at::Tensor custom_quantized_sdpa_aten( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const int64_t start_pos, + // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const std::optional scale, + const std::optional& q_zero_points, + const std::optional& q_scales, + const std::optional& k_zero_points, + const std::optional& k_scales, + const std::optional& v_zero_points, + const std::optional& v_scales); +#endif // ENABLE_CUSTOM_QUANTIZED_SDPA + Tensor& update_cache_out_no_context( const Tensor& value, Tensor& cache, @@ -198,6 +239,85 @@ at::Tensor custom_sdpa_aten( return output; } +#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA +Tensor& custom_quantized_sdpa_out_no_context( + const Tensor& q, + const Tensor& k, + const Tensor& v, + const int64_t start_pos, + // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional scale, + const optional q_zero_points, + const optional q_scales, + const optional k_zero_points, + const optional k_scales, + const optional v_zero_points, + const optional v_scales, + Tensor& output) { + executorch::aten::RuntimeContext context{}; + return torch::executor::native::custom_quantized_sdpa_out( + context, + q, + k, + v, + start_pos, + attn_mask, + dropout_p, + is_causal, + scale, + q_zero_points, + q_scales, + k_zero_points, + k_scales, + v_zero_points, + v_scales, + output); +} + +at::Tensor custom_quantized_sdpa_aten( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const int64_t start_pos, + // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const std::optional scale, + const std::optional& q_zero_points, + const std::optional& q_scales, + const std::optional& k_zero_points, + const std::optional& k_scales, + const std::optional& v_zero_points, + const std::optional& v_scales) { + auto output = at::empty(q.sizes()); + WRAP_TO_ATEN(custom_quantized_sdpa_out_no_context, 14) + (q, + k, + v, + start_pos, + attn_mask, + dropout_p, + is_causal, + scale, + q_zero_points, + q_scales, + k_zero_points, + k_scales, + v_zero_points, + v_scales, + output); + return output; +} +#endif // ENABLE_CUSTOM_QUANTIZED_SDPA + Tensor& update_cache_out_no_context( const Tensor& value, Tensor& cache, @@ -245,6 +365,20 @@ TORCH_LIBRARY_FRAGMENT(llama, m) { m.def( "update_cache.out(Tensor value, Tensor(a!) cache, " "SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)"); +#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA + m.def( + "custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, " + "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " + "float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, " + "Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, " + "Tensor? v_scales=None) -> Tensor"); + m.def( + "custom_quantized_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, " + "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " + "float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, " + "Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, " + "Tensor? v_scales=None, *, Tensor(a!) out) -> Tensor(a!)"); +#endif // ENABLE_CUSTOM_QUANTIZED_SDPA } // TODO: Rename this file to op_custom_ops_aot.cpp @@ -263,4 +397,13 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { m.impl( "update_cache.out", WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3)); +#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA + m.impl( + "custom_quantized_sdpa", + torch::executor::native::custom_quantized_sdpa_aten); + m.impl( + "custom_quantized_sdpa.out", + WRAP_TO_ATEN( + torch::executor::native::custom_quantized_sdpa_out_no_context, 14)); +#endif // ENABLE_CUSTOM_QUANTIZED_SDPA } diff --git a/extension/llm/custom_ops/op_sdpa_impl.h b/extension/llm/custom_ops/op_sdpa_impl.h index 0639c539ed1..5a0fb708220 100644 --- a/extension/llm/custom_ops/op_sdpa_impl.h +++ b/extension/llm/custom_ops/op_sdpa_impl.h @@ -23,6 +23,10 @@ #endif #include +#if defined(ENABLE_CUSTOM_QUANTIZED_SDPA) +#include +#endif + namespace torch { namespace executor { @@ -67,7 +71,37 @@ void _q_at_k_gemm( q_data.dtype == ScalarType::Char || q_data.dtype == ScalarType::Float, "q and k must be either int8 or float"); if (q_data.dtype == ScalarType::Char) { - ET_CHECK_MSG(false, "int8 not supported yet"); +#if defined(ENABLE_CUSTOM_QUANTIZED_SDPA) + if constexpr (std::is_same::value) { + int a_stride_m_tmp, b_stride_n_tmp; + auto kernel = torchao::kernels::cpu::quantized_matmul:: + get_int8_a_int8_b_channelwise_qmatmul( + q_m, k_n, qk_k, false, true, a_stride_m_tmp, b_stride_n_tmp); + kernel( + q_m, + k_n, + qk_k, + static_cast(q_data.data), + q_stride_m, + static_cast(k_data.data), + k_stride_n, + qk_data, + k_n, + static_cast(q_data.zero_points), + static_cast(k_data.zero_points), + static_cast(q_data.scales), + static_cast(k_data.scales), + 1, + 1); + } else { + ET_CHECK_MSG( + false, "Accumulation in dtype other than float not supported yet"); + } +#else + ET_CHECK_MSG( + false, + "Quantized SDPA is not enabled. Check ENABLE_CUSTOM_QUANTIZED_SDPA compile flag"); +#endif } else { ::executorch::cpublas::gemm( ::executorch::cpublas::TransposeType::Transpose, @@ -99,7 +133,35 @@ void _qk_at_v_gemm( const int64_t o_stride_m, const accum_t beta) { if (v_data.dtype == ScalarType::Char) { - ET_CHECK_MSG(false, "int8 not supported yet"); +#if defined(ENABLE_CUSTOM_QUANTIZED_SDPA) + if constexpr (std::is_same::value) { + int a_stride_m_tmp, b_stride_n_tmp; + auto kernel = torchao::kernels::cpu::quantized_matmul:: + get_fp32_a_input_channelwise_8bit_b_f32_c_matmul( + m, n, k, false, false, a_stride_m_tmp, b_stride_n_tmp); + kernel( + m, + n, + k, + qk_data, + qk_stride_m /*lhs_stride_m*/, + static_cast(v_data.data), + v_stride_n /*rhs_stride_n*/, + o_data, + o_stride_m /*out_stride_n*/, + static_cast(v_data.zero_points), + static_cast(v_data.scales), + beta, + 1); + } else { + ET_CHECK_MSG( + false, "Accumulation in dtype other than float not supported yet"); + } +#else + ET_CHECK_MSG( + false, + "Quantized SDPA is not enabled. Check ENABLE_CUSTOM_QUANTIZED_SDPA compile flag"); +#endif } else { ::executorch::cpublas::gemm( ::executorch::cpublas::TransposeType::NoTranspose, @@ -394,7 +456,10 @@ void cpu_flash_attention( kvSize); } - bool is_quantized_sdpa = query.scalar_type() == ScalarType::Char; + bool is_quantized_sdpa = false; +#if defined(ENABLE_CUSTOM_QUANTIZED_SDPA) + is_quantized_sdpa = query.scalar_type() == ScalarType::Char; +#endif auto strides = query.strides(); int64_t qStrideB = strides[0]; @@ -426,6 +491,33 @@ void cpu_flash_attention( vStrideN = strides[1]; } + int64_t q_quant_params_StrideB = 0; + int64_t q_quant_params_StrideH = 0; + int64_t q_quant_params_StrideM = 0; + int64_t k_quant_params_StrideB = 0; + int64_t k_quant_params_StrideH = 0; + int64_t k_quant_params_StrideN = 0; + int64_t v_quant_params_StrideB = 0; + int64_t v_quant_params_StrideH = 0; + int64_t v_quant_params_StrideN = 0; + + if (is_quantized_sdpa) { + strides = q_zero_points.value().strides(); + q_quant_params_StrideB = strides[0]; + q_quant_params_StrideH = strides[1]; + q_quant_params_StrideM = strides[2]; + + strides = k_zero_points.value().strides(); + k_quant_params_StrideB = strides[0]; + k_quant_params_StrideH = strides[1]; + k_quant_params_StrideN = strides[2]; + + strides = v_zero_points.value().strides(); + v_quant_params_StrideB = strides[0]; + v_quant_params_StrideH = strides[1]; + v_quant_params_StrideN = strides[2]; + } + strides = output.strides(); int64_t oStrideB = strides[0]; int64_t oStrideH = strides[1]; @@ -473,7 +565,11 @@ void cpu_flash_attention( /* qk_sum */ qSplitSize + /* dst */ qSplitSize * headSize; - int64_t size_bytes = size_per_thread * num_thread * query.element_size(); + // Since all intermediate compute is accum_t, we need to + // allocate a buffer accordingly. + int64_t size_of_intermediate_precision = sizeof(accum_t); + int64_t size_bytes = size_per_thread * num_thread * query.element_size() * + size_of_intermediate_precision; std::vector buf_vec(size_bytes); void* buf = reinterpret_cast(buf_vec.data()); // Need to double check the following @@ -559,14 +655,18 @@ void cpu_flash_attention( int64_t q_offset = i * qStrideB + j * qStrideH + m * qStrideM; int64_t k_offset = i * kStrideB + j_kv * kStrideH + n * kStrideN; if (is_quantized_sdpa) { - ET_CHECK_MSG( - !is_seq_at_dim_1, "For quantized SDPA, seq_len must be at dim 2"); - q_scales_ptr = q_scales.value().const_data_ptr() + q_offset; - k_scales_ptr = k_scales.value().const_data_ptr() + k_offset; - q_zero_points_ptr = - q_zero_points.value().const_data_ptr() + q_offset; - k_zero_points_ptr = - k_zero_points.value().const_data_ptr() + k_offset; + int64_t q_quant_params_offset = i * q_quant_params_StrideB + + j * q_quant_params_StrideH + m * q_quant_params_StrideM; + int64_t k_quant_params_offset = i * k_quant_params_StrideB + + j_kv * k_quant_params_StrideH + n * k_quant_params_StrideN; + q_scales_ptr = + q_scales.value().const_data_ptr() + q_quant_params_offset; + k_scales_ptr = + k_scales.value().const_data_ptr() + k_quant_params_offset; + q_zero_points_ptr = q_zero_points.value().const_data_ptr() + + q_quant_params_offset; + k_zero_points_ptr = k_zero_points.value().const_data_ptr() + + k_quant_params_offset; q_sub_matrix_data_ptr = (const int8_t*)(q_data) + q_offset; k_sub_matrix_data_ptr = (const int8_t*)(k_data) + k_offset; } else { @@ -719,11 +819,12 @@ void cpu_flash_attention( const int8_t* v_zero_points_ptr = nullptr; int64_t v_offset = i * vStrideB + j_kv * vStrideH + n * vStrideN; if (is_quantized_sdpa) { - ET_CHECK_MSG( - !is_seq_at_dim_1, "For quantized SDPA, seq_len must be at dim 2"); - v_scales_ptr = v_scales.value().const_data_ptr() + v_offset; - v_zero_points_ptr = - v_zero_points.value().const_data_ptr() + v_offset; + int64_t v_quant_params_offset = i * v_quant_params_StrideB + + j_kv * v_quant_params_StrideH + n * v_quant_params_StrideN; + v_scales_ptr = + v_scales.value().const_data_ptr() + v_quant_params_offset; + v_zero_points_ptr = v_zero_points.value().const_data_ptr() + + v_quant_params_offset; v_sub_matrix_data_ptr = (const int8_t*)(v_data) + v_offset; } else { v_sub_matrix_data_ptr = (const scalar_t*)(v_data) + v_offset; diff --git a/extension/llm/custom_ops/targets.bzl b/extension/llm/custom_ops/targets.bzl index 5b68715e401..545f6516bb7 100644 --- a/extension/llm/custom_ops/targets.bzl +++ b/extension/llm/custom_ops/targets.bzl @@ -9,6 +9,18 @@ load( "get_compiler_optimization_flags", ) +def _get_quantized_sdpa_deps(): + if runtime.is_oss: + return [] + else: + return ["//pytorch/ao/torchao/experimental/kernels/cpu/interface:interface"] + +def _get_quantized_preproc_flags(): + if runtime.is_oss: + return [] + else: + return ["-DENABLE_CUSTOM_QUANTIZED_SDPA"] + def define_common_targets(): """Defines targets that should be shared between fbcode and xplat. @@ -33,7 +45,8 @@ def define_common_targets(): headers = [ "op_sdpa_impl.h", ], - preprocessor_flags = get_vec_preprocessor_flags(), + exported_preprocessor_flags = get_vec_preprocessor_flags() + + _get_quantized_preproc_flags(), exported_deps = [ "//executorch/runtime/kernel:kernel_includes", "//executorch/kernels/portable/cpu:scalar_utils", @@ -45,8 +58,12 @@ def define_common_targets(): deps = [ "//executorch/kernels/portable/cpu/util:reduce_util", "//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform", - ] + get_vec_deps(), - compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"] + get_compiler_optimization_flags(), + ] + get_vec_deps() + _get_quantized_sdpa_deps(), + compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"] + get_compiler_optimization_flags() + + select({ + "DEFAULT": [], + "ovr_config//cpu:arm64": ["-march=armv8.2-a+dotprod"], + }), visibility = [ "//executorch/...", "//executorch/extension/llm/custom_ops/...", diff --git a/extension/llm/custom_ops/test_quantized_sdpa.py b/extension/llm/custom_ops/test_quantized_sdpa.py new file mode 100644 index 00000000000..f5540a4e614 --- /dev/null +++ b/extension/llm/custom_ops/test_quantized_sdpa.py @@ -0,0 +1,470 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch +import torch.nn.functional as F + +from .custom_ops import custom_ops_lib # noqa + + +class SDPATestForCustomQuantizedSDPA(unittest.TestCase): + """ + This test is to test the custom quantized SDPA op + Tensors are in [B, H, S, D] format + """ + + def setUp(self): + from torch.ao.quantization.fx._decomposed import ( # noqa: F401 + quantized_decomposed_lib, + ) + + torch.manual_seed(42) + self.n_batch = 1 + self.n_heads_kv = 32 + self.n_heads_q = 32 + self.head_dim = 128 + self.max_seq_len = 2048 + self.quantized_dtype = torch.int8 + self.float_dtype = torch.float32 + self.q_shape = None + self.kv_shape = None + + def _scale_tensor(self, tensor, min_value, max_value, scale=True): + normalized_tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) + + scaled_tensor = normalized_tensor * (max_value - min_value) + min_value + + return scaled_tensor if scale else tensor + + def setup_caches_and_mask(self, tensor_scale_max, tensor_scale_min, scale_tensors): + self.mask = torch.full( + (self.max_seq_len, self.max_seq_len), + float("-inf"), + ) + self.mask = torch.triu(self.mask, diagonal=1) + + self.k = self._scale_tensor( + torch.rand(self.kv_shape), + tensor_scale_max, + tensor_scale_min, + scale_tensors, + ) + self.v = self._scale_tensor( + torch.rand(self.kv_shape), + tensor_scale_max, + tensor_scale_min, + scale_tensors, + ) + + def _sdpa_ref( + self, + q_quantized, + k_quantized, + v_quantized, + start_pos, + q_zero_point, + q_scale, + k_zero_point, + k_scale, + v_zero_point, + v_scale, + attn_mask, + ): + q = torch.ops.quantized_decomposed.dequantize_per_token( + q_quantized, + q_scale, + q_zero_point, + torch.iinfo(self.quantized_dtype).min, + torch.iinfo(self.quantized_dtype).max, + self.quantized_dtype, + self.float_dtype, + ) + k = torch.ops.quantized_decomposed.dequantize_per_token( + k_quantized, + k_scale, + k_zero_point, + torch.iinfo(self.quantized_dtype).min, + torch.iinfo(self.quantized_dtype).max, + self.quantized_dtype, + self.float_dtype, + ) + v = torch.ops.quantized_decomposed.dequantize_per_token( + v_quantized, + v_scale, + v_zero_point, + torch.iinfo(self.quantized_dtype).min, + torch.iinfo(self.quantized_dtype).max, + self.quantized_dtype, + self.float_dtype, + ) + + num_heads_q = q.size(1) + num_heads_kv = k.size(1) + seq_len = q.size(2) + k = torch.narrow(k, 2, 0, start_pos + seq_len) + v = torch.narrow(v, 2, 0, start_pos + seq_len) + if num_heads_q != num_heads_kv: + assert ( + num_heads_q % num_heads_kv == 0 + ), f"{num_heads_q} not divisible by {num_heads_kv}" + n_reps = num_heads_q // num_heads_kv + if n_reps > 1: + k = k.repeat_interleave(n_reps, dim=1) + v = v.repeat_interleave(n_reps, dim=1) + out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + return out + + def _int_matmul( + self, quantized_q, quantized_k, q_zero_points, q_scale, k_zero_point, k_scale + ): + row_sum_q = torch.sum(quantized_q, dim=-1, keepdim=True) + row_sum_k = torch.sum(quantized_k, dim=-1, keepdim=True) + q_at_k = torch.matmul(quantized_q, quantized_k.transpose(-2, -1)) + row_sum_q_scaled = row_sum_q * k_zero_point.squeeze(-1).unsqueeze(0) + row_sum_k_scaled = q_zero_points * row_sum_k.squeeze(-1).unsqueeze(0) + zero_points_product = ( + quantized_q.size(-1) * q_zero_points * k_zero_point.squeeze(-1).unsqueeze(0) + ) + res = q_at_k - row_sum_q_scaled - row_sum_k_scaled + zero_points_product + q_scale_mul_k_scale = q_scale * k_scale.squeeze(-1).unsqueeze(0) + res = res.to(torch.float32) * q_scale_mul_k_scale + return res + + def _quantized_sdpa_ref( + self, + quantized_q, + quantized_k, + quantized_v, + q_zero_points, + q_scale, + k_scale, + k_zero_point, + v_scale, + v_zero_point, + attn_mask, + ): + import math + + quantized_q = quantized_q.to(torch.int32) + quantized_k = quantized_k.to(torch.int32) + quantized_v = quantized_v.to(torch.int32) + batch_size = quantized_q.size(0) + num_heads_q = quantized_q.size(1) + num_heads_kv = quantized_k.size(1) + q_scale = q_scale.to(torch.float32) + k_scale = k_scale.to(torch.float32) + q_zero_points = q_zero_points.to(torch.int32) + k_zero_point = k_zero_point.to(torch.int32) + if num_heads_q != num_heads_kv: + assert ( + num_heads_q % num_heads_kv == 0 + ), f"{num_heads_q} not divisible by {num_heads_kv}" + n_reps = num_heads_q // num_heads_kv + if n_reps > 1: + quantized_k = quantized_k.repeat_interleave(n_reps, dim=1) + quantized_v = quantized_v.repeat_interleave(n_reps, dim=1) + res_b = [] + scale_factor = 1 / math.sqrt(quantized_k.size(-1)) + dequantized_v = torch.ops.quantized_decomposed.dequantize_per_token( + quantized_v, + v_scale, + v_zero_point, + torch.iinfo(torch.int8).min, + torch.iinfo(torch.int8).max, + torch.int8, + torch.float32, + ) + for b in range(batch_size): + res_h = [] + for h in range(num_heads_q): + q_at_k = self._int_matmul( + quantized_q[b][h], + quantized_k[b][h], + q_zero_points[b][h], + q_scale[b][h], + k_zero_point[b][h], + k_scale[b][h], + ) + q_at_k = q_at_k * scale_factor + q_at_k += attn_mask + attn_weight = torch.softmax(q_at_k, dim=-1) + y = torch.matmul(attn_weight, dequantized_v[b][h]) + res_h.append(y) + res = torch.stack(res_h, dim=0) + res_b.append(res.unsqueeze(0)) + res = torch.cat(res_b, dim=0) + return res + + def _test_sdpa_common( + self, + n_heads_kv, + n_heads_q, + head_dim, + max_seq_len, + start_pos, + seq_len, + scale_tensors=False, + atol=1e-5, + is_seq_at_dim_2=True, + ): + # Range arbitrarily chosen to reproduce a numerical error on x86 in some of the long context tests + tensor_scale_max = 15 + tensor_scale_min = -15 + self.n_heads_kv = n_heads_kv + self.n_heads_q = n_heads_q + self.head_dim = head_dim + self.max_seq_len = max_seq_len + seq_dim = 2 + self.q_shape = (self.n_batch, self.n_heads_q, seq_len, self.head_dim) + self.kv_shape = (self.n_batch, self.n_heads_q, self.max_seq_len, self.head_dim) + if not is_seq_at_dim_2: + seq_dim = 1 + self.q_shape = (self.n_batch, seq_len, self.n_heads_q, self.head_dim) + self.kv_shape = ( + self.n_batch, + self.max_seq_len, + self.n_heads_kv, + self.head_dim, + ) + + q = self._scale_tensor( + torch.rand(self.q_shape), + tensor_scale_max, + tensor_scale_min, + scale_tensors, + ) + self.setup_caches_and_mask(tensor_scale_max, tensor_scale_min, scale_tensors) + k = self.k + v = self.v + + quantized_dtype = self.quantized_dtype + q_scale, q_zero_point = ( + torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default( + q, quantized_dtype + ) + ) + k_scale, k_zero_point = ( + torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default( + k, quantized_dtype + ) + ) + v_scale, v_zero_point = ( + torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default( + v, quantized_dtype + ) + ) + + q_quantized = torch.ops.quantized_decomposed.quantize_per_token( + q, + q_scale, + q_zero_point, + torch.iinfo(quantized_dtype).min, + torch.iinfo(quantized_dtype).max, + quantized_dtype, + ) + k_quantized = torch.ops.quantized_decomposed.quantize_per_token( + k, + k_scale, + k_zero_point, + torch.iinfo(quantized_dtype).min, + torch.iinfo(quantized_dtype).max, + quantized_dtype, + ) + v_quantized = torch.ops.quantized_decomposed.quantize_per_token( + v, + v_scale, + v_zero_point, + torch.iinfo(quantized_dtype).min, + torch.iinfo(quantized_dtype).max, + quantized_dtype, + ) + + start_pos = 0 + seq_len = q.size(seq_dim) + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] + + # quantized_sdpa_ref_output = self._quantized_sdpa_ref(q_quantized, k_quantized, v_quantized, q_zero_point, q_scale, k_scale, k_zero_point, v_scale, v_zero_point, attn_mask) + + from torch.nn.attention import SDPBackend + + with torch.nn.attention.sdpa_kernel( + [SDPBackend.FLASH_ATTENTION] + ), torch.no_grad(): + ref_output = self._sdpa_ref( + q_quantized, + k_quantized, + v_quantized, + start_pos, + q_zero_point, + q_scale, + k_zero_point, + k_scale, + v_zero_point, + v_scale, + attn_mask, + ) + + q_zero_point_int8 = q_zero_point.to(dtype=torch.int8) + k_zero_point_int8 = k_zero_point.to(dtype=torch.int8) + v_zero_point_int8 = v_zero_point.to(dtype=torch.int8) + q_scale_fp32 = q_scale.to(dtype=torch.float32) + k_scale_fp32 = k_scale.to(dtype=torch.float32) + v_scale_fp32 = v_scale.to(dtype=torch.float32) + + op_output = torch.ops.llama.custom_quantized_sdpa( + q_quantized, + k_quantized, + v_quantized, + start_pos, + None, + 0, + True, + None, + q_zero_point_int8, + q_scale_fp32, + k_zero_point_int8, + k_scale_fp32, + v_zero_point_int8, + v_scale_fp32, + ) + self.assertTrue(torch.allclose(ref_output, op_output, atol=atol)) + # Following line crashes due to some weird issues in mkldnn with crash in mkl_sgemm with `wild jump` + # self.assertTrue(torch.allclose(ref_output, quantized_sdpa_ref_output, atol=1e-3)) + + start_pos = seq_len + seq_len = q.size(seq_dim) + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] + with torch.nn.attention.sdpa_kernel( + [SDPBackend.FLASH_ATTENTION] + ), torch.no_grad(): + ref_output = self._sdpa_ref( + q_quantized, + k_quantized, + v_quantized, + start_pos, + q_zero_point, + q_scale, + k_zero_point, + k_scale, + v_zero_point, + v_scale, + attn_mask, + ) + op_output = torch.ops.llama.custom_quantized_sdpa( + q_quantized, + k_quantized, + v_quantized, + start_pos, + None, + 0, + True, + None, + q_zero_point_int8, + q_scale_fp32, + k_zero_point_int8, + k_scale_fp32, + v_zero_point_int8, + v_scale_fp32, + ) + self.assertTrue(torch.allclose(ref_output, op_output, atol=atol)) + + def test_sdpa_with_custom_quantized(self): + n_heads_kv = 8 + n_heads_q = 8 + head_dim = 128 + max_seq_len = 2048 + seq_len = 24 + start_pos = 0 + self._test_sdpa_common( + n_heads_kv, + n_heads_q, + head_dim, + max_seq_len, + start_pos, + seq_len, + True, + atol=1e-4, + ) + + def test_sdpa_with_custom_quantized_seq_len_1(self): + n_heads_kv = 4 + n_heads_q = 4 + head_dim = 4 + max_seq_len = 8 + seq_len = 1 + start_pos = 0 + self._test_sdpa_common( + n_heads_kv, n_heads_q, head_dim, max_seq_len, start_pos, seq_len + ) + + def test_sdpa_with_custom_quantized_seq_len_small(self): + n_heads_kv = 4 + n_heads_q = 4 + head_dim = 4 + max_seq_len = 8 + seq_len = 4 + start_pos = 0 + self._test_sdpa_common( + n_heads_kv, n_heads_q, head_dim, max_seq_len, start_pos, seq_len + ) + + def test_sdpa_with_custom_quantized_seq_len_llava_example(self): + n_heads_kv = 32 + n_heads_q = 32 + head_dim = 128 + max_seq_len = 2048 + seq_len = 634 + start_pos = 0 + self._test_sdpa_common( + n_heads_kv, n_heads_q, head_dim, max_seq_len, start_pos, seq_len + ) + + def test_sdpa_with_custom_quantized_seq_len_130_gqa(self): + n_heads_kv = 8 + n_heads_q = 32 + head_dim = 128 + max_seq_len = 2048 + seq_len = 130 + start_pos = 0 + # For some reason when scaling tensors, the test fails with smaller atol + self._test_sdpa_common( + n_heads_kv, + n_heads_q, + head_dim, + max_seq_len, + start_pos, + seq_len, + True, + atol=1e-3, + ) + + def test_sdpa_with_custom_quantized_seq_len_llava_example_gqa(self): + n_heads_kv = 16 + n_heads_q = 32 + head_dim = 128 + max_seq_len = 2048 + seq_len = 634 + start_pos = 0 + self._test_sdpa_common( + n_heads_kv, n_heads_q, head_dim, max_seq_len, start_pos, seq_len + ) + + def test_sdpa_with_cache_mqa(self): + n_heads_kv = 1 + n_heads_q = 8 + head_dim = 128 + max_seq_len = 2048 + seq_len = 24 + start_pos = 0 + self._test_sdpa_common( + n_heads_kv, n_heads_q, head_dim, max_seq_len, start_pos, seq_len + ) diff --git a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py index a1f054a153e..41497b17a66 100644 --- a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py +++ b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py @@ -489,11 +489,11 @@ def _test_sdpa_common( class SDPATestForLargeSeqLength(SDPATestCommon): def test_sdpa_with_cache_seq_len_130(self): - n_heads_kv = 32 - n_heads_q = 32 + n_heads_kv = 8 + n_heads_q = 8 head_dim = 128 max_seq_len = 2048 - seq_len = 130 + seq_len = 24 self._test_sdpa_common( n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, True )