Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions benchmark/benchmark_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import functools
import random
import time
from typing import List

from flash_attn.flash_attn_interface import _flash_attn_forward
import torch

from cacheflow import attention_ops


def benchmark(name, f, num_warmup = 10, num_iters = 100):
for _ in range(num_warmup):
f()
torch.cuda.synchronize()

start = time.time()
for _ in range(num_iters):
f()
torch.cuda.synchronize()
end = time.time()
print(f'{name}: {(end - start) / num_iters * 1000:.3f} ms')


@torch.inference_mode()
def benchmark_multi_query_cached_kv_attention(
query_lens: List[int],
context_lens: List[int],
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
) -> None:
print(f'query_lens: {query_lens}, context_lens: {context_lens}, '
f'num_heads: {num_heads}, head_size: {head_size}, block_size: '
f'{block_size}, num_blocks: {num_blocks}, dtype: {dtype}')
# Create query tensor.
num_queries = len(query_lens)
cu_query_lens = [0]
for query_len in query_lens:
cu_query_lens.append(cu_query_lens[-1] + query_len)
num_total_tokens = cu_query_lens[-1]
qkv = torch.randn(
num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
query, _, _ = qkv.unbind(dim=1)

# Create key and value cache.
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.randn(
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
value_block_shape = (num_heads, head_size, block_size)
value_cache = torch.randn(
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')

# Create block tables.
max_context_len = max(context_lens)
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_queries):
block_table = [
random.randint(0, num_blocks - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')

# Create input and output data structures.
cu_query_lens = torch.tensor(cu_query_lens, dtype=torch.int, device='cuda')
context_len_tensor = torch.tensor(context_lens, dtype=torch.int, device='cuda')
scale = float(1.0 / (head_size ** 0.5))
output = torch.empty(
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')

# Run our implementation.
def run_ours():
attention_ops.multi_query_cached_kv_attention(
cu_query_lens,
output,
query,
key_cache,
value_cache,
scale,
block_tables,
context_len_tensor,
block_size,
max_context_len,
)
benchmark('Ours', run_ours)

# Upper bound: Flash attention.
# Becuase Flash attention cannot read our own cache,
# we make key and value tensors contiguous.
num_kv_tokens = sum(context_lens)
cu_context_lens = [0]
for context_len in context_lens:
cu_context_lens.append(cu_context_lens[-1] + context_len)
cu_context_lens = torch.tensor(cu_context_lens, dtype=torch.int, device='cuda')
qkv = torch.randn(
num_kv_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
_, key, value = qkv.unbind(dim=1)
ref_output = torch.empty_like(output)

# Run Flash attention.
def run_flash_attn():
_flash_attn_forward(
query,
key,
value,
ref_output,
cu_query_lens,
cu_context_lens,
max(query_lens),
max_context_len,
dropout_p=0.0,
softmax_scale=scale,
causal=True,
return_softmax=False,
)
benchmark('Flash attention', run_flash_attn)


if __name__ == '__main__':
BLOCK_SIZE = 8
NUM_BLOCKS = 1024
DTYPE = torch.half

# LLaMA-13B and OPT-13B
NUM_HEADS = 40
HEAD_SIZE = 128

run_benchmark = functools.partial(
benchmark_multi_query_cached_kv_attention,
num_heads=NUM_HEADS,
head_size=HEAD_SIZE,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
dtype=DTYPE,
)

run_benchmark(
query_lens=[64] * 1,
context_lens=[64] * 1,
)
run_benchmark(
query_lens=[128] * 1,
context_lens=[128] * 1,
)
run_benchmark(
query_lens=[64] * 8,
context_lens=[64] * 8,
)
run_benchmark(
query_lens=[128] * 8,
context_lens=[128] * 8,
)
run_benchmark(
query_lens=[64, 32, 16],
context_lens=[128, 256, 64],
)
run_benchmark(
query_lens=[1024],
context_lens=[1024],
)
16 changes: 11 additions & 5 deletions csrc/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_(
const float scale,
const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq]
const int context_len,
const int max_num_blocks_per_seq) {
const int max_num_blocks_per_seq,
const int q_stride) {
constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
Expand Down Expand Up @@ -302,7 +303,8 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_(
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
// th vectors of the query, and so on.
const scalar_t* q_ptr = q + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
Q_vec q_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
Expand Down Expand Up @@ -514,7 +516,8 @@ __global__ void multi_query_cached_kv_attention_kernel(
const float scale,
const int* __restrict__ block_tables, // [num_prompts, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_prompts]
const int max_num_blocks_per_seq) {
const int max_num_blocks_per_seq,
const int q_stride) {
const int seq_idx = blockIdx.y;
const int prompt_idx = seq_prompt_mapping[seq_idx];
const int seq_start_idx = cu_query_lens[prompt_idx];
Expand All @@ -532,7 +535,8 @@ __global__ void multi_query_cached_kv_attention_kernel(
scale,
block_table,
context_len,
max_num_blocks_per_seq);
max_num_blocks_per_seq,
q_stride);
}

} // namespace cacheflow
Expand Down Expand Up @@ -696,7 +700,8 @@ void single_query_cached_kv_attention(
scale, \
block_tables_ptr, \
context_lens_ptr, \
max_num_blocks_per_seq);
max_num_blocks_per_seq, \
query_stride);


// TODO(woosuk): Tune NUM_THREADS.
Expand All @@ -719,6 +724,7 @@ void multi_query_cached_kv_attention_launcher(
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int query_stride = query.stride(0);

int* cu_query_lens_ptr = cu_query_lens.data_ptr<int>();
int* seq_prompt_mapping_ptr = seq_prompt_mapping.data_ptr<int>();
Expand Down
8 changes: 5 additions & 3 deletions tests/kernels/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,9 @@ def test_multi_query_cached_kv_attention(
cu_query_lens.append(cu_query_lens[-1] + query_len)
num_total_tokens = cu_query_lens[-1]

query = torch.randn(
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')
qkv = torch.randn(
num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
query, _, _ = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.randn(
Expand Down Expand Up @@ -314,7 +315,8 @@ def test_multi_query_cached_kv_attention(
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')

scale = float(1.0 / (head_size ** 0.5))
output = torch.empty_like(query)
output = torch.empty(
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')

attention_ops.multi_query_cached_kv_attention(
cu_query_lens,
Expand Down