Skip to content

Commit dcb9915

Browse files
committed
mha fwd/bwd kernel integration
1 parent 8384acf commit dcb9915

15 files changed

+4993
-2
lines changed

src/ATen/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# ATen XPU sources
22

33
file(GLOB xpu_cpp "xpu/*.cpp")
4-
file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp" ${TORCH_ROOT}/aten/src/ATen/native/transformers/xpu/flash_attn/*.cpp)
4+
file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp" "native/transformers/xpu/flash_attn/*.cpp")
55
file(GLOB xpu_sycl "native/xpu/sycl/*.cpp" "native/sparse/xpu/sycl/*.cpp" "native/nested/xpu/sycl/*.cpp" "native/transformers/sycl/*.cpp" "native/quantized/sycl/*.cpp")
6-
file(GLOB xpu_sycltla "${TORCH_ROOT}/aten/src/ATen/native/transformers/xpu/flash_attn/sycltla/*.cpp")
6+
file(GLOB xpu_sycltla "native/transformers/xpu/flash_attn/sycltla/*.cpp")
77

88
list(APPEND ATen_XPU_CPP_SRCS ${xpu_cpp})
99
if(USE_ONEMKL_XPU)
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
#include <ATen/native/transformers/xpu/flash_attn/flash_api.h>
2+
#include <ATen/native/transformers/xpu/flash_attn/sycltla/flash_api.h>
3+
4+
namespace sycltla {
5+
6+
bool is_flash_attention_available() {
7+
#ifndef USE_SYCLTLA
8+
return false;
9+
#else
10+
return true;
11+
#endif
12+
}
13+
14+
std::tuple<
15+
at::Tensor,
16+
at::Tensor,
17+
at::Tensor,
18+
at::Tensor,
19+
c10::SymInt,
20+
c10::SymInt,
21+
at::Tensor,
22+
at::Tensor>
23+
flash_attention_forward(
24+
const at::Tensor& query,
25+
const at::Tensor& key,
26+
const at::Tensor& value,
27+
const double dropout,
28+
const bool is_causal,
29+
const float scale) {
30+
#ifndef USE_SYCLTLA
31+
TORCH_CHECK(
32+
false,
33+
"flash_attention_forward: Torch XPU was not compiled with SYCLTLA support.");
34+
return std::make_tuple(
35+
at::Tensor(),
36+
at::Tensor(),
37+
at::Tensor(),
38+
at::Tensor(),
39+
c10::SymInt(0),
40+
c10::SymInt(0),
41+
at::Tensor(),
42+
at::Tensor());
43+
#else
44+
auto
45+
[attention,
46+
logsumexp,
47+
cumulative_sequence_length_q,
48+
cumulative_sequence_length_k,
49+
max_seqlen_batch_q,
50+
max_seqlen_batch_k,
51+
philox_seed,
52+
philox_offset] =
53+
flash_attention_forward_sycltla(
54+
query, key, value, dropout, is_causal, scale);
55+
return std::make_tuple(
56+
std::move(attention),
57+
std::move(logsumexp),
58+
std::move(cumulative_sequence_length_q),
59+
std::move(cumulative_sequence_length_k),
60+
std::move(max_seqlen_batch_q),
61+
std::move(max_seqlen_batch_k),
62+
std::move(philox_seed),
63+
std::move(philox_offset));
64+
#endif
65+
}
66+
67+
std::tuple<at::Tensor, at::Tensor, at::Tensor> flash_attention_backward(
68+
const at::Tensor& grad_out,
69+
const at::Tensor& query,
70+
const at::Tensor& key,
71+
const at::Tensor& value,
72+
const at::Tensor& out,
73+
const at::Tensor& logsumexp,
74+
const at::Tensor& cumulative_sequence_length_q,
75+
const at::Tensor& cumulative_sequence_length_k,
76+
const int64_t max_seqlen_batch_q,
77+
const int64_t max_seqlen_batch_k,
78+
const double dropout,
79+
const bool is_causal,
80+
const at::Tensor& philox_seed,
81+
const at::Tensor& philox_offset,
82+
const float scale) {
83+
#ifndef USE_SYCLTLA
84+
TORCH_CHECK(
85+
false,
86+
"flash_attention_backward: Torch XPU was not compiled with SYCLTLA support.");
87+
return std::make_tuple(at::Tensor(), at::Tensor(), at::Tensor());
88+
#else
89+
auto [grad_query, grad_key, grad_value] = flash_attention_backward_sycltla(
90+
grad_out,
91+
query,
92+
key,
93+
value,
94+
out,
95+
logsumexp,
96+
cumulative_sequence_length_q,
97+
cumulative_sequence_length_k,
98+
max_seqlen_batch_q,
99+
max_seqlen_batch_k,
100+
dropout,
101+
is_causal,
102+
philox_seed,
103+
philox_offset,
104+
scale);
105+
return std::make_tuple(
106+
std::move(grad_query), std::move(grad_key), std::move(grad_value));
107+
#endif
108+
}
109+
} // namespace sycltla
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
5+
namespace sycltla {
6+
7+
bool is_flash_attention_available();
8+
9+
std::tuple<
10+
at::Tensor,
11+
at::Tensor,
12+
at::Tensor,
13+
at::Tensor,
14+
c10::SymInt,
15+
c10::SymInt,
16+
at::Tensor,
17+
at::Tensor>
18+
flash_attention_forward(
19+
const at::Tensor& query,
20+
const at::Tensor& key,
21+
const at::Tensor& value,
22+
const double dropout,
23+
const bool is_causal,
24+
const float scale);
25+
26+
std::tuple<at::Tensor, at::Tensor, at::Tensor> flash_attention_backward(
27+
const at::Tensor& grad_out,
28+
const at::Tensor& query,
29+
const at::Tensor& key,
30+
const at::Tensor& value,
31+
const at::Tensor& out,
32+
const at::Tensor& logsumexp,
33+
const at::Tensor& cumulative_sequence_length_q,
34+
const at::Tensor& cumulative_sequence_length_k,
35+
const int64_t max_seqlen_batch_q,
36+
const int64_t max_seqlen_batch_k,
37+
const double dropout,
38+
const bool is_causal,
39+
const at::Tensor& philox_seed,
40+
const at::Tensor& philox_offset,
41+
const float scale);
42+
43+
} // namespace sycltla

0 commit comments

Comments
 (0)