-
Notifications
You must be signed in to change notification settings - Fork 62
[SYCL-TLA] Integrate FlashAttention fwd/bwd kernels #2341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
LuFinch
wants to merge
7
commits into
main
Choose a base branch
from
lfq/flash_attention
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+4,808
−3
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
a7deb96
mha fwd/bwd kernel integration
LuFinch 76aec4e
install header
LuFinch d33b6a5
fix build warning
LuFinch 54a5ca5
rebase forwardkernel
LuFinch 991ee97
fix CI build error
LuFinch 89c6a49
rebase to latest
LuFinch b61325e
pad input tensors if headdim is not multiple of 64
LuFinch File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
109 changes: 109 additions & 0 deletions
109
src/ATen/native/transformers/xpu/flash_attn/flash_api.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| #include <ATen/native/transformers/xpu/flash_attn/flash_api.h> | ||
| #include <ATen/native/transformers/xpu/flash_attn/sycltla/flash_api.h> | ||
|
|
||
| namespace sycltla { | ||
|
|
||
| bool is_flash_attention_available() { | ||
| #ifndef USE_SYCLTLA | ||
| return false; | ||
| #else | ||
| return true; | ||
| #endif | ||
| } | ||
|
|
||
| std::tuple< | ||
| at::Tensor, | ||
| at::Tensor, | ||
| at::Tensor, | ||
| at::Tensor, | ||
| c10::SymInt, | ||
| c10::SymInt, | ||
| at::Tensor, | ||
| at::Tensor> | ||
| flash_attention_forward( | ||
| const at::Tensor& query, | ||
| const at::Tensor& key, | ||
| const at::Tensor& value, | ||
| const double dropout, | ||
| const bool is_causal, | ||
| const float scale) { | ||
| #ifndef USE_SYCLTLA | ||
| TORCH_CHECK( | ||
| false, | ||
| "flash_attention_forward: Torch XPU was not compiled with SYCLTLA support."); | ||
| return std::make_tuple( | ||
| at::Tensor(), | ||
| at::Tensor(), | ||
| at::Tensor(), | ||
| at::Tensor(), | ||
| c10::SymInt(0), | ||
| c10::SymInt(0), | ||
| at::Tensor(), | ||
| at::Tensor()); | ||
| #else | ||
| auto | ||
| [attention, | ||
| logsumexp, | ||
| cumulative_sequence_length_q, | ||
| cumulative_sequence_length_k, | ||
| max_seqlen_batch_q, | ||
| max_seqlen_batch_k, | ||
| philox_seed, | ||
| philox_offset] = | ||
| flash_attention_forward_sycltla( | ||
| query, key, value, dropout, is_causal, scale); | ||
| return std::make_tuple( | ||
| std::move(attention), | ||
| std::move(logsumexp), | ||
| std::move(cumulative_sequence_length_q), | ||
| std::move(cumulative_sequence_length_k), | ||
| std::move(max_seqlen_batch_q), | ||
| std::move(max_seqlen_batch_k), | ||
| std::move(philox_seed), | ||
| std::move(philox_offset)); | ||
| #endif | ||
| } | ||
|
|
||
| std::tuple<at::Tensor, at::Tensor, at::Tensor> flash_attention_backward( | ||
| const at::Tensor& grad_out, | ||
| const at::Tensor& query, | ||
| const at::Tensor& key, | ||
| const at::Tensor& value, | ||
| const at::Tensor& out, | ||
| const at::Tensor& logsumexp, | ||
| const at::Tensor& cumulative_sequence_length_q, | ||
| const at::Tensor& cumulative_sequence_length_k, | ||
| const int64_t max_seqlen_batch_q, | ||
| const int64_t max_seqlen_batch_k, | ||
| const double dropout, | ||
| const bool is_causal, | ||
| const at::Tensor& philox_seed, | ||
| const at::Tensor& philox_offset, | ||
| const float scale) { | ||
| #ifndef USE_SYCLTLA | ||
| TORCH_CHECK( | ||
| false, | ||
| "flash_attention_backward: Torch XPU was not compiled with SYCLTLA support."); | ||
| return std::make_tuple(at::Tensor(), at::Tensor(), at::Tensor()); | ||
| #else | ||
| auto [grad_query, grad_key, grad_value] = flash_attention_backward_sycltla( | ||
| grad_out, | ||
| query, | ||
| key, | ||
| value, | ||
| out, | ||
| logsumexp, | ||
| cumulative_sequence_length_q, | ||
| cumulative_sequence_length_k, | ||
| max_seqlen_batch_q, | ||
| max_seqlen_batch_k, | ||
| dropout, | ||
| is_causal, | ||
| philox_seed, | ||
| philox_offset, | ||
| scale); | ||
| return std::make_tuple( | ||
| std::move(grad_query), std::move(grad_key), std::move(grad_value)); | ||
| #endif | ||
| } | ||
| } // namespace sycltla |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| #pragma once | ||
|
|
||
| #include <ATen/ATen.h> | ||
|
|
||
| namespace sycltla { | ||
|
|
||
| bool is_flash_attention_available(); | ||
|
|
||
| std::tuple< | ||
| at::Tensor, | ||
| at::Tensor, | ||
| at::Tensor, | ||
| at::Tensor, | ||
| c10::SymInt, | ||
| c10::SymInt, | ||
| at::Tensor, | ||
| at::Tensor> | ||
| flash_attention_forward( | ||
| const at::Tensor& query, | ||
| const at::Tensor& key, | ||
| const at::Tensor& value, | ||
| const double dropout, | ||
| const bool is_causal, | ||
| const float scale); | ||
|
|
||
| std::tuple<at::Tensor, at::Tensor, at::Tensor> flash_attention_backward( | ||
| const at::Tensor& grad_out, | ||
| const at::Tensor& query, | ||
| const at::Tensor& key, | ||
| const at::Tensor& value, | ||
| const at::Tensor& out, | ||
| const at::Tensor& logsumexp, | ||
| const at::Tensor& cumulative_sequence_length_q, | ||
| const at::Tensor& cumulative_sequence_length_k, | ||
| const int64_t max_seqlen_batch_q, | ||
| const int64_t max_seqlen_batch_k, | ||
| const double dropout, | ||
| const bool is_causal, | ||
| const at::Tensor& philox_seed, | ||
| const at::Tensor& philox_offset, | ||
| const float scale); | ||
|
|
||
| } // namespace sycltla |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I think we should install the header file under flash_attn into PyTorch such as line 42
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May I know what is the purpose of installing header file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Give a chance to use them in cpp extension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@guangyey , I think PyTorch does not expose
flash_attnbecause it is the underlying logic ofsdpa, which is exposed as a backend. Meanwhile, I don't believe users invoke theflash_attenof PyTorch becausedao/flash_attenis a better choice.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Meanwhile, the namespace of these functions is
sycltla. It is weird to let users invoke sycl-tla-specific functions.