Skip to content

Commit 89c6a49

Browse files
committed
rebase to latest
1 parent 991ee97 commit 89c6a49

File tree

2 files changed

+40
-10
lines changed

2 files changed

+40
-10
lines changed

src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ class FMHAPrefill {
164164
Arguments const& args,
165165
void* workspace) {
166166
(void)workspace;
167+
167168
return {
168169
args.mode,
169170
args.problem_shape,
@@ -438,6 +439,29 @@ class FMHAPrefill {
438439
prefetch(tiled_prefetch_v, pVgV(_, i, _, nblock));
439440
}
440441
442+
// Prevnt numerical errors when seq_len_kv is not fully divisible by
443+
// QK_BLK_N
444+
const int item_id = thread_idx % SubgroupSize;
445+
if (seq_len_kv % QK_BLK_N != 0) {
446+
int col_idx = item_id + nblock * QK_BLK_N;
447+
int remainder = seq_len_kv % QK_BLK_N;
448+
int cutoff = (seq_len_kv / QK_BLK_N) * QK_BLK_N + remainder;
449+
450+
CUTLASS_PRAGMA_UNROLL
451+
for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) {
452+
CUTLASS_PRAGMA_UNROLL
453+
for (int m = 0; m < FragsM; m++) {
454+
int row_idx = m * Vec + seq_coord;
455+
CUTLASS_PRAGMA_UNROLL
456+
for (int row = 0; row < Vec; row++, row_idx++) {
457+
if (col_idx >= cutoff) {
458+
tSr(row, m, n) = ElementAccumulator{-INFINITY};
459+
}
460+
}
461+
}
462+
}
463+
}
464+
441465
CollectiveSoftmaxEpilogue softmax(params.softmax);
442466
softmax(nblock == 0, tSr, max_reg, sum_reg, out_reg);
443467
@@ -479,6 +503,8 @@ class FMHAPrefill {
479503
// mask the elements of each tile using the bottom right masking
480504
const int item_id = thread_idx % SubgroupSize;
481505
int col_idx = item_id + (nblock_limit - 1) * QK_BLK_N;
506+
int remainder = seq_len_kv % QK_BLK_N;
507+
int cutoff = (seq_len_kv / QK_BLK_N) * QK_BLK_N + remainder;
482508
CUTLASS_PRAGMA_UNROLL
483509
for (int n = 0; n < FragsN;
484510
n++, col_idx += get<1>(MmaAtomShape())) { // 4
@@ -487,8 +513,12 @@ class FMHAPrefill {
487513
int row_idx = m * Vec + seq_coord;
488514
CUTLASS_PRAGMA_UNROLL
489515
for (int row = 0; row < Vec; row++, row_idx++) { // 8
490-
if (row_idx < first_non_masked_sequence ||
491-
col_idx > row_idx - first_non_masked_sequence) {
516+
if (row_idx < first_non_masked_sequence || // for the sequence
517+
// that is fully masked
518+
col_idx > row_idx -
519+
first_non_masked_sequence || // for the bottom right
520+
// triangular masking
521+
col_idx >= cutoff) { // for seq_len_kv not fully divisible
492522
tSr(row, m, n) = ElementAccumulator{-INFINITY};
493523
}
494524
}

src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ void compute_o_dot_do(
2222
const int bidh) {
2323
// The thread index.
2424
constexpr int kBlockM = T::kBlockM;
25-
// constexpr int kBlockN = T::kBlockN;
25+
constexpr int kBlockN = T::kBlockN;
2626
constexpr int kHeadDim = T::kHeadDim;
2727
constexpr int kNSGs = T::kNSGs;
2828
constexpr int SubgroupSize = T::SubgroupSize;
@@ -31,8 +31,8 @@ void compute_o_dot_do(
3131

3232
auto sg = compat::get_nd_item<1>().get_sub_group();
3333
auto group = compat::get_nd_item<1>().get_group();
34-
// auto first_thread_in_sg_idx = sg.get_group_linear_id() *
35-
// trait.SubgroupSize;
34+
auto first_thread_in_sg_idx = sg.get_group_linear_id() * trait.SubgroupSize;
35+
3636
auto bofst = Boffset(param);
3737

3838
const index_t o_offset = bofst.o_offset(bidb, bidh, m_block * kBlockM);
@@ -209,7 +209,7 @@ CUTLASS_DEVICE void apply_mask_causal(
209209
auto sg = compat::get_nd_item<1>().get_sub_group();
210210
auto group = compat::get_nd_item<1>().get_group();
211211
int sg_local_id = sg.get_local_id();
212-
// int sg_group_id = sg.get_group_id();
212+
int sg_group_id = sg.get_group_id();
213213
Tensor rC_2d = make_tensor(rC.data(), convert_layout_2d_layout(rC.layout()));
214214
CUTLASS_PRAGMA_UNROLL
215215
for (int n = 0; n < size<1>(tensor); ++n) {
@@ -371,8 +371,8 @@ void dq_dk_dv_1colblock(
371371
constexpr int kBlockM = Trait::kBlockM;
372372
constexpr int kBlockN = Trait::kBlockN;
373373
constexpr bool is_causal = Trait::is_causal;
374-
// constexpr int kNSGs = Trait::kNSGs;
375-
// constexpr int SubgroupSize = Trait::SubgroupSize;
374+
constexpr int kNSGs = Trait::kNSGs;
375+
constexpr int SubgroupSize = Trait::SubgroupSize;
376376
auto sg = compat::get_nd_item<1>().get_sub_group();
377377
auto group = compat::get_nd_item<1>().get_group();
378378
auto first_thread_in_sg_idx = sg.get_group_linear_id() * trait.SubgroupSize;
@@ -675,7 +675,7 @@ void dq_dk_dv_1colblock(
675675
const int max_m_block = ceil_div(param.seq_len_q, kBlockM);
676676
const int tail_m = param.seq_len_q % kBlockM;
677677

678-
// cutlass::NumericConverter<T, float> converter;
678+
cutlass::NumericConverter<T, float> converter;
679679

680680
// clear accumulator
681681
clear(tdVrdV);
@@ -880,7 +880,7 @@ void convert_dq(
880880
int bidb,
881881
int bidh) {
882882
constexpr int kBlockM = T::kBlockM;
883-
// constexpr int kBlockN = T::kBlockN;
883+
constexpr int kBlockN = T::kBlockN;
884884
constexpr int kHeadDim = T::kHeadDim;
885885
using DType = typename T::DType;
886886
using VType = typename T::VType;

0 commit comments

Comments
 (0)