diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index eb03765700..2da99a70a6 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -154,6 +154,72 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped; } +void set_params_fprop_sparse(Flash_fwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + const at::Tensor block_count, + const at::Tensor block_offset, + const at::Tensor column_count, + const at::Tensor column_index, + at::Tensor out, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_k, + void *p_d, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + const float softcap, + bool seqlenq_ngroups_swapped=false, + const bool unpadded_lse=false) { + set_params_fprop(params, + b, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + h, h_k, + d, d_rounded, + q, k, v, out, + cu_seqlens_q_d, + cu_seqlens_k_d, + seqused_k, + p_d, + softmax_lse_d, + p_dropout, + softmax_scale, + -1, // window_size_left + -1, // window_size_right + softcap, + seqlenq_ngroups_swapped, + unpadded_lse + ); + params.block_count = block_count.const_data_ptr(); + params.block_offset = block_offset.const_data_ptr(); + params.column_count = column_count.const_data_ptr(); + params.column_index = column_index.const_data_ptr(); + TORCH_CHECK(block_count.size(2) == block_offset.size(2)); + TORCH_CHECK(column_index.size(2) == block_offset.size(2)); + TORCH_CHECK(column_count.size(2) == column_index.size(2)); + params.NUM_ROWS = block_count.size(2); + // params.NUM_ROWS should be equal to cdiv(seqlen_q, BLOCK_M), and BLOCK_M has to be 64 for now. + constexpr int BLOCK_M = 64; + int expected_num_rows = (seqlen_q + BLOCK_M - 1) / BLOCK_M; + TORCH_CHECK(params.NUM_ROWS == expected_num_rows); + params.NNZ_S = block_offset.size(3); + params.NNZ_V = column_index.size(3); +} + void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { @@ -168,6 +234,17 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split }); } +void run_mha_fwd_sparse(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { + TORCH_CHECK(params.num_splits <= 1 && !force_split_kernel, "run_mha_fwd_sparse does not support splitkv."); + TORCH_CHECK(params.d == 128, "run_mha_fwd_sparse only supports headdim=128 for now to keep binary small."); + FP16_SWITCH(!params.is_bf16, [&] { + constexpr static int kHeadDim = 128; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_mha_fwd_sparse_(params, stream); + }); + }); +} + // Find the number of splits that maximizes the occupancy. For example, if we have // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is // better than having 3 splits (efficiency = 0.67). However, we also don't want too many @@ -261,6 +338,375 @@ void set_params_alibi(Flash_fwd_params ¶ms, const c10::optional #endif } +std::vector +mha_fwd_sparse(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &block_count, + const at::Tensor &block_offset, + const at::Tensor &column_count, + const at::Tensor &column_index, + const c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size + const c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const double p_dropout, + const double softmax_scale, + bool is_causal, + const double softcap, + const bool return_softmax, + c10::optional gen_) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + // We will support Turing in the near future + // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + if (q_dtype == torch::kBFloat16) { + TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + } + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + int seqlen_q = sizes[1]; + int num_heads = sizes[2]; + const int head_size_og = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be postive"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); + + at::Tensor q_padded, k_padded, v_padded; + if (head_size_og % 8 != 0) { + q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + q_padded = q; + k_padded = k; + v_padded = v; + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og); + if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } + } else { + out = torch::empty_like(q_padded); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + + auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor p; + // Only return softmax if there's dropout to reduce compilation time + if (return_softmax) { + TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); + p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); + } + + Flash_fwd_params params; + set_params_fprop_sparse(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q_padded, k_padded, v_padded, + block_count, block_offset, + column_count, column_index, + out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, + return_softmax ? p.data_ptr() : nullptr, + softmax_lse.data_ptr(), + p_dropout, + softmax_scale, + softcap + ); + + // Keep references to these tensors to extend their lifetime + at::Tensor softmax_lse_accum, out_accum; + std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( + params, batch_size, num_heads, head_size, seqlen_k, seqlen_q, + head_size_rounded, p_dropout, /*num_splits*/ 1, dprops, opts); + + // NOTE(woosuk): Commented out because they are not used in inference. + // // number of times random will be generated per thread, to offset philox counter in thc random + // // state + // // We use a custom RNG that increases the offset by batch_size * nheads * 32. + // int64_t counter_offset = params.b * params.h * 32; + // auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + // auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + // // Forward kernel will populate memory with the seed and offset. + // params.rng_state = reinterpret_cast(rng_state.data_ptr()); + + // if (p_dropout > 0.0) { + // auto gen = at::get_generator_or_default( + // gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // // See Note [Acquire lock when using random generators] + // std::lock_guard lock(gen->mutex_); + // params.philox_args = gen->philox_cuda_state(counter_offset); + // } + + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + + if (seqlen_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd_sparse(params, stream); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + at::Tensor out_padded = out; + if (head_size_og % 8 != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + } + + return {out, softmax_lse}; +} + +std::vector +mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i. + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i. + const at::Tensor &block_count, + const at::Tensor &block_offset, + const at::Tensor &column_count, + const at::Tensor &column_index, + const c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + const c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + const c10::optional &alibi_slopes_, // num_heads or b x num_heads + int64_t max_seqlen_q, + const int64_t max_seqlen_k, + const double p_dropout, + const double softmax_scale, + const bool zero_tensors, + bool is_causal, + const double softcap, + const bool return_softmax, + c10::optional gen_) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + // We will support Turing in the near future + // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + if (q_dtype == torch::kBFloat16) { + TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + } + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(cu_seqlens_q); + CHECK_DEVICE(cu_seqlens_k); + + at::Tensor block_table; + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + + const int batch_size = cu_seqlens_q.numel() - 1; + int num_heads = sizes[1]; + const int head_size_og = sizes[2]; + const int num_heads_k = k.size(1); + + if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + + if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case + + void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); + + const int total_q = q.sizes()[0]; + + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + CHECK_SHAPE(q, total_q, num_heads, head_size_og); + const int total_k = k.size(0); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); + + + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + if (seqused_k.has_value()){ + auto seqused_k_ = seqused_k.value(); + TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); + TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); + TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); + CHECK_SHAPE(seqused_k_, batch_size); + } + + at::Tensor q_padded, k_padded, v_padded; + if (head_size_og % 8 != 0) { + q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + q_padded = q; + k_padded = k; + v_padded = v; + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og); + if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } + } else { + out = torch::empty_like(q_padded); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + at::Tensor p; + // Only return softmax if there's dropout to reduce compilation time + if (return_softmax) { + TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); + p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); + } + + if (zero_tensors) { + out.zero_(); + softmax_lse.fill_(-std::numeric_limits::infinity()); + if (return_softmax) {p.zero_();} + } + + Flash_fwd_params params; + set_params_fprop_sparse(params, + batch_size, + max_seqlen_q, max_seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q_padded, k_padded, v_padded, + block_count, block_offset, + column_count, column_index, + out, + cu_seqlens_q_d, + cu_seqlens_k.data_ptr(), + seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, + return_softmax ? p.data_ptr() : nullptr, + softmax_lse.data_ptr(), + p_dropout, + softmax_scale, + softcap + ); + params.total_q = total_q; + + // Keep references to these tensors to extend their lifetime + at::Tensor softmax_lse_accum, out_accum; + + // NOTE(woosuk): Commented out because they are not used in inference. + // number of times random will be generated per thread, to offset philox counter in thc random + // state + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + // int64_t counter_offset = params.b * params.h * 32; + // auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + // auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + // // Forward kernel will populate memory with the seed and offset. + // params.rng_state = reinterpret_cast(rng_state.data_ptr()); + + // if (p_dropout > 0.0) { + // auto gen = at::get_generator_or_default( + // gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // // See Note [Acquire lock when using random generators] + // std::lock_guard lock(gen->mutex_); + // params.philox_args = gen->philox_cuda_state(counter_offset); + // } + + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + + if (max_seqlen_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd_sparse(params, stream); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + // at::Tensor out_padded = out; + if (head_size_og % 8 != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + } + + return {out, softmax_lse}; +} + std::vector mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size @@ -1009,6 +1455,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "-> Tensor[]"); ops.impl("fwd", torch::kCUDA, &mha_fwd); + ops.def("fwd_sparse(Tensor! q, Tensor k, Tensor v, " + "Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, " + "Tensor!? out, Tensor? alibi_slopes, " + "float p_dropout, float softmax_scale, bool is_causal, " + "float softcap, bool return_softmax, Generator? gen)" + "-> Tensor[]"); + ops.impl("fwd_sparse", torch::kCUDA, &mha_fwd_sparse); + ops.def("varlen_fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor cu_seqlens_q, " "Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? block_table, Tensor? alibi_slopes, " "int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, " @@ -1016,6 +1470,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Generator? gen) -> Tensor[]"); ops.impl("varlen_fwd", torch::kCUDA, &mha_varlen_fwd); + ops.def("varlen_fwd_sparse(Tensor! q, Tensor k, Tensor v, " + "Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, " + "Tensor!? out, Tensor cu_seqlens_q, " + "Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? alibi_slopes, " + "int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, " + "bool is_causal, float softcap, bool return_softmax, " + "Generator? gen) -> Tensor[]"); + ops.impl("varlen_fwd_sparse", torch::kCUDA, &mha_varlen_fwd_sparse); + ops.def("fwd_kvcache(Tensor! q, Tensor kcache, Tensor vcache, Tensor? k, Tensor? v, Tensor? seqlens_k, " "Tensor? rotary_cos, Tensor? rotary_sin, Tensor? cache_batch_idx, Tensor? block_table, Tensor? alibi_slopes, " "Tensor!? out, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, " diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 8e2352b660..e16f62892a 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -142,6 +142,15 @@ struct Flash_fwd_params : public Qkv_params { bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). + + // For sparse attention + const int* block_count; + const int* block_offset; + const int* column_count; + const int* column_index; + int NUM_ROWS; + int NNZ_S; + int NNZ_V; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -189,6 +198,7 @@ struct Flash_bwd_params : public Flash_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_sparse_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu new file mode 100644 index 0000000000..057065e07f --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_sparse_launch_template.h" + +template<> +void run_mha_fwd_sparse_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_sparse_hdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu new file mode 100644 index 0000000000..2a2cf8f8dc --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_sparse_launch_template.h" + +template<> +void run_mha_fwd_sparse_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_sparse_hdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu new file mode 100644 index 0000000000..de130d17a9 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_sparse_launch_template.h" + +template<> +void run_mha_fwd_sparse_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_sparse_hdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_sm80.cu new file mode 100644 index 0000000000..1139a1866a --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_sparse_launch_template.h" + +template<> +void run_mha_fwd_sparse_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_sparse_hdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_sparse_kernel.h b/csrc/flash_attn/src/flash_fwd_sparse_kernel.h new file mode 100644 index 0000000000..3e392e9261 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_sparse_kernel.h @@ -0,0 +1,685 @@ +/****************************************************************************** + * Copyright (c) 2024, PAI, Alibaba Cloud. + ******************************************************************************/ + +#pragma once + +#include "flash_fwd_kernel.h" + +namespace flash { + +using namespace cute; + +template +inline __device__ void sparse_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + auto seed_offset = at::cuda::philox::unpack(params.philox_args); + flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t, + bidb, bidh, tidx, params.h); + + // Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might + // exit early and no one saves the rng states. + if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { + params.rng_state[0] = std::get<0>(seed_offset); + params.rng_state[1] = std::get<1>(seed_offset); + } + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); + // } + } + // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. + // Otherwise we might read OOB elements from gK and gV. + // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + + Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.q_row_stride, params.q_head_stride, _1{})); + Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.k_row_stride, params.k_head_stride, _1{})); + Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, + make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) + Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.v_row_stride, params.v_head_stride, _1{})); + Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, + make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) + const index_t row_offset_k_token = + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v_token = + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + + (bidh / params.h_h_k_ratio) * params.v_head_stride; + + Tensor gKToken = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k_token), + Shape, Int>{}, + make_stride(_0{}, _1{})); + Tensor gVToken = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v_token), + Shape, Int>{}, + make_stride(_0{}, _1{})); + + Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), + Shape, Int>{}, + make_stride(params.seqlen_k_rounded, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) + Tensor tKgKBlock = tKgK(_, _, _, 0); + auto tKgKBlockData = tKgKBlock.data(); + Tensor tKgKToken = gmem_thr_copy_QKV.partition_S(gKToken); // (KCPY, KCPY_N, KCPY_K) + auto tKgKTokenData = tKgKToken.data(); + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) + Tensor tVgVBlock = tVgV(_, _, _, 0); + auto tVgVBlockData = tVgVBlock.data(); + Tensor tVgVToken = gmem_thr_copy_QKV.partition_S(gVToken); // (VCPY, VCPY_N, VCPY_K) + auto tVgVTokenData = tVgVToken.data(); + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor tSgS = thr_mma.partition_C(gP); + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + // if (cute::thread0()) {smem_thr_copy_Q.print_all();} + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // + // PREDICATES + // + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) + // if (cute::thread0()) { + // print(tScQ.layout()); printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<0>(tScQ(i))); + // } + // printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<1>(tScQ(i))); + // } + // printf("\n"); + // } + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + cute::cp_async_fence(); + if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } + + // // if (cute::thread(1, 0)) { print(tQsQ); } + // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{}); + // // if (cute::thread0()) { print(sQNoSwizzle); } + + if (Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<0>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); + } + + int n_block = n_block_max - 1; + // block_count: torch.Tensor, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + // block_offset: torch.Tensor, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + // num_blks = tl.load(block_count + off_hz * NUM_ROWS + start_m) + // blks_ptr = block_offset + (off_hz * NUM_ROWS + start_m) * NNZ_S + int num_blks = params.block_count[(bidb * params.h + bidh) * params.NUM_ROWS + m_block]; + auto* blks_ptr = params.block_offset + ((bidb * params.h + bidh) * params.NUM_ROWS + m_block) * params.NNZ_S; + + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } + // __syncthreads(); + + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<1>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + } + + clear(acc_o); + + flash::Softmax<2 * size<1>(acc_o)> softmax; + + const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + // column_count: torch.Tensor, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + // column_index: torch.Tensor, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + // num_cols = tl.load(column_count + off_hz * NUM_ROWS + start_m) + // cols_ptr = column_index + (off_hz * NUM_ROWS + start_m) * NNZ_V + int num_cols = params.column_count[(bidb * params.h + bidh) * params.NUM_ROWS + m_block]; + int num_cols_block = (num_cols + kBlockN - 1)/ kBlockN; + if (num_blks <= 0 && num_cols_block <= 0) { + Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.o_row_stride, params.o_head_stride, _1{})); + Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + + Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } + } + return; + } + if (num_blks > 0) { + int block_index = num_blks - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + tKgKBlock.data() = tKgKBlockData + blks_ptr[block_index] * int64_t(params.k_row_stride); + flash::copy(gmem_tiled_copy_QKV, tKgKBlock, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - blks_ptr[block_index]); + cute::cp_async_fence(); + for (int n = 0; n < n_masking_steps && block_index >= 0; ++n, --block_index) { + int start_n = blks_ptr[block_index]; // replace n_block * kBlockN + + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + tVgVBlock.data() = tVgVBlockData + start_n * int64_t(params.v_row_stride); + if (block_index < num_blks - 1) { + flash::copy(gmem_tiled_copy_QKV, tVgVBlock, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgVBlock, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - start_n + ); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + if constexpr (Is_softcap){ + flash::apply_softcap(acc_s, params.softcap); + } + + mask.template apply_mask( + acc_s, start_n, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (block_index > 0) { + tKgKBlock.data() = tKgKBlockData + blks_ptr[block_index - 1] * int64_t(params.k_row_stride); + flash::copy(gmem_tiled_copy_QKV, tKgKBlock, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // TODO: when we have key_padding_mask we'll need to Check_inf + n == 0 + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(acc_s); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); + if (Return_softmax) { + Tensor rP_drop = make_fragment_like(rP); + cute::copy(rP, rP_drop); + dropout.template apply_dropout( + rP_drop, block_row_idx, block_col_idx, kNWarps + ); + cute::copy(rP_drop, tSgS); + tSgS.data() = tSgS.data() + (-kBlockN); + } + if (Is_dropout) { + dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); + } + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + // if (cute::thread0()) { print(tOrP); } + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } + } + for (; block_index >= 0; --block_index) { + int start_n = blks_ptr[block_index]; // replace n_block * kBlockN + + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + tVgVBlock.data() = tVgVBlockData + start_n * int64_t(params.v_row_stride); + flash::copy(gmem_tiled_copy_QKV, tVgVBlock, tVsV, tKVcKV, tKVpKV); + + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + if constexpr (Is_softcap){ + flash::apply_softcap(acc_s, params.softcap); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + if (block_index > 0) { + tKgKBlock.data() = tKgKBlockData + blks_ptr[block_index - 1] * int64_t(params.k_row_stride); + flash::copy(gmem_tiled_copy_QKV, tKgKBlock, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(acc_s); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); + if (Return_softmax) { + Tensor rP_drop = make_fragment_like(rP); + cute::copy(rP, rP_drop); + dropout.template apply_dropout( + rP_drop, block_row_idx, block_col_idx, kNWarps + ); + cute::copy(rP_drop, tSgS); + tSgS.data() = tSgS.data() + (-kBlockN); + } + if (Is_dropout) { + dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); + } + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + // if (cute::thread0()) { print(tOrP); } + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } + } + } + + if (num_cols > 0) { + auto* cols_ptr = params.column_index + ((bidb * params.h + bidh) * params.NUM_ROWS + m_block) * params.NNZ_V; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + #pragma unroll + for (int m = 0; m < size<1>(tKgKToken); ++m) { + if (Is_even_MN || get<0>(tKVcKV(0, m, 0)) < num_cols) { // Is_even_MN + tKgKToken.data() = tKgKTokenData + cols_ptr[get<0>(tKVcKV(0, m, 0))] * int64_t(params.k_row_stride); + #pragma unroll + for (int k = 0; k < size<2>(tKgKToken); ++k) { + if (Is_even_K || tKVpKV(k)) { + cute::copy(gmem_tiled_copy_QKV, tKgKToken(_, m, k), tKsK(_, m, k)); + } else { // Clear_OOB_K + cute::clear(tKsK(_, m, k)); + } + } + } + } + cute::cp_async_fence(); + for (int n = 0; n < num_cols_block; ++n) { + // cols = tl.load(cols_ptr + start_n + offs_n, mask=n_mask, other=0) + // int start_n = cols_ptr[n * kBlockN]; // replace n_block * kBlockN + + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (n < num_cols_block - 1) { + #pragma unroll + for (int m = 0; m < size<1>(tVgVToken); ++m) { + tVgVToken.data() = tVgVTokenData + cols_ptr[n * kBlockN + get<0>(tKVcKV(0, m, 0))] * int64_t(params.v_row_stride); + #pragma unroll + for (int k = 0; k < size<2>(tVgVToken); ++k) { + if (Is_even_K || tKVpKV(k)) { + cute::copy(gmem_tiled_copy_QKV, tVgVToken(_, m, k), tVsV(_, m, k)); + } else { // Clear_OOB_K + cute::clear(tVsV(_, m, k)); + } + } + } + } else { + // Clear the smem tiles to account for predicated off loads + #pragma unroll + for (int m = 0; m < size<1>(tVgVToken); ++m) { + if (Is_even_MN || n * kBlockN + get<0>(tKVcKV(0, m, 0)) < num_cols) { // Is_even_MN + tVgVToken.data() = tVgVTokenData + cols_ptr[n * kBlockN + get<0>(tKVcKV(0, m, 0))] * int64_t(params.v_row_stride); + #pragma unroll + for (int k = 0; k < size<2>(tVgVToken); ++k) { + if (Is_even_K || tKVpKV(k)) { + cute::copy(gmem_tiled_copy_QKV, tVgVToken(_, m, k), tVsV(_, m, k)); + } else { // Clear_OOB_K + cute::clear(tVsV(_, m, k)); + } + } + } else { // Clear_OOB_MN + cute::clear(tVsV(_, m, _)); + } + } + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + if constexpr (Is_softcap){ + flash::apply_softcap(acc_s, params.softcap); + } + + if (n >= num_cols_block - n_masking_steps) { + Tensor tensor = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = n * kBlockN + (lane_id % 4) * 2; + const int row_idx_offset = m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4; + const int warp_row_stride = kNWarps * 16; + const int max_seqlen_k = binfo.actual_seqlen_k; + const int max_seqlen_q = binfo.actual_seqlen_q; + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Is_causal ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + if (col_idx_base + j >= num_cols || cols_ptr[col_idx_base + j] >= col_idx_limit) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + } + } + } + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n < num_cols_block - 2) { + #pragma unroll + for (int m = 0; m < size<1>(tKgKToken); ++m) { + int token_idx = cols_ptr[(n + 1) * kBlockN + get<0>(tKVcKV(0, m, 0))]; + tKgKToken.data() = tKgKTokenData + token_idx * int64_t(params.k_row_stride); + #pragma unroll + for (int k = 0; k < size<2>(tKgKToken); ++k) { + if (Is_even_K || tKVpKV(k)) { + cute::copy(gmem_tiled_copy_QKV, tKgKToken(_, m, k), tKsK(_, m, k)); + } else { // Clear_OOB_K + cute::clear(tKsK(_, m, k)); + } + } + } + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } else if (n == num_cols_block - 2) { + #pragma unroll + for (int m = 0; m < size<1>(tKgKToken); ++m) { + if (Is_even_MN || (n + 1) * kBlockN + get<0>(tKVcKV(0, m, 0)) < num_cols) { // Is_even_MN + int token_idx = cols_ptr[(n + 1) * kBlockN + get<0>(tKVcKV(0, m, 0))]; + tKgKToken.data() = tKgKTokenData + token_idx * int64_t(params.k_row_stride); + #pragma unroll + for (int k = 0; k < size<2>(tKgKToken); ++k) { + if (Is_even_K || tKVpKV(k)) { + cute::copy(gmem_tiled_copy_QKV, tKgKToken(_, m, k), tKsK(_, m, k)); + } else { // Clear_OOB_K + cute::clear(tKsK(_, m, k)); + } + } + } + } + cute::cp_async_fence(); + } + + // TODO: when we have key_padding_mask we'll need to Check_inf + (num_blks <= 0 && n ==0) + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(acc_s); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); + if (Return_softmax) { + Tensor rP_drop = make_fragment_like(rP); + cute::copy(rP, rP_drop); + dropout.template apply_dropout( + rP_drop, block_row_idx, block_col_idx, kNWarps + ); + cute::copy(rP_drop, tSgS); + tSgS.data() = tSgS.data() + (-kBlockN); + } + if (Is_dropout) { + dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); + } + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + // if (cute::thread0()) { print(tOrP); } + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } + } + } + + // Epilogue + + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); + + // Convert acc_o from fp32 to fp16/bf16 + Tensor rO = flash::convert_type(acc_o); + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } + + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + + Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.o_row_stride, params.o_head_stride, _1{})); + Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + Tensor tOrO = make_tensor(shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +template +inline __device__ void compute_sparse_attn(const Params ¶ms) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting + // them to have the same number of threads or have to traverse the attention matrix + // in the same order. + // In the Philox RNG, we use the offset to store the batch, head, and the lane id + // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within + // the attention matrix. This way, as long as we have the batch, head, and the location of + // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. + + flash::sparse_attn_1rowblock(params, bidb, bidh, m_block); +} + +} // namespace flash diff --git a/csrc/flash_attn/src/flash_fwd_sparse_launch_template.h b/csrc/flash_attn/src/flash_fwd_sparse_launch_template.h new file mode 100644 index 0000000000..c07d954fd9 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_sparse_launch_template.h @@ -0,0 +1,125 @@ +/****************************************************************************** + * Copyright (c) 2024, PAI, Alibaba Cloud. + ******************************************************************************/ + +#pragma once + +#include "flash_fwd_launch_template.h" +#include "flash_fwd_sparse_kernel.h" + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_sparse_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { + #if defined(ARCH_SUPPORTS_FLASH) + static_assert(!(Is_causal && Is_local)); // Enforce constraints + flash::compute_sparse_attn(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif +} + +template +void run_flash_sparse_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr size_t smem_size = Kernel_traits::kSmemSize; + // printf("smem_size = %d\n", smem_size); + + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.b, params.h); + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + const bool return_softmax = params.p_ptr != nullptr; + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + constexpr bool IsEvenMNConst = false; + constexpr bool Is_local = false; + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_sparse_kernel; + // auto kernel = &flash_fwd_kernel; + // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); +} + +template +void run_mha_fwd_sparse_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 32; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_sparse_fwd, Is_dropout, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_sparse_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 64; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_sparse_fwd, Is_dropout, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_sparse_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 96; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_sparse_fwd, Is_dropout, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_sparse_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 128; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_sparse_fwd, Is_dropout, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_sparse_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 160; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_sparse_fwd, Is_dropout, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_sparse_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 192; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_sparse_fwd, Is_dropout, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_sparse_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 224; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_sparse_fwd, Is_dropout, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_sparse_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 256; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_sparse_fwd, Is_dropout, Is_causal>(params, stream); + }); +} diff --git a/csrc/flash_attn/src/generate_kernels.py b/csrc/flash_attn/src/generate_kernels.py index 45fc3d9f10..df76cc3e04 100644 --- a/csrc/flash_attn/src/generate_kernels.py +++ b/csrc/flash_attn/src/generate_kernels.py @@ -25,6 +25,14 @@ }} """ +KERNEL_IMPL_TEMPLATE_FWD_SPARSE = """#include "flash_fwd_sparse_launch_template.h" + +template<> +void run_mha_fwd_sparse_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{ + run_mha_fwd_sparse_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream); +}} +""" + KERNEL_IMPL_TEMPLATE_FWD_SPLIT = """#include "flash_fwd_launch_template.h" template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream); @@ -53,6 +61,10 @@ def template(self) -> str: return KERNEL_IMPL_TEMPLATE_FWD.format( DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal ) + elif self.direction == "fwd_sparse": + return KERNEL_IMPL_TEMPLATE_FWD_SPARSE.format( + DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal + ) elif self.direction == "bwd": return KERNEL_IMPL_TEMPLATE_BWD.format( DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim @@ -68,7 +80,7 @@ def filename(self) -> str: def get_all_kernels() -> List[Kernel]: - for direction in ["fwd", "fwd_split"]: + for direction in ["fwd", "fwd_split", "fwd_sparse"]: for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM): yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction) for direction in ["bwd"]: diff --git a/tests/test_vllm_flash_attn.py b/tests/test_vllm_flash_attn.py index 758c77b68b..d2fd239dc6 100644 --- a/tests/test_vllm_flash_attn.py +++ b/tests/test_vllm_flash_attn.py @@ -267,3 +267,154 @@ def test_varlen_with_paged_kv( ) torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" + +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("seq_lens", [(1, 1), (1, 1024), (1, 2048), (1023, 2049), (1023, 1023), (32, 32), (65, 65), (129, 129)]) +@pytest.mark.parametrize("num_heads", [1, 2, 4]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("NNZ_S", [0, 1, 2, 3, 7, 15, 32]) +@torch.inference_mode() +def test_sparse_attention( + batch_size: int, + seq_lens: Tuple[int, int], + num_heads: int, + head_size: int, + dtype: torch.dtype, + NNZ_S: int, +) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + block_size_M = 64 + block_size_N = 64 + seqlen_q, seqlen_k = seq_lens + q = torch.randn( + batch_size, seqlen_q, num_heads, head_size, dtype=dtype, requires_grad=False + ) + k = torch.randn( + batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False + ) + v = torch.randn( + batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False + ) + NUM_ROWS = (seqlen_q + block_size_M - 1) // block_size_M + if NNZ_S * block_size_N > seqlen_k: + return + NNZ_V = seqlen_k - NNZ_S * block_size_N + block_count = torch.tensor([NNZ_S] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32).reshape(batch_size, num_heads, NUM_ROWS) + column_count = torch.tensor([NNZ_V] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32).reshape(batch_size, num_heads, NUM_ROWS) + block_offset = torch.tensor([[i * block_size_N for i in range(NNZ_S)]] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S) + column_index = torch.tensor([[NNZ_S * block_size_N + i for i in range(NNZ_V)]] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V) + from vllm_flash_attn import sparse_attn_func, flash_attn_func + out, lse = sparse_attn_func( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + return_softmax_lse=True, + ) + + ref_out, ref_lse = flash_attn_func( + q, + k, + v, + return_softmax_lse=True, + ) + + torch.testing.assert_close(out, ref_out, atol=2e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(out - ref_out))}" + torch.testing.assert_close(lse, ref_lse, atol=2e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(lse - ref_lse))}" + +@pytest.mark.parametrize("seq_lens", [[(1024, 1328)], + [(1024, 1328), (1, 2048)], + [(1025, 1328), (2, 2048)], + [(1025, 2049), (2, 1281)], + ]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode() +def test_sparse_attention_varlen( + seq_lens: List[Tuple[int, int]], + head_size: int, + dtype: torch.dtype, +) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + block_size_M = 64 + block_size_N = 64 + num_seqs = len(seq_lens) + query_lens = [x[0] for x in seq_lens] + kv_lens = [x[1] for x in seq_lens] + num_heads = 1 + query = torch.randn(sum(query_lens), + num_heads, + head_size, + dtype=dtype) + key = torch.randn(sum(kv_lens), + num_heads, + head_size, + dtype=dtype) + value = torch.randn_like(key) + cu_query_lens = torch.tensor([0] + query_lens, + dtype=torch.int32).cumsum(dim=0, + dtype=torch.int32) + cu_kv_lens = torch.tensor([0] + kv_lens, + dtype=torch.int32).cumsum(dim=0, + dtype=torch.int32) + max_query_len = max(query_lens) + max_kv_len = max(kv_lens) + + NUM_ROWS = (max_query_len + block_size_M - 1) // block_size_M + NNZ_S = 20 + NNZ_V = 2048 + batch_size = len(query_lens) + + block_counts = [] + column_counts = [] + block_offsets = [] + column_indices = [] + for b in range(batch_size): + block_counts.append(torch.tensor([NNZ_S] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS)) + columns = kv_lens[b] - NNZ_S * block_size_N + column_counts.append(torch.tensor([columns] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS)) + block_offsets.append(torch.tensor([[i * block_size_N for i in range(NNZ_S)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_S)) + column_indices.append(torch.tensor([[NNZ_S * block_size_N + i for i in range(NNZ_V)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_V)) + block_count = torch.concat(block_counts).reshape(batch_size, num_heads, NUM_ROWS) + column_count = torch.concat(column_counts).reshape(batch_size, num_heads, NUM_ROWS) + block_offset = torch.concat(block_offsets).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S) + column_index = torch.concat(column_indices).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V) + from vllm_flash_attn import sparse_attn_varlen_func, flash_attn_varlen_func + out, lse = sparse_attn_varlen_func( + query, + key, + value, + block_count, + block_offset, + column_count, + column_index, + cu_seqlens_q=cu_query_lens, + cu_seqlens_k=cu_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + return_softmax_lse=True, + ) + + ref_out, ref_lse = flash_attn_varlen_func( + query, + key, + value, + cu_seqlens_q=cu_query_lens, + cu_seqlens_k=cu_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + return_softmax_lse=True, + ) + + torch.testing.assert_close(out, ref_out, atol=2e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(out - ref_out))}" + torch.testing.assert_close(lse, ref_lse, atol=2e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(lse - ref_lse))}" diff --git a/vllm_flash_attn/__init__.py b/vllm_flash_attn/__init__.py index 7469df4c54..e17099fe78 100644 --- a/vllm_flash_attn/__init__.py +++ b/vllm_flash_attn/__init__.py @@ -3,6 +3,8 @@ # Use relative import to support build-from-source installation in vLLM from .flash_attn_interface import ( flash_attn_func, + sparse_attn_func, + sparse_attn_varlen_func, flash_attn_varlen_func, flash_attn_with_kvcache, ) diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index 6a40c32b05..5e4e824f3c 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -45,6 +45,76 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal): elif head_dim <= 256: return 64 +def _sparse_attn_forward( + q, k, v, block_count, block_offset, column_count, column_index, dropout_p, softmax_scale, causal, softcap, alibi_slopes, return_softmax, *, out=None +): + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, softmax_lse = torch.ops.vllm_flash_attn_c.fwd_sparse( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + out, + alibi_slopes, + dropout_p, + softmax_scale, + causal, + softcap, + return_softmax, + None, + ) + return out, softmax_lse + +def _sparse_attn_varlen_forward( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + softcap, + alibi_slopes, + return_softmax, + *, + out=None +): + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, softmax_lse = torch.ops.vllm_flash_attn_c.varlen_fwd_sparse( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + out, + cu_seqlens_q, + cu_seqlens_k, + None, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + False, + causal, + softcap, + return_softmax, + None, + ) + return out, softmax_lse + def _flash_attn_forward( q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax, *, out=None @@ -112,6 +182,162 @@ def _flash_attn_varlen_forward( ) return out, softmax_lse +def sparse_attn_func( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + dropout_p=0.0, + softmax_scale=None, + causal=False, + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + *, + return_softmax_lse=False, + out=None, +): + """Compute attention with vertical and slash sparsity patterns. + Most Arguments are the same with the flash_attn_func interface, except for 4 extra args: + block_count and block_offset for slash sparsity patterns, and + column_count and column_index for vertical sparsity patterns. + For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k: (batch_size, seqlen, nheads_k, headdim) + v: (batch_size, seqlen, nheads_k, headdim) + block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) + block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S) + column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) + column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, softmax_lse = _sparse_attn_forward( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + dropout_p, + softmax_scale, + causal=causal, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_attn_probs and dropout_p > 0, + out=out, + ) + return (out, softmax_lse) if return_softmax_lse else out + +def sparse_attn_varlen_func( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + *, + return_softmax_lse=False, + out=None, +): + """Compute attention with vertical and slash sparsity patterns. + Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args: + block_count and block_offset for slash sparsity patterns, and + column_count and column_index for vertical sparsity patterns. + For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) + block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S) + column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) + column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V) + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, softmax_lse = _sparse_attn_varlen_forward( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal=causal, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_attn_probs and dropout_p > 0, + out=out, + ) + return (out, softmax_lse) if return_softmax_lse else out def flash_attn_func( q,