Skip to content

Commit f3e8972

Browse files
authored
Fix, or rather "port", bug fix for sdpa
Differential Revision: D73640471 Pull Request resolved: #10466
1 parent 7e034ca commit f3e8972

File tree

2 files changed

+64
-22
lines changed

2 files changed

+64
-22
lines changed

extension/llm/custom_ops/op_sdpa_impl.h

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -968,27 +968,36 @@ void cpu_flash_attention(
968968
tmp_max);
969969
}
970970
tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max;
971-
// qk <- exp(qk - max) and sum per row
972-
tmp_sum = tmp_max;
973-
_exp_reduce_sum_fusion_kernel(
974-
qk_data + row * kvBlockSize,
975-
kvBlockSize,
976-
conditional_data_ptr(qk_data, qk_reduced_data) +
977-
row * kvBlockSize,
978-
tmp_sum);
979-
// exp_tmp <- exp(max[row] - max)
980-
exp_tmp = std::exp(qk_max_data[row] - tmp_max);
981-
// sum[row] <- sum + exp_tmp * sum[row]
982-
qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
983-
// max[row] <- max
984-
qk_max_data[row] = tmp_max;
985-
// dst <- dst * exp_tmp
986-
if (n > 0) {
987-
vec::map<accum_t>(
988-
[exp_tmp](Vec x) { return x * Vec(exp_tmp); },
989-
dst_data + row * headSize,
990-
dst_data + row * headSize,
991-
headSize);
971+
if (tmp_max == -std::numeric_limits<accum_t>::infinity()) {
972+
// to avoid `nan = exp2f(-inf - (-inf))`
973+
fill_stub(
974+
conditional_data_ptr(qk_data, qk_reduced_data) +
975+
row * kvBlockSize,
976+
static_cast<scalar_t>(0),
977+
kvBlockSize);
978+
} else {
979+
// qk <- exp(qk - max) and sum per row
980+
tmp_sum = tmp_max;
981+
_exp_reduce_sum_fusion_kernel(
982+
qk_data + row * kvBlockSize,
983+
kvBlockSize,
984+
conditional_data_ptr(qk_data, qk_reduced_data) +
985+
row * kvBlockSize,
986+
tmp_sum);
987+
// exp_tmp <- exp(max[row] - max)
988+
exp_tmp = std::exp(qk_max_data[row] - tmp_max);
989+
// sum[row] <- sum + exp_tmp * sum[row]
990+
qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
991+
// max[row] <- max
992+
qk_max_data[row] = tmp_max;
993+
// dst <- dst * exp_tmp
994+
if (n > 0) {
995+
vec::map<accum_t>(
996+
[exp_tmp](Vec x) { return x * Vec(exp_tmp); },
997+
dst_data + row * headSize,
998+
dst_data + row * headSize,
999+
headSize);
1000+
}
9921001
}
9931002
}
9941003

extension/llm/custom_ops/test_sdpa_with_kv_cache.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,13 @@ def setUp(self):
5353
self.mask = torch.triu(self.mask, diagonal=1)
5454
self.use_mask_with_custom_op = False
5555
self.is_causal = False
56+
self.start_pos = 0
5657

5758
def test_sdpa_with_cache_no_mqa_1(self):
5859
q = torch.rand((1, 1, 8, 4))
5960
k = torch.rand((1, 1, 8, 4))
6061
v = torch.rand((1, 1, 8, 4))
61-
start_pos = 0
62+
start_pos = self.start_pos
6263
seq_len = q.size(1)
6364
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
6465
attn_mask = attn_mask[:, : start_pos + seq_len]
@@ -238,6 +239,38 @@ def setUp(self):
238239
self.use_mask_with_custom_op = True
239240

240241

242+
class SDPAWithAttentionMaskLongSequenceTest(SDPATest):
243+
244+
def setUp(self):
245+
SDPATest.setUp(self)
246+
max_context_len = 700
247+
context_window_len = 60
248+
self.k_cache = torch.zeros((1, 700, 8, 4))
249+
self.v_cache = torch.zeros((1, 700, 8, 4))
250+
causal_mask = torch.tril(
251+
torch.ones(
252+
max_context_len,
253+
max_context_len,
254+
dtype=torch.bool,
255+
device="cpu",
256+
)
257+
)
258+
causal_mask2 = torch.tril(
259+
torch.ones(
260+
max_context_len,
261+
max_context_len,
262+
dtype=torch.bool,
263+
device="cpu",
264+
),
265+
diagonal=-context_window_len,
266+
)
267+
mask = torch.logical_xor(causal_mask, causal_mask2)
268+
self.mask = torch.where(mask == True, 0.0, float("-inf")) # noqa: E712
269+
270+
self.use_mask_with_custom_op = True
271+
self.start_pos = 575
272+
273+
241274
class SDPAWithCausalTest(SDPATest):
242275

243276
def setUp(self):

0 commit comments

Comments
 (0)