Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions csrc/flash_attn/flash_api_sparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,17 +157,14 @@ mha_fwd_sparse(at::Tensor &q, // batch_size x seqlen_q x num_heads x hea
std::optional<at::Generator> gen_) {

auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
bool is_sm8x = cc_major == 8 && cc_minor >= 0;
bool is_sm90 = cc_major == 9 && cc_minor == 0;
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
// We will support Turing in the near future
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
bool is_sm8x_min = cc_major >= 8;
TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");

auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
if (q_dtype == torch::kBFloat16) {
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
TORCH_CHECK(is_sm8x_min, "bfloat16 is only supported on Ampere GPUs or newer");
}
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
Expand Down Expand Up @@ -342,17 +339,14 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_
std::optional<at::Generator> gen_) {

auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
bool is_sm8x = cc_major == 8 && cc_minor >= 0;
bool is_sm90 = cc_major == 9 && cc_minor == 0;
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
// We will support Turing in the near future
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
bool is_sm8x_min = cc_major >= 8;
TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");

auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
if (q_dtype == torch::kBFloat16) {
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
TORCH_CHECK(is_sm8x_min, "bfloat16 is only supported on Ampere GPUs or newer");
}
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
Expand Down Expand Up @@ -528,4 +522,4 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_
return {out, softmax_lse};
}

} // namespace FLASH_NAMESPACE
} // namespace FLASH_NAMESPACE