Skip to content

Commit ea88404

Browse files
author
Ivan Chikish
committed
cuda: add q8_0->f32 cpy operation
llama: enable K-shift for quantized KV cache It will fail on unsupported backends or quant types.
1 parent d39e267 commit ea88404

File tree

3 files changed

+81
-8
lines changed

3 files changed

+81
-8
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2841,6 +2841,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28412841
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
28422842
return true;
28432843
}
2844+
if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
2845+
return true;
2846+
}
28442847
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
28452848
return true;
28462849
}

ggml/src/ggml-cuda/cpy.cu

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,17 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
8181
}
8282
}
8383

84+
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
85+
const block_q8_0 * xi = (const block_q8_0 *) cxi;
86+
float * dsti = (float *) cdsti;
87+
88+
const float d = (float)xi->d;
89+
90+
for (int j = 0; j < QK8_0; j++) {
91+
dsti[j] = xi->qs[j] * d;
92+
}
93+
}
94+
8495
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
8596
const float * xi = (const float *) cxi;
8697
block_q4_0 * dsti = (block_q4_0 *) cdsti;
@@ -288,6 +299,32 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
288299
cpy_blck(cx + x_offset, cdst + dst_offset);
289300
}
290301

302+
template <cpy_kernel_t cpy_blck, int qk>
303+
static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
304+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
305+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
306+
const int nb12, const int nb13) {
307+
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
308+
309+
if (i >= ne) {
310+
return;
311+
}
312+
313+
const int i03 = i/(ne00 * ne01 * ne02);
314+
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
315+
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
316+
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
317+
const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
318+
319+
const int i13 = i/(ne10 * ne11 * ne12);
320+
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
321+
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
322+
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
323+
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
324+
325+
cpy_blck(cx + x_offset, cdst + dst_offset);
326+
}
327+
291328
static void ggml_cpy_f16_f32_cuda(
292329
const char * cx, char * cdst, const int ne,
293330
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -329,6 +366,16 @@ static void ggml_cpy_f32_q8_0_cuda(
329366
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
330367
}
331368

369+
static void ggml_cpy_q8_0_f32_cuda(
370+
const char * cx, char * cdst, const int ne,
371+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
372+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
373+
374+
const int num_blocks = ne;
375+
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
376+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
377+
}
378+
332379
static void ggml_cpy_f32_q4_0_cuda(
333380
const char * cx, char * cdst, const int ne,
334381
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -437,6 +484,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
437484
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
438485
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
439486
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
487+
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
488+
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
440489
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
441490
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
442491
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
@@ -471,6 +520,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
471520
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
472521
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
473522
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
523+
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
524+
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
474525
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
475526
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
476527
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {

src/llama.cpp

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9934,17 +9934,36 @@ struct llm_build_context {
99349934
const int64_t n_head_kv = hparams.n_head_kv(il);
99359935
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
99369936
struct ggml_tensor * rope_factors = build_rope_factors(il);
9937-
struct ggml_tensor * tmp =
9937+
struct ggml_tensor * k =
9938+
ggml_view_3d(ctx0, kv_self.k_l[il],
9939+
n_embd_head_k, n_head_kv, n_ctx,
9940+
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
9941+
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
9942+
0);
9943+
9944+
struct ggml_tensor * tmp;
9945+
if (ggml_is_quantized(k->type)) {
9946+
// dequantize to f32 -> RoPE -> quantize back
9947+
tmp = ggml_cast(ctx0, k, GGML_TYPE_F32);
9948+
cb(tmp, "K_f32", il);
9949+
for (auto * backend : lctx.backends) {
9950+
// Figure out which backend KV cache belongs to
9951+
if (ggml_backend_supports_buft(backend, lctx.model.buft_layer[il].buft)) {
9952+
ggml_backend_sched_set_tensor_backend(lctx.sched, tmp, backend);
9953+
break;
9954+
}
9955+
}
9956+
tmp = ggml_rope_ext_inplace(ctx0, tmp,
9957+
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
9958+
ext_factor, attn_factor, beta_fast, beta_slow);
9959+
cb(tmp, "K_shifted_f32", il);
9960+
tmp = ggml_cpy(ctx0, tmp, k);
9961+
} else {
99389962
// we rotate only the first n_rot dimensions
9939-
ggml_rope_ext_inplace(ctx0,
9940-
ggml_view_3d(ctx0, kv_self.k_l[il],
9941-
n_embd_head_k, n_head_kv, n_ctx,
9942-
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
9943-
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
9944-
0),
9963+
tmp = ggml_rope_ext_inplace(ctx0, k,
99459964
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
99469965
ext_factor, attn_factor, beta_fast, beta_slow);
9947-
9966+
}
99489967
cb(tmp, "K_shifted", il);
99499968
ggml_build_forward_expand(gf, tmp);
99509969
}

0 commit comments

Comments
 (0)