Skip to content

Commit 28fedb2

Browse files
committed
update template
1 parent 159c5d2 commit 28fedb2

File tree

1 file changed

+47
-52
lines changed

1 file changed

+47
-52
lines changed

torchao/prototype/inductor/codegen/cpp_int8_sdpa_template.py

Lines changed: 47 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,6 @@
704704
#include <ATen/core/Tensor.h>
705705
#include <ATen/cpu/vec/functional.h>
706706
#include <ATen/cpu/vec/vec.h>
707-
#include <ATen/cpu/vec/vec_quant.h>
708707
#include <ATen/cpu/Utils.h>
709708
#include <ATen/native/cpu/utils.h>
710709
#include <ATen/native/CPUBlas.h>
@@ -899,12 +898,12 @@
899898
for (int64_t b = 0; b < kvBlockSize; b += block_64) {
900899
bool istail = kvBlockSize - b < block_64;
901900
int64_t trans_rows = istail ? kvBlockSize - b : block_64;
902-
at::native::utils::transpose<uint8_t>(
903-
headSize,
904-
trans_rows,
901+
do_transpose(
905902
k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN,
906-
kStrideN,
907903
B_blocked_xform_u8,
904+
trans_rows,
905+
headSize,
906+
kStrideN,
908907
block_64);
909908
if (!headSize_mul64 || istail) {
910909
pad_remain_row_col(
@@ -916,24 +915,30 @@
916915
block_64
917916
);
918917
}
919-
at::vec::pack_vnni4(
920-
/* src */ B_blocked_xform_u8,
921-
/* dst */ key_reorder_ptr + n * qk_gemm_K +
922-
b * qk_gemm_K,
923-
/* ld_src */ block_64,
924-
/* K */ qk_gemm_K,
925-
/* N */ block_64);
918+
at::native::cpublas::pack(
919+
qk_gemm_K, // K
920+
block_64, // N
921+
block_64, // ld_in
922+
block_64, // ld_out
923+
u8_dt, // dt_in
924+
u8_dt, // dt_out
925+
B_blocked_xform_u8,
926+
key_reorder_ptr + n * qk_gemm_K +
927+
b * qk_gemm_K);
926928
}
927929
// split headSize to block_64, block_64, block_64 ...
928930
// [av_gemm_K, headSize] -> [av_gemm_K, block_64 ...]
929931
for (int64_t b = 0; b < rndHeadSize; b += block_64) {
930-
at::vec::pack_vnni4(
931-
/* src */ v_data + i * vStrideB + j * vStrideH + n * vStrideN + b,
932-
/* dst */ value_reorder_ptr + n * rndHeadSize +
933-
av_gemm_K * b,
934-
/* ld_src */ vStrideN,
935-
/* K */ av_gemm_K,
936-
/* N */ block_64);
932+
at::native::cpublas::pack(
933+
av_gemm_K,
934+
block_64,
935+
vStrideN, // block_64,
936+
block_64,
937+
u8_dt,
938+
u8_dt,
939+
v_data + i * vStrideB + j * vStrideH + n * vStrideN + b,
940+
value_reorder_ptr + n * rndHeadSize +
941+
av_gemm_K * b);
937942
}
938943
}
939944
@@ -1166,8 +1171,6 @@
11661171
int64_t num_thread = {{num_thread}};
11671172
using accum_t = float;
11681173
using scalar_t = {{kernel.dtype(query)}};
1169-
int block_64 = 64;
1170-
auto u8_dt = at::ScalarType::Byte;
11711174
11721175
// Sizes
11731176
int64_t batchSize = {{kernel.size(query, 0)}};
@@ -1198,15 +1201,11 @@
11981201
int64_t kvSlice = (kvSize - 1) / kvSplitSize + 1;
11991202
int64_t kvTail = (kvSize - 1) % kvSplitSize + 1;
12001203
1201-
int64_t rndHeadSize = (headSize + block_64 - 1L) / block_64 * block_64;
1202-
int64_t rndkvSplitSize = (kvSplitSize + block_64 - 1L) / block_64 * block_64;
1203-
int64_t rndkvTail = (kvTail + block_64 - 1L) / block_64 * block_64;
1204+
int64_t rndHeadSize = headSize % 4 == 0 ? headSize : headSize + 4 - headSize % 4;
1205+
int64_t rndkvSplitSize = kvSplitSize % 4 == 0 ? kvSplitSize : kvSplitSize + 4 - kvSplitSize % 4;
1206+
int64_t rndkvTail = kvTail % 4 == 0 ? kvTail : kvTail + 4 - kvTail % 4;
12041207
int64_t rndkvSize = {{kv_split_size}} > kvSize ? rndkvTail : rndkvSplitSize * kvSlice + rndkvTail;
12051208
1206-
bool av_gemm_K_mul4 = kvSplitSize % 4 == 0;
1207-
int av_gemm_K_padding = av_gemm_K_mul4 ? 0 : 4 - kvSplitSize % 4;
1208-
int av_gemm_K = kvSplitSize + av_gemm_K_padding;
1209-
12101209
{%- if has_attention_mask %}
12111210
// attention mask
12121211
using mask_t = {{kernel.dtype(attention_mask)}};
@@ -1235,24 +1234,20 @@
12351234
const scalar_t* v_data = value;
12361235
scalar_t* out_data = output;
12371236
1238-
bool headSize_mul64 = headSize % 64 == 0;
1239-
int qk_gemm_K_padding = headSize_mul64 ? 0 : 64 - headSize % 64;
1240-
int qk_gemm_K = headSize + qk_gemm_K_padding;
1241-
1242-
int64_t qk_reduce_strideL = qSplitSize * av_gemm_K;
1243-
int64_t v_reorder_strideL = av_gemm_K * rndHeadSize;
1237+
int64_t qk_reduce_strideL = qSplitSize * rndkvSplitSize;
1238+
int64_t v_reorder_strideL = rndkvSplitSize * rndHeadSize;
12441239
12451240
int64_t total_size_uint8_per_thread =
12461241
/* qk */ kvSlice * qSplitSize * rndkvSplitSize * 4 +
1247-
/* qk_local */ kvSlice * av_gemm_K * 4 +
1242+
/* qk_local */ kvSlice * rndkvSplitSize * 4 +
12481243
/* qk_reduce */ kvSlice * qk_reduce_strideL +
12491244
/* qk_s32 */ qSplitSize * rndkvSplitSize * 4 +
12501245
/* dst_s32 */ qSplitSize * rndHeadSize * 4 +
12511246
/* softmax_sum */ qSplitSize * 4 +
12521247
/* query_sum */ qSplitSize * 4 +
12531248
/* attention_sum */ qSplitSize * 4 +
12541249
/* softmax max */ qSplitSize * 4 +
1255-
/* query_padding_data */ qSplitSize * qk_gemm_K;
1250+
/* query_padding_data */ qSplitSize * rndHeadSize;
12561251
{{template.codegen_allocate_buffer("total_buf_data", "scalar_t", "num_thread * total_size_uint8_per_thread")}}
12571252
12581253
int64_t kv_sum_size_per_BH =
@@ -1261,11 +1256,11 @@
12611256
{{template.codegen_allocate_buffer("kv_sum_buf_data", "int32_t", "batchSize * num_head * kv_sum_size_per_BH")}}
12621257
12631258
int64_t kv_reorder_size_per_BH =
1264-
/* key_t_reorder */ qk_gemm_K * rndkvSize +
1259+
/* key_t_reorder */ rndHeadSize * rndkvSize +
12651260
/* value_t_reorder */ kvSlice * v_reorder_strideL;
12661261
{{template.codegen_allocate_buffer("kv_reorder_buf_data", "scalar_t", "batchSize * num_head * kv_reorder_size_per_BH")}}
12671262
scalar_t* key_reorder_ptr = kv_reorder_buf_data;
1268-
scalar_t* value_reorder_ptr = kv_reorder_buf_data + batchSize * num_head * qk_gemm_K * rndkvSize;
1263+
scalar_t* value_reorder_ptr = kv_reorder_buf_data + batchSize * num_head * rndHeadSize * rndkvSize;
12691264
12701265
// sum k and v
12711266
at::parallel_for(
@@ -1305,12 +1300,12 @@
13051300
int64_t i = 0, j = 0, l = 0, n = 0;
13061301
at::native::data_index_init(
13071302
begin, i, batchSize, j, num_head, l, kvSlice);
1308-
uint8_t* B_blocked_xform_u8 = new uint8_t[qk_gemm_K * kvSplitSize];
1303+
uint8_t* B_blocked_xform_u8 = new uint8_t[rndHeadSize * kvSplitSize];
13091304
for (const auto z : c10::irange(begin, end)) {
13101305
(void)z; // Suppress unused variable
13111306
n = l * kvSplitSize;
1312-
auto k_reorder = key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize +
1313-
j * qk_gemm_K * rndkvSize + n * qk_gemm_K;
1307+
auto k_reorder = key_reorder_ptr + i * num_head * rndHeadSize * rndkvSize +
1308+
j * rndHeadSize * rndkvSize + n * rndHeadSize;
13141309
auto v_reorder = value_reorder_ptr +
13151310
i * num_head * kvSlice * v_reorder_strideL +
13161311
j * kvSlice * v_reorder_strideL + n * rndHeadSize;
@@ -1326,13 +1321,13 @@
13261321
/* src */ B_blocked_xform_u8,
13271322
/* dst */ k_reorder,
13281323
/* ld_src */ kvBlockSize,
1329-
/* K */ qk_gemm_K,
1324+
/* K */ rndHeadSize,
13301325
/* N */ kvBlockSize);
13311326
at::vec::pack_vnni4(
13321327
/* src */ v_data + i * vStrideB + j * vStrideH + n * vStrideN,
13331328
/* dst */ v_reorder,
13341329
/* ld_src */ vStrideN,
1335-
/* K */ av_gemm_K,
1330+
/* K */ rndkvSplitSize,
13361331
/* N */ rndHeadSize);
13371332
// Move to the next query
13381333
at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice);
@@ -1350,7 +1345,7 @@
13501345
accum_t* qk_data = reinterpret_cast<accum_t*>(total_buf_ptr);
13511346
offset += kvSlice * qSplitSize * rndkvSplitSize * 4;
13521347
accum_t* qk_local_data = reinterpret_cast<accum_t*>(total_buf_ptr + offset);
1353-
offset += kvSlice * av_gemm_K * 4;
1348+
offset += kvSlice * rndkvSplitSize * 4;
13541349
scalar_t* qk_reduced_data = reinterpret_cast<scalar_t*>(total_buf_ptr + offset);
13551350
offset += kvSlice * qk_reduce_strideL;
13561351
int32_t* qk_s32_data = reinterpret_cast<int32_t*>(total_buf_ptr + offset);
@@ -1401,8 +1396,8 @@
14011396
for (int64_t l = 0; l < rkvSlice; l++) {
14021397
int64_t n = l * kvSplitSize;
14031398
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
1404-
auto k_reorder = key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize +
1405-
j * qk_gemm_K * rndkvSize + n * qk_gemm_K;
1399+
auto k_reorder = key_reorder_ptr + i * num_head * rndHeadSize * rndkvSize +
1400+
j * rndHeadSize * rndkvSize + n * rndHeadSize;
14061401
// Calculate q @ k.T
14071402
at::native::cpublas::brgemm(
14081403
qSplitSize, kvBlockSize, headSize,
@@ -1462,7 +1457,7 @@
14621457
qk_reduce_strideL, //ldo
14631458
kvSize, //kvSize
14641459
rndkvSplitSize, //rndkvSplitSize
1465-
av_gemm_K, //av_gemm_K
1460+
rndkvSplitSize, //av_gemm_K
14661461
{{a_zp}}, // zp_a=beta1
14671462
{{a_scale}}, // scale_a=alpha
14681463
qk_local_data, //local
@@ -1480,7 +1475,7 @@
14801475
qk_reduce_strideL, //ldo
14811476
kvSize, //kvSize
14821477
rndkvSplitSize, //rndkvSplitSize
1483-
av_gemm_K, //av_gemm_K
1478+
rndkvSplitSize, //av_gemm_K
14841479
{{a_zp}}, // zp_a=beta1
14851480
{{v_zp}}, // zp_b=beta2
14861481
{{a_scale}}, // scale_a=alpha
@@ -1497,8 +1492,8 @@
14971492
j * kvSlice * v_reorder_strideL;
14981493
for (int64_t s = 0; s < kvSlice; s++) {
14991494
at::native::cpublas::brgemm(
1500-
qSplitSize, headSize, av_gemm_K,
1501-
av_gemm_K, // lda
1495+
qSplitSize, headSize, rndkvSplitSize,
1496+
rndkvSplitSize, // lda
15021497
rndHeadSize, //ldb
15031498
rndHeadSize, //ldc
15041499
s != 0,
@@ -1675,7 +1670,6 @@ def get_options(
16751670
q_split_size = 256
16761671
elif qSize >= 192:
16771672
q_split_size = 128
1678-
kv_split_size = 512
16791673

16801674
qSplitSize = min(qSize, q_split_size)
16811675
l2_cache_size = torch._C._cpu._L2_cache_size()
@@ -1687,8 +1681,9 @@ def get_options(
16871681
):
16881682
# if not symbolic shape
16891683
use_one_parallel_loop = (batchSize * num_head > num_threads) and (
1690-
attn_size > 1.5 * l2_cache_size
1684+
attn_size > 3 * l2_cache_size
16911685
)
1686+
kv_split_size = 64 if use_one_parallel_loop else 512
16921687

16931688
options = dict(
16941689
q_split_size=q_split_size,

0 commit comments

Comments
 (0)