Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
463 changes: 463 additions & 0 deletions csrc/flash_attn/flash_api.cpp

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,15 @@ struct Flash_fwd_params : public Qkv_params {

bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).

// For sparse attention
const int* block_count;
const int* block_offset;
const int* column_count;
const int* column_index;
int NUM_ROWS;
int NNZ_S;
int NNZ_V;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -189,6 +198,7 @@ struct Flash_bwd_params : public Flash_fwd_params {
////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_sparse_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);

template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_sparse_launch_template.h"

template<>
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_sparse_hdim128<cutlass::bfloat16_t, true>(params, stream);
}
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_sparse_launch_template.h"

template<>
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_sparse_hdim128<cutlass::bfloat16_t, false>(params, stream);
}
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_sparse_launch_template.h"

template<>
void run_mha_fwd_sparse_<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_sparse_hdim128<cutlass::half_t, true>(params, stream);
}
10 changes: 10 additions & 0 deletions csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_sparse_launch_template.h"

template<>
void run_mha_fwd_sparse_<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_sparse_hdim128<cutlass::half_t, false>(params, stream);
}
685 changes: 685 additions & 0 deletions csrc/flash_attn/src/flash_fwd_sparse_kernel.h

Large diffs are not rendered by default.

125 changes: 125 additions & 0 deletions csrc/flash_attn/src/flash_fwd_sparse_launch_template.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/******************************************************************************
* Copyright (c) 2024, PAI, Alibaba Cloud.
******************************************************************************/

#pragma once

#include "flash_fwd_launch_template.h"
#include "flash_fwd_sparse_kernel.h"

DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_sparse_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // Enforce constraints
flash::compute_sparse_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}

template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_sparse_fwd(Flash_fwd_params &params, cudaStream_t stream) {
constexpr size_t smem_size = Kernel_traits::kSmemSize;
// printf("smem_size = %d\n", smem_size);

// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21

const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.b, params.h);
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr;
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
constexpr bool IsEvenMNConst = false;
constexpr bool Is_local = false;
// Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_fwd_sparse_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
}

template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 32;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}

template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}

template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}

template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}

template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 160;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}

template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 192;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}

template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 224;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}

template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}
14 changes: 13 additions & 1 deletion csrc/flash_attn/src/generate_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@
}}
"""

KERNEL_IMPL_TEMPLATE_FWD_SPARSE = """#include "flash_fwd_sparse_launch_template.h"

template<>
void run_mha_fwd_sparse_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream) {{
run_mha_fwd_sparse_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream);
}}
"""

KERNEL_IMPL_TEMPLATE_FWD_SPLIT = """#include "flash_fwd_launch_template.h"

template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream);
Expand Down Expand Up @@ -53,6 +61,10 @@ def template(self) -> str:
return KERNEL_IMPL_TEMPLATE_FWD.format(
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
)
elif self.direction == "fwd_sparse":
return KERNEL_IMPL_TEMPLATE_FWD_SPARSE.format(
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
)
elif self.direction == "bwd":
return KERNEL_IMPL_TEMPLATE_BWD.format(
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
Expand All @@ -68,7 +80,7 @@ def filename(self) -> str:


def get_all_kernels() -> List[Kernel]:
for direction in ["fwd", "fwd_split"]:
for direction in ["fwd", "fwd_split", "fwd_sparse"]:
for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM):
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction)
for direction in ["bwd"]:
Expand Down
151 changes: 151 additions & 0 deletions tests/test_vllm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,154 @@ def test_varlen_with_paged_kv(
)
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"

@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("seq_lens", [(1, 1), (1, 1024), (1, 2048), (1023, 2049), (1023, 1023), (32, 32), (65, 65), (129, 129)])
@pytest.mark.parametrize("num_heads", [1, 2, 4])
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("NNZ_S", [0, 1, 2, 3, 7, 15, 32])
@torch.inference_mode()
def test_sparse_attention(
batch_size: int,
seq_lens: Tuple[int, int],
num_heads: int,
head_size: int,
dtype: torch.dtype,
NNZ_S: int,
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
block_size_M = 64
block_size_N = 64
seqlen_q, seqlen_k = seq_lens
q = torch.randn(
batch_size, seqlen_q, num_heads, head_size, dtype=dtype, requires_grad=False
)
k = torch.randn(
batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False
)
v = torch.randn(
batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False
)
NUM_ROWS = (seqlen_q + block_size_M - 1) // block_size_M
if NNZ_S * block_size_N > seqlen_k:
return
NNZ_V = seqlen_k - NNZ_S * block_size_N
block_count = torch.tensor([NNZ_S] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32).reshape(batch_size, num_heads, NUM_ROWS)
column_count = torch.tensor([NNZ_V] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32).reshape(batch_size, num_heads, NUM_ROWS)
block_offset = torch.tensor([[i * block_size_N for i in range(NNZ_S)]] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S)
column_index = torch.tensor([[NNZ_S * block_size_N + i for i in range(NNZ_V)]] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V)
from vllm_flash_attn import sparse_attn_func, flash_attn_func
out, lse = sparse_attn_func(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
return_softmax_lse=True,
)

ref_out, ref_lse = flash_attn_func(
q,
k,
v,
return_softmax_lse=True,
)

torch.testing.assert_close(out, ref_out, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(out - ref_out))}"
torch.testing.assert_close(lse, ref_lse, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(lse - ref_lse))}"

@pytest.mark.parametrize("seq_lens", [[(1024, 1328)],
[(1024, 1328), (1, 2048)],
[(1025, 1328), (2, 2048)],
[(1025, 2049), (2, 1281)],
])
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_sparse_attention_varlen(
seq_lens: List[Tuple[int, int]],
head_size: int,
dtype: torch.dtype,
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
block_size_M = 64
block_size_N = 64
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
num_heads = 1
query = torch.randn(sum(query_lens),
num_heads,
head_size,
dtype=dtype)
key = torch.randn(sum(kv_lens),
num_heads,
head_size,
dtype=dtype)
value = torch.randn_like(key)
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
cu_kv_lens = torch.tensor([0] + kv_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)

NUM_ROWS = (max_query_len + block_size_M - 1) // block_size_M
NNZ_S = 20
NNZ_V = 2048
batch_size = len(query_lens)

block_counts = []
column_counts = []
block_offsets = []
column_indices = []
for b in range(batch_size):
block_counts.append(torch.tensor([NNZ_S] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS))
columns = kv_lens[b] - NNZ_S * block_size_N
column_counts.append(torch.tensor([columns] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS))
block_offsets.append(torch.tensor([[i * block_size_N for i in range(NNZ_S)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_S))
column_indices.append(torch.tensor([[NNZ_S * block_size_N + i for i in range(NNZ_V)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_V))
block_count = torch.concat(block_counts).reshape(batch_size, num_heads, NUM_ROWS)
column_count = torch.concat(column_counts).reshape(batch_size, num_heads, NUM_ROWS)
block_offset = torch.concat(block_offsets).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S)
column_index = torch.concat(column_indices).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V)
from vllm_flash_attn import sparse_attn_varlen_func, flash_attn_varlen_func
out, lse = sparse_attn_varlen_func(
query,
key,
value,
block_count,
block_offset,
column_count,
column_index,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
return_softmax_lse=True,
)

ref_out, ref_lse = flash_attn_varlen_func(
query,
key,
value,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
return_softmax_lse=True,
)

torch.testing.assert_close(out, ref_out, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(out - ref_out))}"
torch.testing.assert_close(lse, ref_lse, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(lse - ref_lse))}"
2 changes: 2 additions & 0 deletions vllm_flash_attn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Use relative import to support build-from-source installation in vLLM
from .flash_attn_interface import (
flash_attn_func,
sparse_attn_func,
sparse_attn_varlen_func,
flash_attn_varlen_func,
flash_attn_with_kvcache,
)
Loading