Skip to content

Commit 97c6134

Browse files
refactor decode_kv_cache_memcpy
1 parent 5630324 commit 97c6134

File tree

8 files changed

+107
-56
lines changed

8 files changed

+107
-56
lines changed

colossalai/inference/modeling/models/nopadding_llama.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,8 @@ def llama_model_forward(
9898
"""
9999
block_tables = inputmetadata.block_tables
100100
sequence_lengths = inputmetadata.sequence_lengths
101-
batch_size = inputmetadata.batch_size
102101
kv_seq_len = inputmetadata.kv_seq_len
103102

104-
# NOTE: After testing, the performance of this configuration is relatively good. With updates
105-
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
106-
# selection should be conducted.
107-
if batch_size >= 32 and kv_seq_len > 512:
108-
use_cuda_kernel = False
109-
110103
# NOTE (yuanheng-zhao): fow now, only triton kernels support verification process
111104
# during speculative-decoding (`q_len > 1`)
112105
# We will expicitly disable `use_cuda_kernel` here when speculative-decoding is enabled

examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
configs = [
2121
triton.testing.Benchmark(
2222
x_names=["MAX_NUM_BLOCKS_PER_SEQ"],
23-
x_vals=[2**i for i in range(3, 8)],
23+
x_vals=[2**i for i in range(2, 8)],
2424
line_arg="provider",
2525
line_vals=[
2626
"vllm_paged_decoding_attention",
@@ -113,6 +113,7 @@ def benchmark_flash_decoding_attention(
113113
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
114114
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
115115
sm_scale = 1.0 / (HEAD_SIZE**0.5)
116+
kv_scale = 1.0
116117

117118
mid_output = torch.empty(
118119
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
@@ -136,6 +137,7 @@ def benchmark_flash_decoding_attention(
136137
max_seq_len_across_batch,
137138
alibi_slopes,
138139
"auto",
140+
kv_scale,
139141
)
140142
elif provider == "triton_flash_decoding_attention":
141143
fn = lambda: flash_decoding_attention(

examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def benchmark_rotary_emb(
105105
elif provider == "no_fused_cuda_rotary_emb_func":
106106
fn = lambda: [
107107
inference_ops.rotary_embedding(new_q, new_k, cos, sin, True),
108-
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables),
108+
inference_ops.decode_kv_cache_memcpy(new_k, new_v, new_k_cache, v_cache, kv_seq_lengths, block_tables),
109109
]
110110
elif provider == "fused_cuda_rotary_emb_func":
111111
fn = lambda: inference_ops.rotary_embedding_and_cache_copy(

examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from colossalai.kernel.kernel_loader import InferenceOpsLoader
55
from colossalai.kernel.triton import copy_kv_to_blocked_cache
66
from colossalai.utils import get_current_device
7+
from tests.test_infer.test_ops.cuda.test_kv_cache_memcpy import prepare_data as prepare_data_new_kcache_layout
78
from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data
89

910
try:
@@ -68,6 +69,9 @@ def benchmark_kvcache_copy(
6869
elif provider == "triton_copy_func":
6970
fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
7071
elif provider == "cuda_copy_func":
72+
_, _, k_cache, _, _, _, _, _, _ = prepare_data_new_kcache_layout(
73+
bsz, num_kv_heads, block_size, max_seq_len // block_size, context_lengths - 1, device, dtype
74+
)
7175
new_k = new_k.squeeze(1) if new_k.dim() == 4 else new_k
7276
new_v = new_v.squeeze(1) if new_v.dim() == 4 else new_v
7377
fn = lambda: inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)

extensions/csrc/kernel/cuda/decode_kv_cache_memcpy_kernel.cu

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ __global__ void decode_kv_cache_memcpy_kernel(
2020
const int block_size,
2121
const int64_t key_stride,
2222
const int64_t value_stride,
23-
const int block_table_stride
23+
const int block_table_stride,
24+
const int x
2425
)
2526
{
2627
const int seq_id = blockIdx.x;
@@ -38,28 +39,42 @@ __global__ void decode_kv_cache_memcpy_kernel(
3839
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
3940
const int head_id = i / head_dim;
4041
const int head_offset = i % head_dim;
42+
const int x_id = head_offset / x;
43+
const int x_offset = head_offset % x;
4144
const int64_t key_src_id = seq_id * key_stride + i;
4245
const int64_t value_src_id = seq_id * value_stride + i;
43-
const int64_t target_id = block_id * hidden_size * block_size
46+
const int64_t target_key_id = block_id * hidden_size * block_size
47+
+ head_id * block_size * head_dim
48+
+ x_id * block_size * x
49+
+ block_offset * x
50+
+ x_offset;
51+
const int64_t target_value_id = block_id * hidden_size * block_size
4452
+ head_id * block_size * head_dim
4553
+ block_offset * head_dim + head_offset;
4654

47-
copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id);
48-
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
55+
copy_vector<scalar_t, VecSize>(key_cache + target_key_id, key + key_src_id);
56+
copy_vector<scalar_t, VecSize>(value_cache + target_value_id, value + value_src_id);
4957
}
5058

5159
if (!Aligned) {
5260
for (; i < hidden_size; ++i ) {
5361
const int head_id = i / head_dim;
5462
const int head_offset = i % head_dim;
63+
const int x_id = head_offset / x;
64+
const int x_offset = head_offset % x;
5565
const int64_t key_src_id = seq_id * key_stride + i;
5666
const int64_t value_src_id = seq_id * value_stride + i;
57-
const int64_t target_id = block_id * hidden_size * block_size
67+
const int64_t target_key_id = block_id * hidden_size * block_size
68+
+ head_id * block_size * head_dim
69+
+ x_id * block_size * x
70+
+ block_offset * x
71+
+ x_offset;
72+
const int64_t target_value_id = block_id * hidden_size * block_size
5873
+ head_id * block_size * head_dim
5974
+ block_offset * head_dim + head_offset;
6075

61-
key_cache[target_id] = key[key_src_id];
62-
value_cache[target_id] = value[value_src_id];
76+
key_cache[target_key_id] = key[key_src_id];
77+
value_cache[target_value_id] = value[value_src_id];
6378
}
6479
}
6580

@@ -69,15 +84,16 @@ template<typename scalar_t>
6984
void apply_decode_kv_cache_memcpy(
7085
at::Tensor& key, // [num_tokens, head_num, head_dim]
7186
at::Tensor& value, // [num_tokens, head_num, head_dim]
72-
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
87+
at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
7388
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
7489
at::Tensor& sequence_lengths, // [batch_size]
7590
at::Tensor& block_tables) // [batch_size, max_seq_len]
7691
{
7792
int num_tokens = key.size(0);
7893
int head_num = key.size(1);
7994
int head_dim = key.size(2);
80-
int block_size = key_cache.size(2);
95+
int block_size = key_cache.size(3);
96+
int x = key_cache.size(4);
8197

8298
int64_t key_stride = key.stride(0);
8399
int64_t value_stride = value.stride(0);
@@ -110,7 +126,8 @@ void apply_decode_kv_cache_memcpy(
110126
block_size, \
111127
key_stride, \
112128
value_stride, \
113-
block_table_stride \
129+
block_table_stride, \
130+
x \
114131
); \
115132
} while(0)
116133

@@ -146,7 +163,7 @@ void apply_decode_kv_cache_memcpy(
146163
void decode_kv_cache_memcpy(
147164
at::Tensor& key, // [num_tokens, head_num, head_dim]
148165
at::Tensor& value, // [num_tokens, head_num, head_dim]
149-
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
166+
at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
150167
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
151168
at::Tensor& sequence_lengths, // [batch_size]
152169
at::Tensor& block_tables) // [batch_size, max_seq_len]

extensions/pybind/inference/inference.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
#include <torch/extension.h>
22

33
void decode_kv_cache_memcpy(
4-
torch::Tensor& key, // [num_tokens, num_heads, head_size]
5-
torch::Tensor& value, // [num_tokens, num_heads, head_size]
6-
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
4+
torch::Tensor& key, // [num_tokens, num_heads, head_size]
5+
torch::Tensor& value, // [num_tokens, num_heads, head_size]
6+
torch::Tensor&
7+
key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
78
torch::Tensor&
89
value_cache, // [num_blocks, num_heads, block_size, head_size]
910
torch::Tensor& sequence_lengths, // [batch_size]

tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def test_vllm_flash_decoding_attention(
193193
max_seq_len_across_batch = kv_seq_lengths.max().item()
194194
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
195195
sm_scale = 1.0 / (HEAD_SIZE**0.5)
196+
kv_scale = 1.0
196197

197198
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
198199
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
@@ -250,6 +251,7 @@ def test_vllm_flash_decoding_attention(
250251
max_seq_len_across_batch,
251252
alibi_slopes,
252253
"auto",
254+
kv_scale,
253255
)
254256
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
255257

tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py

Lines changed: 65 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,42 @@
44

55
from colossalai.kernel.kernel_loader import InferenceOpsLoader
66
from colossalai.utils import get_current_device
7-
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v3
8-
from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data
7+
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v3, mock_alloc_single_token
98

109
inference_ops = InferenceOpsLoader().load()
1110

1211
HEAD_DIM = 72
1312

1413

14+
def prepare_data(
15+
bsz,
16+
num_kv_heads,
17+
block_size,
18+
max_num_blocks_per_seq,
19+
context_lengths,
20+
device="cuda",
21+
dtype=torch.float16,
22+
):
23+
num_tokens = torch.sum(context_lengths).item()
24+
25+
max_seq_len_in_batch = context_lengths.max()
26+
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
27+
28+
kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
29+
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
30+
value = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
31+
32+
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v3(
33+
key, value, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
34+
)
35+
36+
block_tables = block_tables.to(device=device)
37+
k_cache = torch.zeros_like(k_cache_ref)
38+
v_cache = torch.zeros_like(v_cache_ref)
39+
40+
return key, value, k_cache, v_cache, cu_seqlens, block_tables, max_seq_len_in_batch, k_cache_ref, v_cache_ref
41+
42+
1543
def run_decode_copy_kv_to_caches(
1644
bsz: int,
1745
block_size: int,
@@ -24,32 +52,41 @@ def run_decode_copy_kv_to_caches(
2452
torch.cuda.synchronize()
2553
torch.cuda.reset_peak_memory_stats()
2654

55+
n = 1
56+
2757
max_seq_len = block_size * max_num_blocks_per_seq
2858
dtype = torch.float32
2959
device = get_current_device()
3060

31-
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables = prepare_data(
32-
bsz,
33-
num_kv_heads,
34-
HEAD_DIM,
35-
block_size,
36-
max_num_blocks_per_seq,
37-
same_context_len,
38-
max_seq_len,
39-
device=device,
40-
dtype=dtype,
61+
assert max_seq_len > n, "max_seq_len must be greater than n"
62+
63+
past_kv_seq_lengths = (
64+
torch.tensor([max_seq_len - n for _ in range(bsz)], dtype=torch.int32, device=device)
65+
if same_context_len
66+
else torch.randint(low=1, high=max_seq_len - n, size=(bsz,), dtype=torch.int32, device=device)
67+
)
68+
69+
key, value, k_cache, v_cache, _, block_tables, _, _, _ = prepare_data(
70+
bsz, num_kv_heads, block_size, max_num_blocks_per_seq, past_kv_seq_lengths, device, dtype
4171
)
4272

43-
new_k = new_k.squeeze(1) if new_k.dim() == 4 else new_k
44-
new_v = new_v.squeeze(1) if new_v.dim() == 4 else new_v
45-
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables)
73+
new_k = torch.randn((bsz, num_kv_heads, HEAD_DIM), dtype=dtype, device=device)
74+
new_v = torch.randn((bsz, num_kv_heads, HEAD_DIM), dtype=dtype, device=device)
4675

47-
past_kv_seq_len = kv_seq_lengths - 1
76+
# mock allocating blocks for the new k/v and update block tables
77+
for _ in range(n):
78+
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
79+
past_kv_seq_lengths += 1
80+
81+
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, past_kv_seq_lengths, block_tables)
82+
83+
past_kv_seq_len = past_kv_seq_lengths - 1
4884
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
4985
offsets_in_block = past_kv_seq_len % block_size
50-
k_target = k_cache[target_block_ids, :, offsets_in_block, :]
86+
k_target = k_cache[target_block_ids, :, :, offsets_in_block, :]
5187
k_source = new_k.squeeze()
5288
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
89+
k_target = k_target.reshape(v_target.shape)
5390
v_source = new_v.squeeze()
5491

5592
assert k_target.shape == k_source.shape
@@ -77,22 +114,17 @@ def run_context_copy_kv_to_cache(
77114
else:
78115
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
79116

80-
num_tokens = torch.sum(context_lengths).item()
81-
82-
max_seq_len_in_batch = context_lengths.max()
83-
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
84-
85-
kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
86-
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
87-
value = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
88-
89-
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v3(
90-
key, value, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
91-
)
92-
93-
block_tables = block_tables.to(device=device)
94-
k_cache = torch.zeros_like(k_cache_ref)
95-
v_cache = torch.zeros_like(v_cache_ref)
117+
(
118+
key,
119+
value,
120+
k_cache,
121+
v_cache,
122+
cu_seqlens,
123+
block_tables,
124+
max_seq_len_in_batch,
125+
k_cache_ref,
126+
v_cache_ref,
127+
) = prepare_data(bsz, num_kv_heads, block_size, max_num_blocks_per_seq, context_lengths, device, dtype)
96128

97129
inference_ops.context_kv_cache_memcpy(
98130
key, value, k_cache, v_cache, context_lengths, cu_seqlens, block_tables, max_seq_len_in_batch

0 commit comments

Comments
 (0)