forked from Dao-AILab/flash-attention
-
Notifications
You must be signed in to change notification settings - Fork 98
Implements the attention kernel with vertical and slash sparse pattern described in Appendix C.4.2 of https://arxiv.org/abs/2407.02490 (as sparse_attn_func) #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
fcc8a21
Add sparse attention with virtical and slash
minminsun 7b2c7a2
update
minminsun e1a41d0
move sparse_attn to new files
minminsun b206be4
Refine
minminsun 371a03a
remove registering as custom op
minminsun 866ca51
address review comments
minminsun cb64a43
remove window_size and useless code
minminsun 050052d
only keep hdim128
minminsun 7bbe317
add seqlen_q=1 in ut and remove useless code
minminsun b94e887
support batch_size > 1
minminsun d7b3975
add interface sparse_attn_varlen_func
minminsun 0dbe623
remove useless code
minminsun File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
10 changes: 10 additions & 0 deletions
10
csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
// Copyright (c) 2023, Tri Dao. | ||
// Splitting the different head dimensions to different files to speed up compilation. | ||
// This file is auto-generated. See "generate_kernels.py" | ||
|
||
#include "flash_fwd_sparse_launch_template.h" | ||
|
||
template<> | ||
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 128, true>(Flash_fwd_params ¶ms, cudaStream_t stream) { | ||
run_mha_fwd_sparse_hdim128<cutlass::bfloat16_t, true>(params, stream); | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
// Copyright (c) 2023, Tri Dao. | ||
// Splitting the different head dimensions to different files to speed up compilation. | ||
// This file is auto-generated. See "generate_kernels.py" | ||
|
||
#include "flash_fwd_sparse_launch_template.h" | ||
|
||
template<> | ||
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 128, false>(Flash_fwd_params ¶ms, cudaStream_t stream) { | ||
run_mha_fwd_sparse_hdim128<cutlass::bfloat16_t, false>(params, stream); | ||
} |
10 changes: 10 additions & 0 deletions
10
csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
// Copyright (c) 2023, Tri Dao. | ||
// Splitting the different head dimensions to different files to speed up compilation. | ||
// This file is auto-generated. See "generate_kernels.py" | ||
|
||
#include "flash_fwd_sparse_launch_template.h" | ||
|
||
template<> | ||
void run_mha_fwd_sparse_<cutlass::half_t, 128, true>(Flash_fwd_params ¶ms, cudaStream_t stream) { | ||
run_mha_fwd_sparse_hdim128<cutlass::half_t, true>(params, stream); | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
// Copyright (c) 2023, Tri Dao. | ||
// Splitting the different head dimensions to different files to speed up compilation. | ||
// This file is auto-generated. See "generate_kernels.py" | ||
|
||
#include "flash_fwd_sparse_launch_template.h" | ||
|
||
template<> | ||
void run_mha_fwd_sparse_<cutlass::half_t, 128, false>(Flash_fwd_params ¶ms, cudaStream_t stream) { | ||
run_mha_fwd_sparse_hdim128<cutlass::half_t, false>(params, stream); | ||
} |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
/****************************************************************************** | ||
* Copyright (c) 2024, PAI, Alibaba Cloud. | ||
******************************************************************************/ | ||
|
||
#pragma once | ||
|
||
#include "flash_fwd_launch_template.h" | ||
#include "flash_fwd_sparse_kernel.h" | ||
|
||
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_sparse_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { | ||
#if defined(ARCH_SUPPORTS_FLASH) | ||
static_assert(!(Is_causal && Is_local)); // Enforce constraints | ||
flash::compute_sparse_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params); | ||
#else | ||
FLASH_UNSUPPORTED_ARCH | ||
#endif | ||
} | ||
|
||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal> | ||
void run_flash_sparse_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { | ||
constexpr size_t smem_size = Kernel_traits::kSmemSize; | ||
// printf("smem_size = %d\n", smem_size); | ||
|
||
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. | ||
// https://github.com/kokkos/kokkos-kernels/issues/349 | ||
// https://github.com/HazyResearch/flash-attention/issues/21 | ||
|
||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; | ||
dim3 grid(num_m_block, params.b, params.h); | ||
const bool is_even_K = params.d == Kernel_traits::kHeadDim; | ||
const bool return_softmax = params.p_ptr != nullptr; | ||
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { | ||
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { | ||
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { | ||
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { | ||
constexpr bool IsEvenMNConst = false; | ||
constexpr bool Is_local = false; | ||
// Will only return softmax if dropout, to reduce compilation time. | ||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. | ||
// If return_softmax, set IsEvenMNConst to false to reduce number of templates | ||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates | ||
// If Is_local, set Is_causal to false | ||
auto kernel = &flash_fwd_sparse_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>; | ||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>; | ||
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); | ||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>; | ||
if (smem_size >= 48 * 1024) { | ||
C10_CUDA_CHECK(cudaFuncSetAttribute( | ||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); | ||
} | ||
// int ctas_per_sm; | ||
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( | ||
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); | ||
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); | ||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params); | ||
C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||
}); | ||
}); | ||
}); | ||
}); | ||
} | ||
|
||
template<typename T, bool Is_causal> | ||
void run_mha_fwd_sparse_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { | ||
constexpr static int Headdim = 32; | ||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { | ||
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||
}); | ||
} | ||
|
||
template<typename T, bool Is_causal> | ||
void run_mha_fwd_sparse_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { | ||
constexpr static int Headdim = 64; | ||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { | ||
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||
}); | ||
} | ||
|
||
template<typename T, bool Is_causal> | ||
void run_mha_fwd_sparse_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { | ||
constexpr static int Headdim = 96; | ||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { | ||
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||
}); | ||
} | ||
|
||
template<typename T, bool Is_causal> | ||
void run_mha_fwd_sparse_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { | ||
constexpr static int Headdim = 128; | ||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { | ||
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||
}); | ||
} | ||
|
||
template<typename T, bool Is_causal> | ||
void run_mha_fwd_sparse_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { | ||
constexpr static int Headdim = 160; | ||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { | ||
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||
}); | ||
} | ||
|
||
template<typename T, bool Is_causal> | ||
void run_mha_fwd_sparse_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { | ||
constexpr static int Headdim = 192; | ||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { | ||
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||
}); | ||
} | ||
|
||
template<typename T, bool Is_causal> | ||
void run_mha_fwd_sparse_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { | ||
constexpr static int Headdim = 224; | ||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { | ||
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||
}); | ||
} | ||
|
||
template<typename T, bool Is_causal> | ||
void run_mha_fwd_sparse_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { | ||
constexpr static int Headdim = 256; | ||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { | ||
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||
}); | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.