diff --git a/test/kernel/test_paged_attention.py b/test/kernel/test_paged_attention.py new file mode 100644 index 0000000000..6892f1ec35 --- /dev/null +++ b/test/kernel/test_paged_attention.py @@ -0,0 +1,204 @@ +import torch +import unittest +import random +from itertools import product +import torchao +from torchao.kv_cache import PagedAttentionCache, PagedTensor + +class NaiveCache: + def __init__(self): + self.past_key = None + self.past_value = None + + def expand_cache(self, beam_size): + self.past_key = self.past_key.repeat_interleave(beam_size, dim=0) + self.past_value = self.past_value.repeat_interleave(beam_size, dim=0) + + def update(self, key, value, layer_idx=0): + if self.past_key is None: + self.past_key = key + self.past_value = value + else: + self.past_key = torch.cat((self.past_key, key), dim=2) + self.past_value = torch.cat((self.past_value, value), dim=2) + return self.past_key, self.past_value + + def reorder_cache(self, beam_idx): + self.past_key = self.past_key.index_select(0, beam_idx) + self.past_value = self.past_value.index_select(0, beam_idx) + + +class MHAModule(torch.nn.Module): + def __init__(self, head_dim, num_heads, num_kv_heads): + super(MHAModule, self).__init__() + self.head_dim = head_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.scale = head_dim**-0.5 + self.q = torch.nn.Linear( + self.num_heads * self.head_dim, self.num_heads * self.head_dim + ) + self.k = torch.nn.Linear( + self.num_heads * self.head_dim, self.num_kv_heads * self.head_dim + ) + self.v = torch.nn.Linear( + self.num_heads * self.head_dim, self.num_kv_heads * self.head_dim + ) + + def forward(self, inputs, kv_cache): + query = self.q(inputs) + key = self.k(inputs) + value = self.v(inputs) + batch_size = inputs.size(0) + query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose( + 1, 2 + ) + key = key.view(batch_size, -1, self.num_kv_heads, self.head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.num_kv_heads, self.head_dim).transpose( + 1, 2 + ) + updated_key, updated_value = kv_cache.update(key, value, 0) + if not isinstance(updated_key, PagedTensor): + updated_key = updated_key.repeat_interleave( + self.num_heads // self.num_kv_heads, dim=1 + ) + updated_value = updated_value.repeat_interleave( + self.num_heads // self.num_kv_heads, dim=1 + ) + return torch.nn.functional.scaled_dot_product_attention( + query, updated_key, updated_value, scale=self.scale + ) + + +@unittest.skipIf(torch.cuda.is_available(), "CUDA is not enabled yet") +class PagedAttentionCachePagedTensorTest(unittest.TestCase): + def _test_paged_attention_cache( + self, + num_blocks, + block_size, + num_query_heads, + num_key_value_heads, + head_dim, + device, + dtype, + batch_size, + beam_size, + ): + num_layers = 1 + prompt_len = 32 + mha_model = MHAModule(head_dim, num_query_heads, num_key_value_heads).to( + device=device, dtype=dtype + ) + naive_cache = NaiveCache() + pagedcache = PagedAttentionCache( + num_blocks, + block_size, + num_key_value_heads, + head_dim, + num_layers, + device, + dtype, + ) + # enable prompt sharing for the first token, fork + pagedcache.set_batch2seq_for_prompt_sharing(batch_size, beam_size) + pagedcache.allocate(batch_size, prompt_len) + prompt_inputs = torch.randn( + batch_size, + prompt_len, + num_query_heads * head_dim, + device=device, + dtype=dtype, + ) + paged_output = mha_model(prompt_inputs, pagedcache) + naive_output = mha_model(prompt_inputs, naive_cache) + torch.allclose(paged_output, naive_output) + + beam_idx = torch.arange( + 0, batch_size * beam_size, beam_size, device=device, dtype=torch.int64 + ).repeat_interleave(beam_size) + naive_cache.expand_cache(beam_size) + naive_cache.reorder_cache(beam_idx) + pagedcache.reorder_cache(beam_idx) + + # Next token + pagedcache.allocate(batch_size * beam_size, 1) + next_inputs = torch.randn( + batch_size * beam_size, + 1, + num_query_heads * head_dim, + device=device, + dtype=dtype, + ) + + paged_output = mha_model(next_inputs, pagedcache) + naive_output = mha_model(next_inputs, naive_cache) + torch.allclose(paged_output, naive_output, atol=1e-3, rtol=1e-3) + + for i in range(batch_size): + beam_idx[i * beam_size : (i + 1) * beam_size] = torch.randint( + i * beam_size, + (i + 1) * beam_size, + (1, beam_size), + device=device, + dtype=torch.int64, + ) + naive_cache.reorder_cache(beam_idx) + pagedcache.reorder_cache(beam_idx) + + # Next token + pagedcache.allocate(batch_size * beam_size, 1) + prompt_inputs = torch.randn( + batch_size * beam_size, + 1, + num_query_heads * head_dim, + device=device, + dtype=dtype, + ) + paged_output = mha_model(prompt_inputs, pagedcache) + naive_output = mha_model(prompt_inputs, naive_cache) + torch.allclose(paged_output, naive_output, atol=1e-3, rtol=1e-3) + + def test_paged_attention_kv_cache(self): + # num_blocks, block_size, num_query_heads, num_key_value_heads, head_dim, device, dtype, batch_size, beam_size + num_blocks = 128 + block_sizes = [16, 32] + num_query_heads = [40] + num_key_value_heads = [40, 10, 1] + head_dim = [64, 128] + device = ['cpu'] + dtypes = [torch.bfloat16, torch.float16] + batch_size = [1, 8] + beam_size = [1, 4] + for ( + block_size, + num_query_head, + num_key_value_head, + head_dim, + device, + dtype, + batch_size, + beam_size, + ) in product( + block_sizes, + num_query_heads, + num_key_value_heads, + head_dim, + device, + dtypes, + batch_size, + beam_size, + ): + self._test_paged_attention_cache( + num_blocks, + block_size, + num_query_head, + num_key_value_head, + head_dim, + device, + dtype, + batch_size, + beam_size, + ) + +if __name__ == "__main__": + test = unittest.main() diff --git a/torchao/__init__.py b/torchao/__init__.py index dce378411c..48be662c04 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -34,9 +34,12 @@ ) from . import dtypes +from torchao.kv_cache import PagedAttentionCache, PagedTensor __all__ = [ "dtypes", "autoquant", + "PagedAttentionCache", + "PagedTensor" "quantize_", ] diff --git a/torchao/csrc/cpu/paged_attention_kernel.cpp b/torchao/csrc/cpu/paged_attention_kernel.cpp new file mode 100644 index 0000000000..a094b9a058 --- /dev/null +++ b/torchao/csrc/cpu/paged_attention_kernel.cpp @@ -0,0 +1,524 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include +#include +#include +#include +#include + +#define SEQ_PARTITION_SIZE 256 + +namespace torchao { + +namespace { + +template +void reduce_head(const scalar_t *q_ptr_start, const scalar_t *k_cache_start, + accum_t *attn_w_pos, int64_t head_size) { + attn_w_pos[0] = 0; + for (long i = 0; i < head_size; i++) { + attn_w_pos[0] += q_ptr_start[i] * k_cache_start[i]; + } +} + +// BF16 +template <> +void reduce_head(const at::BFloat16 *q_ptr_start, + const at::BFloat16 *k_cache_start, + float *attn_w_pos, int64_t head_size) { + attn_w_pos[0] = 0; + using lpVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + auto vec_size = lpVec::size(); + auto vec_tmp_sum = fVec(0.0f); + for (long i = 0; i < vec_size * (head_size / vec_size); i += vec_size) { + auto tmpq = lpVec::loadu(q_ptr_start + i); + auto tmpk = lpVec::loadu(k_cache_start + i); + fVec tmpq1, tmpq2, tmpk1, tmpk2; + // convert to float + std::tie(tmpq1, tmpq2) = at::vec::convert_to_float(tmpq); + std::tie(tmpk1, tmpk2) = at::vec::convert_to_float(tmpk); + vec_tmp_sum = vec_tmp_sum + tmpq1 * tmpk1 + tmpq2 * tmpk2; + } + attn_w_pos[0] = at::vec::vec_reduce_all<>( + [](fVec &x, fVec &y) { return x + y; }, vec_tmp_sum); +} + +template +inline void mul_attenion_weights_and_value_of_head( + const accum_t &attn_w, const scalar_t *v_cache_start, + accum_t *attn_out_start, int64_t head_size, bool accumulated) { + for (auto hsi = 0; hsi < head_size; hsi++) { + if (accumulated) { + attn_out_start[hsi] += attn_w * (float)v_cache_start[hsi]; + } else { + attn_out_start[hsi] = attn_w * (float)v_cache_start[hsi]; + } + } +} + +template <> +inline void mul_attenion_weights_and_value_of_head( + const float &attn_w, const at::BFloat16 *v_cache_start, + float *attn_out_start, int64_t head_size, bool accumulated) { + using lpVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + auto lpVec_size = lpVec::size(); + auto fVec_size = fVec::size(); + auto vec_attn_w = fVec(attn_w); + auto vec_tmp_sum = fVec(0.0f); + long i = 0; + for (; i < lpVec_size * (head_size / lpVec_size); i += lpVec_size) { + auto tmpv = lpVec::loadu(v_cache_start + i); + fVec tmpv1, tmpv2; + // convert to float + std::tie(tmpv1, tmpv2) = at::vec::convert_to_float(tmpv); + auto tmp1 = tmpv1 * vec_attn_w; + auto tmp2 = tmpv2 * vec_attn_w; + if (accumulated) { + tmp1 = fVec::loadu(attn_out_start + i) + tmp1; + tmp1.store(attn_out_start + i); + tmp2 = fVec::loadu(attn_out_start + i + fVec_size) + tmp2; + tmp2.store(attn_out_start + i + fVec_size); + } else { + tmp1.store(attn_out_start + i); + tmp2.store(attn_out_start + i + fVec_size); + } + } + for (; i < head_size; i++) { + if (accumulated) { + attn_out_start[i] += attn_w * (float)v_cache_start[i]; + } else { + attn_out_start[i] = attn_w * (float)v_cache_start[i]; + } + } +} + +// out = val * a + b +template +inline void _scale_attn_mask_fusion_kernel(T1 *a, float *b, const int &size, + T2 *out, float val) { + const auto vec_size = at::vec::Vectorized::size(); + const auto vec_scale = at::vec::Vectorized(val); + int64_t i = 0; + for (; i < size - (size % vec_size); i += vec_size) { + auto a_v = at::vec::Vectorized::loadu(a + i); + auto b_v = at::vec::Vectorized::loadu(b + i); + auto res = a_v * vec_scale + b_v; + res.store(out + i); + } + for (; i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = b[i]; + out[i] = tmp0 * val + tmp1; + } +} + +// 1) out = exp(a - val) +// 2) val = sum(out) +template +inline void _exp_reduce_sum_fusion_kernel(T1 *a, const int &size, T2 *out, + T1 &val) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_max = at::vec::Vectorized(val); + T1 tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + long i = 0; + for (; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum += tmp2; + tmp2.store(out + i); + } + tmp_sum = at::vec::vec_reduce_all( + [](at::vec::Vectorized &x, at::vec::Vectorized &y) { + return x + y; + }, + vec_tmp_sum); + for (; i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 - val; + auto tmp2 = exp(tmp1); + tmp_sum += tmp2; + out[i] = tmp2; + } + val = tmp_sum; +} + +// 1) out = a * scale +// 2) max = max(out) +template +inline void _mul_reduce_max_fusion_kernel(scalar_t *a, const scalar_t &scale, + const int &size, scalar_t *out, + scalar_t &max) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_scale = at::vec::Vectorized(scale); + scalar_t tmp_max = -std::numeric_limits::infinity(); + auto vec_tmp_max = at::vec::Vectorized(tmp_max); + long i = 0; + for (; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 * vec_scale; + vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp1); + tmp1.store(out + i); + } + tmp_max = at::vec::vec_reduce_all( + [](at::vec::Vectorized &x, at::vec::Vectorized &y) { + return at::vec::maximum(x, y); + }, + vec_tmp_max); + for (; i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 * scale; + tmp_max = std::max(tmp_max, tmp1); + out[i] = tmp1; + } + max = tmp_max; +} + +void reshape_attn_mask_to_4d(at::Tensor &attn_mask, int64_t batchSize, + int64_t num_head, int64_t qSize, int64_t kvSize) { + // Support mask shapes: + // 2d: ({Q_seq_len, 1} x {KV_seq_len, 1}) + // 4d: ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1}) + // Guaranteed in check_attn_mask_shape + int64_t attn_mask_size_0 = 1; + int64_t attn_mask_size_1 = 1; + if (attn_mask.dim() == 4) { + if (attn_mask.size(0) == batchSize) { + attn_mask_size_0 = batchSize; + } + if (attn_mask.size(1) == num_head) { + attn_mask_size_1 = num_head; + } + } + attn_mask = attn_mask + .view({attn_mask_size_0, attn_mask_size_1, attn_mask.size(-2), + attn_mask.size(-1)}) + .expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize}); +} + +/** + * Performs scale-dot-product for the next token based on paged cached key-value + * @param out Output tensor [batch_size, num_heads, 1, head_size]. + * @param query Query tensor [batch_size, num_heads, 1, head_size]. + * @param key_cache The pre-allocated buffer to store the key cache. The + * shape should be [num_blocks, num_heads, block_size, head_size]. + * @param value_cache The pre-allocated buffer to store the value cache. The + * shape should be [num_blocks, num_heads, block_size, head_size]. + * @param scale Scaling factor for attention weights. In general, it is: + * float(1.0 / (head_size ** 0.5)). + * @param block_tables Block tables tensor [batch_size, max_num_blocks_per_seq]. + * @param context_lens Context lengths tensor [batch_size]. + * @param attn_mask Optional tensor of attention_mask + */ +template +void paged_attention_kernel(at::Tensor &out, at::Tensor &query, + at::Tensor &key_cache, at::Tensor &value_cache, + const double scale, at::Tensor &block_tables, + at::Tensor &context_lens, + c10::optional attn_mask) { + + TORCH_CHECK(query.size(2) == 1, + "Paged attention: only seqlen 1 is supported for query"); + using accum_t = at::opmath_type; + using Vec = at::vec::Vectorized; + const auto dtype = query.scalar_type(); + const auto accumulate_dtype = at::toOpMathType(dtype); + auto max_context_len = context_lens.max().item(); + auto batch_size = query.size(0); + auto q_len = query.size(2); + auto num_heads = query.size(1); + auto head_size = query.size(3); + auto block_size = key_cache.size(2); + auto num_kv_heads = key_cache.size(1); + auto max_num_blocks_per_seq = block_tables.size(1); + auto kv_head_group_size = num_heads / num_kv_heads; + bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel(); + if (has_attn_mask) { + attn_mask.value() = attn_mask.value().to(at::kFloat); + reshape_attn_mask_to_4d(attn_mask.value(), batch_size, num_heads, q_len, + attn_mask.value().size(-1)); + } + + auto out_ptr = out.data_ptr(); + auto query_ptr = query.data_ptr(); + auto key_cache_ptr = key_cache.data_ptr(); + auto value_cache_ptr = value_cache.data_ptr(); + auto block_tables_ptr = block_tables.data_ptr(); + auto context_lens_ptr = context_lens.data_ptr(); + + auto kv_block_strideN = key_cache.stride(0); + auto kv_block_strideP = key_cache.stride(2); + auto kv_block_strideH = key_cache.stride(1); + + auto out_strideN = out.stride(0); + auto out_strideH = out.stride(1); + auto out_strideS = out.stride(2); + + auto q_strideN = query.stride(0); + auto q_strideH = query.stride(1); + auto q_strideS = query.stride(2); + + auto attn_mask_ptr = + attn_mask.has_value() ? attn_mask.value().data_ptr() : nullptr; + + int64_t mStrideB = (has_attn_mask && attn_mask.value().size(0) > 1) + ? attn_mask.value().stride(0) + : 0; + int64_t mStrideH = (has_attn_mask && attn_mask.value().size(1) > 1) + ? attn_mask.value().stride(1) + : 0; + int64_t mStrideM = has_attn_mask ? attn_mask.value().stride(2) : 0; + + auto max_num_seq_partitions = + (max_context_len + SEQ_PARTITION_SIZE - 1) / SEQ_PARTITION_SIZE; + + auto max_logits = at::empty({batch_size, num_heads, max_num_seq_partitions + 1}, + query.options().dtype(accumulate_dtype)); + + auto exp_sum = at::empty({batch_size, num_heads, max_num_seq_partitions + 1}, + query.options().dtype(accumulate_dtype)); + + auto tmp_out = at::empty({batch_size, num_heads, max_num_seq_partitions, head_size}, + query.options().dtype(accumulate_dtype)); + + auto tmp_out_ptr = tmp_out.data_ptr(); + auto max_logits_ptr = max_logits.data_ptr(); + auto exp_sum_ptr = exp_sum.data_ptr(); + + auto max_logits_strideN = max_logits.stride(0); + auto max_logits_strideH = max_logits.stride(1); + auto exp_sum_strideN = exp_sum.stride(0); + auto exp_sum_strideH = exp_sum.stride(1); + auto tmp_out_strideN = tmp_out.stride(0); + auto tmp_out_strideH = tmp_out.stride(1); + auto tmp_out_strideS = tmp_out.stride(2); +#pragma omp parallel for collapse(3) schedule(static, 1) + for (auto partition_id = 0; partition_id < max_num_seq_partitions; + partition_id++) { + for (auto head_id = 0; head_id < num_heads; head_id++) { + for (auto seq_id = 0; seq_id < batch_size; seq_id++) { + auto context_len = context_lens_ptr[seq_id]; + auto partition_start = partition_id * SEQ_PARTITION_SIZE; + if (partition_start >= context_len) + continue; + auto partition_end = + std::min(partition_start + SEQ_PARTITION_SIZE, context_len); + auto token_num = partition_end - partition_start; + auto block_num = (token_num + block_size - 1) / block_size; + auto logical_block_start = partition_start / block_size; + auto logical_block_end = logical_block_start + block_num; + auto need_update = block_num > 1; + auto kv_head_id = head_id / kv_head_group_size; + auto q_ptr_start = query_ptr + seq_id * q_strideN + head_id * q_strideH; + auto max_logits_offset = seq_id * max_logits_strideN + + head_id * max_logits_strideH + partition_id; + auto exp_sum_offset = + seq_id * exp_sum_strideN + head_id * exp_sum_strideH + partition_id; + auto tmp_out_start = tmp_out_ptr + seq_id * tmp_out_strideN + + head_id * tmp_out_strideH + + partition_id * tmp_out_strideS; + accum_t alignas(64) logits[SEQ_PARTITION_SIZE] = {0}; + auto logits_position = 0; + // 1)calculate the matmul(query, key) for this partition + for (auto logical_block_id = logical_block_start; + logical_block_id < logical_block_end; logical_block_id++) { + auto physical_block_id = + block_tables_ptr[seq_id * max_num_blocks_per_seq + + logical_block_id]; + auto tokens_in_block = + std::min(block_size, context_len - logical_block_id * block_size); + auto token_start = logical_block_id * block_size; + auto token_end = token_start + tokens_in_block; + for (auto token_id = token_start; token_id < token_end; token_id++) { + auto block_offset = token_id - token_start; + auto k_cache_start = + key_cache_ptr + physical_block_id * kv_block_strideN + + block_offset * kv_block_strideP + kv_head_id * kv_block_strideH; + reduce_head(q_ptr_start, k_cache_start, + &(logits[logits_position]), + head_size); + logits_position++; + } + } + // 2) calculate the max and exp_sum for this partition + auto partition_max = -std::numeric_limits::infinity(); + if (has_attn_mask) { + _scale_attn_mask_fusion_kernel( + logits, + attn_mask_ptr + seq_id * mStrideB + head_id * mStrideH + + partition_start, + token_num, logits, scale); + partition_max = at::vec::reduce_all( + [](Vec &x, Vec &y) { return at::vec::maximum(x, y); }, logits, + token_num); + } else { + _mul_reduce_max_fusion_kernel(logits, scale, token_num, + logits, partition_max); + } + max_logits_ptr[max_logits_offset] = partition_max; + _exp_reduce_sum_fusion_kernel(logits, token_num, + logits, partition_max); + exp_sum_ptr[exp_sum_offset] = partition_max; + + // 3) calculate the matmul(exp(logits-partition_max), value) for this + // partition, need to divide the global exp_sum in the final result. + logits_position = 0; + for (auto logical_block_id = logical_block_start; + logical_block_id < logical_block_end; logical_block_id++) { + auto physical_block_id = + block_tables_ptr[seq_id * max_num_blocks_per_seq + + logical_block_id]; + auto tokens_in_block = + std::min(block_size, context_len - logical_block_id * block_size); + auto token_start = logical_block_id * block_size; + auto token_end = token_start + tokens_in_block; + for (auto token_id = token_start; token_id < token_end; token_id++) { + auto block_offset = token_id - token_start; + auto v_cache_start = + value_cache_ptr + physical_block_id * kv_block_strideN + + block_offset * kv_block_strideP + kv_head_id * kv_block_strideH; + auto accumulated = logits_position > 0; + mul_attenion_weights_and_value_of_head( + logits[logits_position], v_cache_start, tmp_out_start, + head_size, accumulated); + logits_position++; + } + } + } + } + } + +// calculate the final output +#pragma omp parallel for collapse(2) + for (auto seq_id = 0; seq_id < batch_size; seq_id++) { + for (auto head_id = 0; head_id < num_heads; head_id++) { + auto global_max = -std::numeric_limits::infinity(); + auto global_exp_sum = 0.0; + auto context_len = context_lens_ptr[seq_id]; + auto partition_num = (context_len + SEQ_PARTITION_SIZE - 1) / SEQ_PARTITION_SIZE; + // calculate the global max and exp_sum for this head + for (auto partition_id = 0; partition_id < max_num_seq_partitions; + partition_id++) { + if (partition_id >= partition_num) + break; + auto max_logit = + max_logits_ptr[seq_id * max_logits_strideN + + head_id * max_logits_strideH + partition_id]; + global_max = std::max(global_max, max_logit); + } + // update the partition 0 result with the global max + auto partition0_out_start = + tmp_out_ptr + seq_id * tmp_out_strideN + head_id * tmp_out_strideH; + auto max_logit0 = max_logits_ptr[seq_id * max_logits_strideN + + head_id * max_logits_strideH]; + float exp_val = expf(max_logit0 - global_max); + global_exp_sum += + exp_sum_ptr[seq_id * exp_sum_strideN + head_id * exp_sum_strideH] * + exp_val; + at::vec::Vectorized exp_val_vec0(exp_val); + at::vec::map([&](auto a) { return a * exp_val_vec0; }, + partition0_out_start, partition0_out_start, + head_size); + + // accumulate the partition 1 to partition n result into partition 0 + if (partition_num > 1) { + for (auto partition_id = 1; partition_id < partition_num; + partition_id++) { + if (partition_id * SEQ_PARTITION_SIZE >= context_len) + break; + auto tmp_out_start = tmp_out_ptr + seq_id * tmp_out_strideN + + head_id * tmp_out_strideH + + partition_id * tmp_out_strideS; + auto max_logit = + max_logits_ptr[seq_id * max_logits_strideN + + head_id * max_logits_strideH + partition_id]; + auto exp_sum = exp_sum_ptr[seq_id * exp_sum_strideN + + head_id * exp_sum_strideH + partition_id]; + exp_val = expf(max_logit - global_max); + global_exp_sum += exp_sum * exp_val; + at::vec::Vectorized exp_val_vec(exp_val); + at::vec::map2( + [&](auto a, auto b) { return a + exp_val_vec * b; }, + partition0_out_start, partition0_out_start, tmp_out_start, + head_size); + } + } + + // copy the partition 0 result into attn_outs + auto attn_out_start = + out_ptr + seq_id * out_strideN + head_id * out_strideH; + float inverse_global_sum = 1.0 / (global_exp_sum + 1e-8); + at::vec::Vectorized inverse_global_sum_vec(inverse_global_sum); + // rescale the partition 0 result with global exp_sum + at::vec::map([&](auto a) { return a * inverse_global_sum_vec; }, + partition0_out_start, partition0_out_start, + head_size); + // copy the partition 0 result into attn_outs + at::vec::map([&](auto a) { return a; }, attn_out_start, + partition0_out_start, head_size); + } + } +} // paged_attention_kernel + +void paged_attention_kernel_impl( + at::Tensor &out, // [batch_size, num_heads, 1, head_size] + at::Tensor &query, // [batch_size, num_heads, 1, head_size] + at::Tensor &key_cache, // [num_blocks, num_heads, block_size, head_size] + at::Tensor &value_cache, // [num_blocks, num_heads, block_size, head_size] + const double scale, + at::Tensor &block_tables, // [batch_size, max_num_blocks_per_seq] + at::Tensor &context_lens, // [batch_size] + c10::optional attn_mask) { + TORCH_CHECK(SEQ_PARTITION_SIZE % key_cache.size(2) == 0, + "Paged attention: The PARTION_SIZE:%d should be divisible by block_size: %d", SEQ_PARTITION_SIZE, key_cache.size(2)); + TORCH_CHECK(query.size(2) == 1, + "Paged attention: only seqlen 1 is supported for query"); + TORCH_CHECK(query.scalar_type() == key_cache.scalar_type() && + query.scalar_type() == value_cache.scalar_type(), + "Paged attention: Q/K/V should have the same data type"); + TORCH_CHECK(!attn_mask.has_value() || + query.scalar_type() == attn_mask.value().scalar_type() || + attn_mask.value().scalar_type() != at::ScalarType::Bool, + "Paged attention: Mask should have the same data type as Q/K/V " + "and should not be Bool"); + TORCH_CHECK( + query.dim() == 4 && key_cache.dim() == 4 && value_cache.dim() == 4, + "Paged attention: Accept only 4 dims inputs shape of {B, H, T, K}"); + TORCH_CHECK( + (query.stride(-1) == 1) && (key_cache.stride(-1) == 1) && + (value_cache.stride(-1) == 1) && + (!attn_mask.has_value() || attn_mask.value().stride(-1) == 1), + "Paged attention: Q/KV cache/Mask should be continuous on the last dim"); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::kBFloat16, at::kHalf, query.scalar_type(), "paged_attention", [&] { + paged_attention_kernel(out, query, key_cache, value_cache, + scale, block_tables, + context_lens, attn_mask); + }); +} + +} // namespace +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::paged_attention", &paged_attention_kernel_impl); +} + +} // namespace torchao \ No newline at end of file diff --git a/torchao/csrc/paged_attention.cpp b/torchao/csrc/paged_attention.cpp new file mode 100644 index 0000000000..f6890efe22 --- /dev/null +++ b/torchao/csrc/paged_attention.cpp @@ -0,0 +1,11 @@ +#include +#include +#include + +TORCH_LIBRARY_FRAGMENT(torchao, m) { + m.impl_abstract_pystub("torchao.ops"); + m.def( + "paged_attention(Tensor (a!)out, Tensor (a!)query, Tensor (a!)key_cache, Tensor (a!)value_cache,\ + float scale, Tensor(a!) block_tables, Tensor(a!) context_lens, \ + Tensor? attn_mask)-> ()"); +} \ No newline at end of file diff --git a/torchao/kv_cache.py b/torchao/kv_cache.py new file mode 100644 index 0000000000..528e71f587 --- /dev/null +++ b/torchao/kv_cache.py @@ -0,0 +1,409 @@ +import torch +import torch.nn as nn +import functools +from typing import List, Tuple, Optional, Dict, Any +import copy + +HANDLED_FUNCTIONS = {} + + +class PagedTensor(torch.Tensor): + @staticmethod + def __new__(cls, size, cache, block_table, *args, **kwargs): + return torch.Tensor._make_wrapper_subclass(cls, size, dtype=cache.dtype, *args, **kwargs) + + def __init__( + self, + size: Tuple[int, int, int, int],#The size of the cached tensor[bs, num_key_value_heads, seq_lens, head_dim]. + cache: torch.Tensor, #The cache tensor from the PagedAttentionCache object, which is shared accross iterations. + block_tables: torch.Tensor,#The block tables for each sequence in the batch which is used to mapping logical block to physical blocks. + ): + self.block_tables = block_tables + self.cache = cache + + def __repr__(self): + return f"PagedTensor(buffer shape: {self.cache.shape}, k/v cache shape:{self.shape}" + + @staticmethod + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + return NotImplemented + + @classmethod + def __torch_function__(cls, func, types, args, kwargs=None): + if func not in HANDLED_FUNCTIONS or not all( + issubclass(t, (torch.Tensor, PagedTensor)) + for t in types + ): + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) + return HANDLED_FUNCTIONS[func](*args, **kwargs) + +def implements(torch_function): + """Register a torch function override for ScalarTensor""" + def decorator(func): + functools.update_wrapper(func, torch_function) + HANDLED_FUNCTIONS[torch_function] = func + return func + return decorator + +@implements(torch.nn.functional.scaled_dot_product_attention) +def scaled_dot_product_attention( + input, key_tensor, value_tensor, attn_mask=None, scale=None +): + query = input + key_cache = key_tensor.cache + value_cache = value_tensor.cache + block_tables = key_tensor.block_tables + context_lens = torch.tensor([key_tensor.shape[2] for _ in range(key_tensor.shape[0])], dtype=torch.int32) + output = torch.empty_like(query) + torch.ops.torchao.paged_attention( + output, + query, + key_cache, + value_cache, + scale, + block_tables, + context_lens, + attn_mask, + ) + return output + + +class PagedAttentionCache(object): + def __init__( + self, + num_blocks: int, + block_size: int, + num_key_value_heads: int, + head_dim: int, + num_layers: int, + device="cpu", + dtype=None, + ) -> None: + super().__init__() + + # model info + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.num_layers = num_layers + + # Cache tensor info + self.dtype = dtype if dtype is not None else torch.float32 + self.device = device + self.num_blocks = num_blocks + self.block_size = block_size + + cache_shape = ( + self.num_blocks, + self.num_key_value_heads, + self.block_size, + self.head_dim, + ) + + # KV caches for each layer + self.key_caches = [ + torch.zeros(cache_shape, dtype=self.dtype, device=device) + for _ in range(num_layers) + ] + self.value_caches = [ + torch.zeros(cache_shape, dtype=self.dtype, device=device) + for _ in range(num_layers) + ] + + self.seen_tokens = ( + 0 # Used in `generate` to keep tally of how many tokens the cache has seen + ) + + # paged cache runtime information + self.free_blocks = list(range(num_blocks)) # free blocks + self.block_ref_count = [ + 0 + ] * self.num_blocks # init the reference count for each physical block + self.block_tables = ( + dict() + ) # mapping logical block to physical blocks for each sequence + + # The follow two states are shared accross layer but only for the current decode step. Need to update for every decode step. + self.batch2seq = None # mapping batch index to {seq_id0, seq_id1, ...} to enable prompt sharing. + self.slots_mapping = None # mapping logical slots to physical slots. + + def _copy_on_write(self, src_block_idx: int, dst_block_idx: int): + """ + Copy the content of src_block_idx to dst_block_idx. + + Args: + src_block_idx (int): The index of the source block. + dst_block_idx (int): The index of the destination block. + """ + for layer_idx in range(self.num_layers): + self.key_caches[layer_idx][dst_block_idx] = self.key_caches[layer_idx][ + src_block_idx + ].clone() + self.value_caches[layer_idx][dst_block_idx] = self.value_caches[layer_idx][ + src_block_idx + ].clone() + + def allocate(self, batch_size: int, key_len: int) -> None: + """ + Allocate physical slots for a every sequence with key_len tokens in this batcch. + + Args: + - batch_size (int): The batch size of the sequence. + - key_len (int): The length of the key. + + Returns: + - None + """ + self.slots_mapping = [] + past_context_len = self.seen_tokens + if self.batch2seq is None: + self.set_batch2seq_for_prompt_sharing(batch_size, 1) + for i in range(batch_size): + seq_idx = self.batch2seq[i][0] + # Scenario 1: New seqence: allocate blocks for this sequence + if seq_idx not in self.block_tables: + needed_blocks = (key_len + self.block_size - 1) // self.block_size + if needed_blocks > len(self.free_blocks): + raise AssertionError( + f"No space in KV cache to store new token state. needed_blocks: {needed_blocks}, free_blocks: {self.free_blocks}" + ) + blocks = self.free_blocks[:needed_blocks] + self.free_blocks = self.free_blocks[needed_blocks:] + self.block_tables[seq_idx] = blocks + for block_idx in blocks: + self.block_ref_count[block_idx] += 1 + # Senario 2: Partial processed sequence: find free slots in the allocated blocks or allocate new blocks + else: + seq_len = key_len + past_context_len + target_blocks = (seq_len + self.block_size - 1) // self.block_size + new_blocks = target_blocks - len(self.block_tables[seq_idx]) + + if new_blocks > len(self.free_blocks): + raise AssertionError( + f"PagedAttentionCache: No enough free blocks to allocate for sequence {seq_idx}." + ) + + if new_blocks > 0: # allocate new blocks + candidate_blocks = self.free_blocks[:new_blocks] + self.block_tables[seq_idx].extend(self.free_blocks[:new_blocks]) + self.free_blocks = self.free_blocks[new_blocks:] + for block_idx in candidate_blocks: + self.block_ref_count[block_idx] += 1 + else: + last_block = self.block_tables[seq_idx][-1] + # sharing the last block with other sequences, need to allocate a new block and copy the last block + if self.block_ref_count[last_block] > 1: + if len(self.free_blocks) == 0: + raise AssertionError( + f"PagedAttentionCache: No enough free blocks to allocate for sequence {seq_idx}." + ) + new_block = self.free_blocks.pop() + self.block_tables[seq_idx][-1] = new_block + self.block_ref_count[new_block] += 1 + self.block_ref_count[last_block] -= 1 + self._copy_on_write(last_block, new_block) + + slots = [] + # the slots for this sequence + for j in range(key_len): + token_id = j + past_context_len + block_idx = token_id // self.block_size + block_offset = token_id % self.block_size + slots.append( + self.block_tables[seq_idx][block_idx] * self.block_size + + block_offset + ) + self.slots_mapping.append(slots) + self.slots_mapping = torch.tensor( + self.slots_mapping, dtype=torch.long, device=self.device + ) + # step 2): fork new sequences to enable prompt sharing + for batch_idx in range(batch_size): + seq_ids = self.batch2seq[batch_idx] + # fork the blocks allocated for the first sequence to other sequences in the batch + for seq_id in seq_ids[1:]: + self._fork(seq_ids[0], seq_id) + + def _free(self, seq_idx: int): + """ + Frees the blocks allocated for the given sequence index. + + Args: + seq_idx (int): The index of the sequence whose blocks are to be freed. + + Raises: + AssertionError: If the given sequence index is not present in the block tables. + """ + + if seq_idx not in self.block_tables: + raise AssertionError( + f"PagedAttentionCache: Sequence index {seq_idx} is not present in the block tables." + ) + + for block_idx in self.block_tables[seq_idx]: + self.block_ref_count[block_idx] -= 1 + if self.block_ref_count[block_idx] == 0: + self.free_blocks.append(block_idx) + + def _fork(self, seq_idx: int, new_seq_idx: int): + """ + Forks the blocks allocated for seq_idx to new_seq_idx. + + Args: + seq_idx (int): The index of the sequence to be forked. + new_seq_idx (int): The index of the new sequence. + + Raises: + AssertionError: If seq_idx is not in block_tables or if new_seq_idx is already in block_tables. + """ + if seq_idx not in self.block_tables: + raise AssertionError( + f"PagedAttentionCache: Sequence index {seq_idx} is not present in the block tables." + ) + + self.block_tables[new_seq_idx] = copy.deepcopy(self.block_tables[seq_idx]) + for block_idx in self.block_tables[seq_idx]: + self.block_ref_count[block_idx] += 1 + + def set_batch2seq_for_prompt_sharing(self, batch_size: int, beam_size: int): + """ + Set the batch2seq mapping for prompt sharing. + + Args: + batch_size (int): The batch size. + beam_size (int): The beam size. + """ + self.batch2seq = {} + for i in range(batch_size): + self.batch2seq[i] = [i * beam_size + j for j in range(beam_size)] + + def _reshape_and_cache( + self, + slot_mapping: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + ): + """ + Reshapes and caches the key and value states based on the given slot mapping. + + Args: + slot_mapping (List[List[int]]): A list of lists representing the slot mapping. + key_states (torch.Tensor): The key states tensor. + value_states (torch.Tensor): The value states tensor. + layer_idx (int): The index of the layer. + + Returns: + None + """ + slot_mapping = slot_mapping.to(torch.int) + block_indicies = torch.div(slot_mapping, self.block_size, rounding_mode="floor") + block_indicies = block_indicies.cpu().tolist() + block_offsets = slot_mapping % self.block_size + block_offsets = block_offsets.cpu().tolist() + batch_size = key_states.size(0) + seq_len = key_states.size(2) + for i in range(batch_size): + for seq_idx in range(seq_len): + block_idx = block_indicies[i][seq_idx] + block_offset = block_offsets[i][seq_idx] + for head_idx in range(self.num_key_value_heads): + self.key_caches[layer_idx][block_idx, head_idx, block_offset, :] = ( + key_states[i, head_idx, seq_idx, :] + ) + self.value_caches[layer_idx][ + block_idx, head_idx, block_offset, : + ] = value_states[i, head_idx, seq_idx, :] + + def get_seq_length(self, layer_idx: int = 0) -> int: + return self.seen_tokens + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states. PagedAttentionCache does not have a maximum length.""" + RuntimeError("PagedAttentionCache does not have a maximum sequence length.") + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update the cache with key and value states for a specific layer. + + Args: + key_states (torch.Tensor): The new key states tensor of shape [batch, head, seq, dim]. + value_states (torch.Tensor): The new value states tensor of shape [batch, head, seq, dim]. + layer_idx (int): The index of the layer. + cache_kwargs (Dict[str, Any]): Additional arguments for the cache subclass. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing the updated key states and value states tensors(entire context token states). + + Raises: + AssertionError: If the batch size is inconsistent with the existing cache. + """ + batch_size = key_states.shape[0] # [batch, head, seq, dim] + cur_len = key_states.shape[-2] + + # # slots info for key/value are same for every layer and allocate should be called before model.forward() to reduce the allocation overhead + # AssertionError( + # self.slots_mapping is not None, + # "PagedAttentionCache: Please first call allocate() of this object to get target positions in paged cache before the model.forward().", + # ) + # cache key_states & value_states + self._reshape_and_cache(self.slots_mapping, key_states, value_states, layer_idx) + + if layer_idx == self.num_layers - 1: + self.seen_tokens += cur_len + self.slot_mapping = None + + if ( + self.seen_tokens == 0 + or self.seen_tokens == cur_len + and layer_idx == self.num_layers - 1 + ): # first token + return key_states, value_states + else: # Next token + if layer_idx == self.num_layers - 1: + # last layer already updated self.seen_tokens + context_lens = torch.tensor( + [self.seen_tokens for _ in range(batch_size)], + dtype=torch.int32, + ) + else: + context_lens = torch.tensor( + [self.seen_tokens + cur_len for _ in range(batch_size)], + dtype=torch.int32, + ) + block_tables_t = [] + for seq_idx in range(batch_size): + block_tables_t.append(self.block_tables[seq_idx]) + block_tables_t = torch.tensor( + block_tables_t, dtype=torch.int32, device=self.device + ) + return PagedTensor( + (batch_size, self.num_key_value_heads, context_lens[0].item(), self.head_dim), self.key_caches[layer_idx], block_tables_t, + ), PagedTensor((batch_size, self.num_key_value_heads, context_lens[0].item(), self.head_dim), self.value_caches[layer_idx], block_tables_t) + + def reorder_cache(self, beam_idx: torch.Tensor) -> None: + """ + Reorder the cache according to the beam index. The beam index is a tensor of shape (batch_size,) + and the sequence id can be get from the self.batch2seq. + """ + freed_seqs = [] + new_block_tables = self.block_tables.copy() + num_beams = beam_idx.numel() // len(self.batch2seq) + for batch_idx, target_batch_idx in enumerate(beam_idx.tolist()): + target_seq_id = self.batch2seq[target_batch_idx // num_beams][0] + seq_id = self.batch2seq[batch_idx // num_beams][0] + freed_seqs.append(seq_id) + new_block_tables[seq_id] = [] + for block in self.block_tables[target_seq_id]: + self.block_ref_count[block] += 1 + new_block_tables[seq_id].append(block) + for seq_idx in freed_seqs: + self._free(seq_idx) + self.block_tables = new_block_tables + self.batch2seq = None