Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmake/SYCLTLA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ macro(replace_cmake_build_flags)
set(CMAKE_CXX_FLAGS_BK "${CMAKE_CXX_FLAGS}")
string(REPLACE "-Werror=format" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
string(REPLACE "-Werror=format" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
string(REPLACE "-Werror=unused-variable" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
string(REPLACE "-Werror=unused-variable" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
endmacro()

macro(restore_cmake_build_flags)
Expand Down
6 changes: 3 additions & 3 deletions src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# ATen XPU sources

file(GLOB xpu_cpp "xpu/*.cpp")
file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp" ${TORCH_ROOT}/aten/src/ATen/native/transformers/xpu/flash_attn/*.cpp)
file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp" "native/transformers/xpu/flash_attn/*.cpp")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think we should install the header file under flash_attn into PyTorch such as line 42

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I know what is the purpose of installing header file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give a chance to use them in cpp extension.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guangyey , I think PyTorch does not expose flash_attn because it is the underlying logic of sdpa, which is exposed as a backend. Meanwhile, I don't believe users invoke the flash_atten of PyTorch because dao/flash_atten is a better choice.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meanwhile, the namespace of these functions is sycltla. It is weird to let users invoke sycl-tla-specific functions.

file(GLOB xpu_sycl "native/xpu/sycl/*.cpp" "native/sparse/xpu/sycl/*.cpp" "native/nested/xpu/sycl/*.cpp" "native/transformers/sycl/*.cpp" "native/quantized/sycl/*.cpp")
file(GLOB xpu_sycltla "${TORCH_ROOT}/aten/src/ATen/native/transformers/xpu/flash_attn/sycltla/*.cpp")
file(GLOB xpu_sycltla "native/transformers/xpu/flash_attn/sycltla/*.cpp")

list(APPEND ATen_XPU_CPP_SRCS ${xpu_cpp})
if(USE_ONEMKL_XPU)
Expand Down Expand Up @@ -40,7 +40,7 @@ install_xpu_headers("native/quantized/xpu/sycl")
install_xpu_headers("native/sparse/xpu")
install_xpu_headers("native/sparse/xpu/sycl")
install_xpu_headers("native/transformers/xpu")
install_xpu_headers("native/transformers/xpu/sycl")
install_xpu_headers("native/transformers/xpu/flash_attn")

if(xpu_ops_generated_headers)
install(FILES ${xpu_ops_generated_headers} DESTINATION ${AT_INSTALL_INCLUDE_DIR}/ATen/ops)
Expand Down
109 changes: 109 additions & 0 deletions src/ATen/native/transformers/xpu/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#include <ATen/native/transformers/xpu/flash_attn/flash_api.h>
#include <ATen/native/transformers/xpu/flash_attn/sycltla/flash_api.h>

namespace sycltla {

bool is_flash_attention_available() {
#ifndef USE_SYCLTLA
return false;
#else
return true;
#endif
}

std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
c10::SymInt,
c10::SymInt,
at::Tensor,
at::Tensor>
flash_attention_forward(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const double dropout,
const bool is_causal,
const float scale) {
#ifndef USE_SYCLTLA
TORCH_CHECK(
false,
"flash_attention_forward: Torch XPU was not compiled with SYCLTLA support.");
return std::make_tuple(
at::Tensor(),
at::Tensor(),
at::Tensor(),
at::Tensor(),
c10::SymInt(0),
c10::SymInt(0),
at::Tensor(),
at::Tensor());
#else
auto
[attention,
logsumexp,
cumulative_sequence_length_q,
cumulative_sequence_length_k,
max_seqlen_batch_q,
max_seqlen_batch_k,
philox_seed,
philox_offset] =
flash_attention_forward_sycltla(
query, key, value, dropout, is_causal, scale);
return std::make_tuple(
std::move(attention),
std::move(logsumexp),
std::move(cumulative_sequence_length_q),
std::move(cumulative_sequence_length_k),
std::move(max_seqlen_batch_q),
std::move(max_seqlen_batch_k),
std::move(philox_seed),
std::move(philox_offset));
#endif
}

std::tuple<at::Tensor, at::Tensor, at::Tensor> flash_attention_backward(
const at::Tensor& grad_out,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& out,
const at::Tensor& logsumexp,
const at::Tensor& cumulative_sequence_length_q,
const at::Tensor& cumulative_sequence_length_k,
const int64_t max_seqlen_batch_q,
const int64_t max_seqlen_batch_k,
const double dropout,
const bool is_causal,
const at::Tensor& philox_seed,
const at::Tensor& philox_offset,
const float scale) {
#ifndef USE_SYCLTLA
TORCH_CHECK(
false,
"flash_attention_backward: Torch XPU was not compiled with SYCLTLA support.");
return std::make_tuple(at::Tensor(), at::Tensor(), at::Tensor());
#else
auto [grad_query, grad_key, grad_value] = flash_attention_backward_sycltla(
grad_out,
query,
key,
value,
out,
logsumexp,
cumulative_sequence_length_q,
cumulative_sequence_length_k,
max_seqlen_batch_q,
max_seqlen_batch_k,
dropout,
is_causal,
philox_seed,
philox_offset,
scale);
return std::make_tuple(
std::move(grad_query), std::move(grad_key), std::move(grad_value));
#endif
}
} // namespace sycltla
43 changes: 43 additions & 0 deletions src/ATen/native/transformers/xpu/flash_attn/flash_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#pragma once

#include <ATen/ATen.h>

namespace sycltla {

bool is_flash_attention_available();

std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
c10::SymInt,
c10::SymInt,
at::Tensor,
at::Tensor>
flash_attention_forward(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const double dropout,
const bool is_causal,
const float scale);

std::tuple<at::Tensor, at::Tensor, at::Tensor> flash_attention_backward(
const at::Tensor& grad_out,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& out,
const at::Tensor& logsumexp,
const at::Tensor& cumulative_sequence_length_q,
const at::Tensor& cumulative_sequence_length_k,
const int64_t max_seqlen_batch_q,
const int64_t max_seqlen_batch_k,
const double dropout,
const bool is_causal,
const at::Tensor& philox_seed,
const at::Tensor& philox_offset,
const float scale);

} // namespace sycltla
Loading
Loading