704704#include <ATen/core/Tensor.h>
705705#include <ATen/cpu/vec/functional.h>
706706#include <ATen/cpu/vec/vec.h>
707- #include <ATen/cpu/vec/vec_quant.h>
708707#include <ATen/cpu/Utils.h>
709708#include <ATen/native/cpu/utils.h>
710709#include <ATen/native/CPUBlas.h>
899898 for (int64_t b = 0; b < kvBlockSize; b += block_64) {
900899 bool istail = kvBlockSize - b < block_64;
901900 int64_t trans_rows = istail ? kvBlockSize - b : block_64;
902- at::native::utils::transpose<uint8_t>(
903- headSize,
904- trans_rows,
901+ do_transpose(
905902 k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN,
906- kStrideN,
907903 B_blocked_xform_u8,
904+ trans_rows,
905+ headSize,
906+ kStrideN,
908907 block_64);
909908 if (!headSize_mul64 || istail) {
910909 pad_remain_row_col(
916915 block_64
917916 );
918917 }
919- at::vec::pack_vnni4(
920- /* src */ B_blocked_xform_u8,
921- /* dst */ key_reorder_ptr + n * qk_gemm_K +
922- b * qk_gemm_K,
923- /* ld_src */ block_64,
924- /* K */ qk_gemm_K,
925- /* N */ block_64);
918+ at::native::cpublas::pack(
919+ qk_gemm_K, // K
920+ block_64, // N
921+ block_64, // ld_in
922+ block_64, // ld_out
923+ u8_dt, // dt_in
924+ u8_dt, // dt_out
925+ B_blocked_xform_u8,
926+ key_reorder_ptr + n * qk_gemm_K +
927+ b * qk_gemm_K);
926928 }
927929 // split headSize to block_64, block_64, block_64 ...
928930 // [av_gemm_K, headSize] -> [av_gemm_K, block_64 ...]
929931 for (int64_t b = 0; b < rndHeadSize; b += block_64) {
930- at::vec::pack_vnni4(
931- /* src */ v_data + i * vStrideB + j * vStrideH + n * vStrideN + b,
932- /* dst */ value_reorder_ptr + n * rndHeadSize +
933- av_gemm_K * b,
934- /* ld_src */ vStrideN,
935- /* K */ av_gemm_K,
936- /* N */ block_64);
932+ at::native::cpublas::pack(
933+ av_gemm_K,
934+ block_64,
935+ vStrideN, // block_64,
936+ block_64,
937+ u8_dt,
938+ u8_dt,
939+ v_data + i * vStrideB + j * vStrideH + n * vStrideN + b,
940+ value_reorder_ptr + n * rndHeadSize +
941+ av_gemm_K * b);
937942 }
938943 }
939944
11661171 int64_t num_thread = {{num_thread}};
11671172 using accum_t = float;
11681173 using scalar_t = {{kernel.dtype(query)}};
1169- int block_64 = 64;
1170- auto u8_dt = at::ScalarType::Byte;
11711174
11721175 // Sizes
11731176 int64_t batchSize = {{kernel.size(query, 0)}};
11981201 int64_t kvSlice = (kvSize - 1) / kvSplitSize + 1;
11991202 int64_t kvTail = (kvSize - 1) % kvSplitSize + 1;
12001203
1201- int64_t rndHeadSize = ( headSize + block_64 - 1L) / block_64 * block_64 ;
1202- int64_t rndkvSplitSize = ( kvSplitSize + block_64 - 1L) / block_64 * block_64 ;
1203- int64_t rndkvTail = ( kvTail + block_64 - 1L) / block_64 * block_64 ;
1204+ int64_t rndHeadSize = headSize % 4 == 0 ? headSize : headSize + 4 - headSize % 4 ;
1205+ int64_t rndkvSplitSize = kvSplitSize % 4 == 0 ? kvSplitSize : kvSplitSize + 4 - kvSplitSize % 4 ;
1206+ int64_t rndkvTail = kvTail % 4 == 0 ? kvTail : kvTail + 4 - kvTail % 4 ;
12041207 int64_t rndkvSize = {{kv_split_size}} > kvSize ? rndkvTail : rndkvSplitSize * kvSlice + rndkvTail;
12051208
1206- bool av_gemm_K_mul4 = kvSplitSize % 4 == 0;
1207- int av_gemm_K_padding = av_gemm_K_mul4 ? 0 : 4 - kvSplitSize % 4;
1208- int av_gemm_K = kvSplitSize + av_gemm_K_padding;
1209-
12101209{%- if has_attention_mask %}
12111210 // attention mask
12121211 using mask_t = {{kernel.dtype(attention_mask)}};
12351234 const scalar_t* v_data = value;
12361235 scalar_t* out_data = output;
12371236
1238- bool headSize_mul64 = headSize % 64 == 0;
1239- int qk_gemm_K_padding = headSize_mul64 ? 0 : 64 - headSize % 64;
1240- int qk_gemm_K = headSize + qk_gemm_K_padding;
1241-
1242- int64_t qk_reduce_strideL = qSplitSize * av_gemm_K;
1243- int64_t v_reorder_strideL = av_gemm_K * rndHeadSize;
1237+ int64_t qk_reduce_strideL = qSplitSize * rndkvSplitSize;
1238+ int64_t v_reorder_strideL = rndkvSplitSize * rndHeadSize;
12441239
12451240 int64_t total_size_uint8_per_thread =
12461241 /* qk */ kvSlice * qSplitSize * rndkvSplitSize * 4 +
1247- /* qk_local */ kvSlice * av_gemm_K * 4 +
1242+ /* qk_local */ kvSlice * rndkvSplitSize * 4 +
12481243 /* qk_reduce */ kvSlice * qk_reduce_strideL +
12491244 /* qk_s32 */ qSplitSize * rndkvSplitSize * 4 +
12501245 /* dst_s32 */ qSplitSize * rndHeadSize * 4 +
12511246 /* softmax_sum */ qSplitSize * 4 +
12521247 /* query_sum */ qSplitSize * 4 +
12531248 /* attention_sum */ qSplitSize * 4 +
12541249 /* softmax max */ qSplitSize * 4 +
1255- /* query_padding_data */ qSplitSize * qk_gemm_K ;
1250+ /* query_padding_data */ qSplitSize * rndHeadSize ;
12561251 {{template.codegen_allocate_buffer("total_buf_data", "scalar_t", "num_thread * total_size_uint8_per_thread")}}
12571252
12581253 int64_t kv_sum_size_per_BH =
12611256 {{template.codegen_allocate_buffer("kv_sum_buf_data", "int32_t", "batchSize * num_head * kv_sum_size_per_BH")}}
12621257
12631258 int64_t kv_reorder_size_per_BH =
1264- /* key_t_reorder */ qk_gemm_K * rndkvSize +
1259+ /* key_t_reorder */ rndHeadSize * rndkvSize +
12651260 /* value_t_reorder */ kvSlice * v_reorder_strideL;
12661261 {{template.codegen_allocate_buffer("kv_reorder_buf_data", "scalar_t", "batchSize * num_head * kv_reorder_size_per_BH")}}
12671262 scalar_t* key_reorder_ptr = kv_reorder_buf_data;
1268- scalar_t* value_reorder_ptr = kv_reorder_buf_data + batchSize * num_head * qk_gemm_K * rndkvSize;
1263+ scalar_t* value_reorder_ptr = kv_reorder_buf_data + batchSize * num_head * rndHeadSize * rndkvSize;
12691264
12701265 // sum k and v
12711266 at::parallel_for(
13051300 int64_t i = 0, j = 0, l = 0, n = 0;
13061301 at::native::data_index_init(
13071302 begin, i, batchSize, j, num_head, l, kvSlice);
1308- uint8_t* B_blocked_xform_u8 = new uint8_t[qk_gemm_K * kvSplitSize];
1303+ uint8_t* B_blocked_xform_u8 = new uint8_t[rndHeadSize * kvSplitSize];
13091304 for (const auto z : c10::irange(begin, end)) {
13101305 (void)z; // Suppress unused variable
13111306 n = l * kvSplitSize;
1312- auto k_reorder = key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize +
1313- j * qk_gemm_K * rndkvSize + n * qk_gemm_K ;
1307+ auto k_reorder = key_reorder_ptr + i * num_head * rndHeadSize * rndkvSize +
1308+ j * rndHeadSize * rndkvSize + n * rndHeadSize ;
13141309 auto v_reorder = value_reorder_ptr +
13151310 i * num_head * kvSlice * v_reorder_strideL +
13161311 j * kvSlice * v_reorder_strideL + n * rndHeadSize;
13261321 /* src */ B_blocked_xform_u8,
13271322 /* dst */ k_reorder,
13281323 /* ld_src */ kvBlockSize,
1329- /* K */ qk_gemm_K ,
1324+ /* K */ rndHeadSize ,
13301325 /* N */ kvBlockSize);
13311326 at::vec::pack_vnni4(
13321327 /* src */ v_data + i * vStrideB + j * vStrideH + n * vStrideN,
13331328 /* dst */ v_reorder,
13341329 /* ld_src */ vStrideN,
1335- /* K */ av_gemm_K ,
1330+ /* K */ rndkvSplitSize ,
13361331 /* N */ rndHeadSize);
13371332 // Move to the next query
13381333 at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice);
13501345 accum_t* qk_data = reinterpret_cast<accum_t*>(total_buf_ptr);
13511346 offset += kvSlice * qSplitSize * rndkvSplitSize * 4;
13521347 accum_t* qk_local_data = reinterpret_cast<accum_t*>(total_buf_ptr + offset);
1353- offset += kvSlice * av_gemm_K * 4;
1348+ offset += kvSlice * rndkvSplitSize * 4;
13541349 scalar_t* qk_reduced_data = reinterpret_cast<scalar_t*>(total_buf_ptr + offset);
13551350 offset += kvSlice * qk_reduce_strideL;
13561351 int32_t* qk_s32_data = reinterpret_cast<int32_t*>(total_buf_ptr + offset);
14011396 for (int64_t l = 0; l < rkvSlice; l++) {
14021397 int64_t n = l * kvSplitSize;
14031398 int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
1404- auto k_reorder = key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize +
1405- j * qk_gemm_K * rndkvSize + n * qk_gemm_K ;
1399+ auto k_reorder = key_reorder_ptr + i * num_head * rndHeadSize * rndkvSize +
1400+ j * rndHeadSize * rndkvSize + n * rndHeadSize ;
14061401 // Calculate q @ k.T
14071402 at::native::cpublas::brgemm(
14081403 qSplitSize, kvBlockSize, headSize,
14621457 qk_reduce_strideL, //ldo
14631458 kvSize, //kvSize
14641459 rndkvSplitSize, //rndkvSplitSize
1465- av_gemm_K , //av_gemm_K
1460+ rndkvSplitSize , //av_gemm_K
14661461 {{a_zp}}, // zp_a=beta1
14671462 {{a_scale}}, // scale_a=alpha
14681463 qk_local_data, //local
14801475 qk_reduce_strideL, //ldo
14811476 kvSize, //kvSize
14821477 rndkvSplitSize, //rndkvSplitSize
1483- av_gemm_K , //av_gemm_K
1478+ rndkvSplitSize , //av_gemm_K
14841479 {{a_zp}}, // zp_a=beta1
14851480 {{v_zp}}, // zp_b=beta2
14861481 {{a_scale}}, // scale_a=alpha
14971492 j * kvSlice * v_reorder_strideL;
14981493 for (int64_t s = 0; s < kvSlice; s++) {
14991494 at::native::cpublas::brgemm(
1500- qSplitSize, headSize, av_gemm_K ,
1501- av_gemm_K , // lda
1495+ qSplitSize, headSize, rndkvSplitSize ,
1496+ rndkvSplitSize , // lda
15021497 rndHeadSize, //ldb
15031498 rndHeadSize, //ldc
15041499 s != 0,
@@ -1675,7 +1670,6 @@ def get_options(
16751670 q_split_size = 256
16761671 elif qSize >= 192 :
16771672 q_split_size = 128
1678- kv_split_size = 512
16791673
16801674 qSplitSize = min (qSize , q_split_size )
16811675 l2_cache_size = torch ._C ._cpu ._L2_cache_size ()
@@ -1687,8 +1681,9 @@ def get_options(
16871681 ):
16881682 # if not symbolic shape
16891683 use_one_parallel_loop = (batchSize * num_head > num_threads ) and (
1690- attn_size > 1.5 * l2_cache_size
1684+ attn_size > 3 * l2_cache_size
16911685 )
1686+ kv_split_size = 64 if use_one_parallel_loop else 512
16921687
16931688 options = dict (
16941689 q_split_size = q_split_size ,
0 commit comments