@@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel(
159
159
// block_size]
160
160
const int64_t * __restrict__ slot_mapping, // [num_tokens]
161
161
const int key_stride, const int value_stride, const int num_heads,
162
- const int head_size, const int block_size, const int x, const float k_scale,
163
- const float v_scale) {
162
+ const int head_size, const int block_size, const int x,
163
+ const float * k_scale, const float * v_scale) {
164
164
const int64_t token_idx = blockIdx .x ;
165
165
const int64_t slot_idx = slot_mapping[token_idx];
166
166
if (slot_idx < 0 ) {
@@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel(
196
196
value_cache[tgt_value_idx] = tgt_value;
197
197
} else {
198
198
key_cache[tgt_key_idx] =
199
- fp8::scaled_convert<cache_t , scalar_t , kv_dt>(tgt_key, k_scale);
199
+ fp8::scaled_convert<cache_t , scalar_t , kv_dt>(tgt_key, * k_scale);
200
200
value_cache[tgt_value_idx] =
201
- fp8::scaled_convert<cache_t , scalar_t , kv_dt>(tgt_value, v_scale);
201
+ fp8::scaled_convert<cache_t , scalar_t , kv_dt>(tgt_value, * v_scale);
202
202
}
203
203
}
204
204
}
@@ -214,7 +214,7 @@ __global__ void reshape_and_cache_flash_kernel(
214
214
const int64_t * __restrict__ slot_mapping, // [num_tokens]
215
215
const int block_stride, const int key_stride, const int value_stride,
216
216
const int num_heads, const int head_size, const int block_size,
217
- const float k_scale, const float v_scale) {
217
+ const float * k_scale, const float * v_scale) {
218
218
const int64_t token_idx = blockIdx .x ;
219
219
const int64_t slot_idx = slot_mapping[token_idx];
220
220
// NOTE: slot_idx can be -1 if the token is padded
@@ -239,9 +239,9 @@ __global__ void reshape_and_cache_flash_kernel(
239
239
value_cache[tgt_key_value_idx] = tgt_value;
240
240
} else {
241
241
key_cache[tgt_key_value_idx] =
242
- fp8::scaled_convert<cache_t , scalar_t , kv_dt>(tgt_key, k_scale);
242
+ fp8::scaled_convert<cache_t , scalar_t , kv_dt>(tgt_key, * k_scale);
243
243
value_cache[tgt_key_value_idx] =
244
- fp8::scaled_convert<cache_t , scalar_t , kv_dt>(tgt_value, v_scale);
244
+ fp8::scaled_convert<cache_t , scalar_t , kv_dt>(tgt_value, * v_scale);
245
245
}
246
246
}
247
247
}
@@ -258,7 +258,9 @@ __global__ void reshape_and_cache_flash_kernel(
258
258
reinterpret_cast <CACHE_T*>(key_cache.data_ptr()), \
259
259
reinterpret_cast <CACHE_T*>(value_cache.data_ptr()), \
260
260
slot_mapping.data_ptr<int64_t >(), key_stride, value_stride, \
261
- num_heads, head_size, block_size, x, k_scale, v_scale);
261
+ num_heads, head_size, block_size, x, \
262
+ reinterpret_cast <const float *>(k_scale.data_ptr()), \
263
+ reinterpret_cast <const float *>(v_scale.data_ptr()));
262
264
263
265
void reshape_and_cache (
264
266
torch::Tensor& key, // [num_tokens, num_heads, head_size]
@@ -268,8 +270,8 @@ void reshape_and_cache(
268
270
torch::Tensor&
269
271
value_cache, // [num_blocks, num_heads, head_size, block_size]
270
272
torch::Tensor& slot_mapping, // [num_tokens]
271
- const std::string& kv_cache_dtype, const double k_scale,
272
- const double v_scale) {
273
+ const std::string& kv_cache_dtype, torch::Tensor& k_scale,
274
+ torch::Tensor& v_scale) {
273
275
int num_tokens = key.size (0 );
274
276
int num_heads = key.size (1 );
275
277
int head_size = key.size (2 );
@@ -299,7 +301,9 @@ void reshape_and_cache(
299
301
reinterpret_cast <CACHE_T*>(key_cache.data_ptr()), \
300
302
reinterpret_cast <CACHE_T*>(value_cache.data_ptr()), \
301
303
slot_mapping.data_ptr<int64_t >(), block_stride, key_stride, \
302
- value_stride, num_heads, head_size, block_size, k_scale, v_scale);
304
+ value_stride, num_heads, head_size, block_size, \
305
+ reinterpret_cast <const float *>(k_scale.data_ptr()), \
306
+ reinterpret_cast <const float *>(v_scale.data_ptr()));
303
307
304
308
void reshape_and_cache_flash (
305
309
torch::Tensor& key, // [num_tokens, num_heads, head_size]
@@ -308,8 +312,8 @@ void reshape_and_cache_flash(
308
312
torch::Tensor&
309
313
value_cache, // [num_blocks, block_size, num_heads, head_size]
310
314
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
311
- const std::string& kv_cache_dtype, const double k_scale,
312
- const double v_scale) {
315
+ const std::string& kv_cache_dtype, torch::Tensor& k_scale,
316
+ torch::Tensor& v_scale) {
313
317
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
314
318
// slot_mapping.size(0) because of padding for CUDA graphs.
315
319
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
0 commit comments