From 9c3e05d5249ef059f4ebfb97e0535deb9df53752 Mon Sep 17 00:00:00 2001 From: Jhen Date: Tue, 17 Oct 2023 11:55:34 +0800 Subject: [PATCH 1/7] metal : implement dequantize_q5_0 --- ggml-metal.m | 47 +++++++++++++++++-- ggml-metal.metal | 114 ++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 152 insertions(+), 9 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 87fa172161405..c908106be5e68 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -73,6 +73,8 @@ GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_1); + GGML_METAL_DECL_KERNEL(get_rows_q5_0); + GGML_METAL_DECL_KERNEL(get_rows_q5_1); GGML_METAL_DECL_KERNEL(get_rows_q8_0); GGML_METAL_DECL_KERNEL(get_rows_q2_K); GGML_METAL_DECL_KERNEL(get_rows_q3_K); @@ -87,6 +89,8 @@ GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4); GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32); GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32); GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32); @@ -97,6 +101,8 @@ GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32); GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32); GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32); @@ -254,6 +260,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){ GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_1); + GGML_METAL_ADD_KERNEL(get_rows_q5_0); + GGML_METAL_ADD_KERNEL(get_rows_q5_1); GGML_METAL_ADD_KERNEL(get_rows_q8_0); GGML_METAL_ADD_KERNEL(get_rows_q2_K); GGML_METAL_ADD_KERNEL(get_rows_q3_K); @@ -268,6 +276,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){ GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4); GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32); GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32); GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32); @@ -278,8 +288,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){ GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); @@ -346,6 +358,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(get_rows_f16); GGML_METAL_DEL_KERNEL(get_rows_q4_0); GGML_METAL_DEL_KERNEL(get_rows_q4_1); + GGML_METAL_DEL_KERNEL(get_rows_q5_0); + GGML_METAL_DEL_KERNEL(get_rows_q5_1); GGML_METAL_DEL_KERNEL(get_rows_q8_0); GGML_METAL_DEL_KERNEL(get_rows_q2_K); GGML_METAL_DEL_KERNEL(get_rows_q3_K); @@ -360,6 +374,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4); GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32); GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32); GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32); GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32); @@ -370,8 +386,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); @@ -1052,6 +1070,8 @@ void ggml_metal_graph_compute( case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break; + case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break; + case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break; case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break; case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break; case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break; @@ -1121,6 +1141,24 @@ void ggml_metal_graph_compute( nth1 = 8; [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32]; } break; + case GGML_TYPE_Q5_0: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32]; + } break; + case GGML_TYPE_Q5_1: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32]; + } break; case GGML_TYPE_Q8_0: { GGML_ASSERT(ne02 == 1); @@ -1201,7 +1239,8 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || + src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } @@ -1233,6 +1272,8 @@ void ggml_metal_graph_compute( case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; + case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break; + case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break; case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break; case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break; case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 99b9fd7a748ba..1887dc140cfe5 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -18,6 +18,20 @@ typedef struct { uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; +#define QK5_0 32 +typedef struct { + half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; + +#define QK5_1 32 +typedef struct { + half d; // delta + half m; // min + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; + #define QK8_0 32 typedef struct { half d; // delta @@ -428,6 +442,22 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre return d * (acc[0] + acc[1]) + sumy * m; } +// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { + // TODO +} + +// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_1/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { + // TODO +} + // putting them in the kernel cause a significant performance penalty #define N_DST 4 // each SIMD group works on 4 rows #define N_SIMDGROUP 2 // number of SIMD groups in a thread group @@ -525,6 +555,43 @@ kernel void kernel_mul_mv_q4_1_f32( mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); } +kernel void kernel_mul_mv_q5_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); +} + +kernel void kernel_mul_mv_q5_1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); +} + + #define NB_Q8_0 8 kernel void kernel_mul_mv_q8_0_f32( @@ -2122,15 +2189,13 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) template void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 1); - const float d1 = il ? (xb->d / 16.h) : xb->d; - const float d2 = d1 / 256.f; + const float d = xb->d; const float md = -8.h * xb->d; - const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = mask0 << 8; + const ushort mask = il ? 0x00F0 : 0x000F; for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; - reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; + reg[i/2][2*(i%2)+0] = d * ((qs[i] & mask) >> (il ? 4 : 0)) + md; + reg[i/2][2*(i%2)+1] = d * (((qs[i] >> 8) & mask) >> (il ? 4 : 0)) + md; } } @@ -2149,6 +2214,39 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg } } +template +void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 3); + const float d = xb->d; + const float md = -16.h * xb->d; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = (il ? 4 : 0); + + const int gh_mv = (il ? 12 : 0); + const int gh_bk = (il ? 0 : 4); + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = (((qs[i] & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + md; + reg[i/2][2*(i%2)+1] = d * x1 + md; + } +} + +template +void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { + // TODO +} + template void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { device const int8_t * qs = ((device const int8_t *)xb->qs); @@ -2490,6 +2588,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; @@ -2518,6 +2618,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; From 7ebd4acb0ab90488ae182ecc28de9a69bdef56ca Mon Sep 17 00:00:00 2001 From: Jhen Date: Tue, 17 Oct 2023 11:56:07 +0800 Subject: [PATCH 2/7] metal : block_q_n_dot_y for block_q5_0 (broken) --- ggml-metal.metal | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 1887dc140cfe5..56c9562d0a2e2 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -447,7 +447,17 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre // we assume that the yl's have been multiplied with the appropriate scale factor // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { - // TODO + float d = qb_curr->d; + float2 acc = 0.f; + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x0010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x1000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 4 ) & 0x0010)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 12) & 0x1000)); + } + return d * (sumy * -16.f + acc[0] + acc[1]); } // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) @@ -2225,13 +2235,13 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg const int x_mv = (il ? 4 : 0); - const int gh_mv = (il ? 12 : 0); - const int gh_bk = (il ? 0 : 4); + const int qh_mv = (il ? 12 : 0); + const int qh_bk = (il ? 0 : 4); for (int i = 0; i < 8; i++) { // extract the 5-th bits for x0 and x1 - const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; - const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + const uint8_t xh_0 = ((qh >> (qh_mv + 2*i )) << qh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (qh_mv + 2*i+1)) << qh_bk) & 0x10; // combine the 4-bits from qs with the 5th bit const int32_t x0 = (((qs[i] & mask) >> x_mv) | xh_0); From a7a4887be92019a676f2cd84a2613c1c49e3f42d Mon Sep 17 00:00:00 2001 From: Jhen Date: Tue, 17 Oct 2023 12:11:26 +0800 Subject: [PATCH 3/7] metal : revert unnecessary change --- ggml-metal.metal | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 56c9562d0a2e2..c6e14f46e7f70 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2199,13 +2199,15 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) template void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 1); - const float d = xb->d; + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; const float md = -8.h * xb->d; - const ushort mask = il ? 0x00F0 : 0x000F; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)+0] = d * ((qs[i] & mask) >> (il ? 4 : 0)) + md; - reg[i/2][2*(i%2)+1] = d * (((qs[i] >> 8) & mask) >> (il ? 4 : 0)) + md; + reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; + reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; } } @@ -2235,13 +2237,13 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg const int x_mv = (il ? 4 : 0); - const int qh_mv = (il ? 12 : 0); - const int qh_bk = (il ? 0 : 4); + const int gh_mv = (il ? 12 : 0); + const int gh_bk = (il ? 0 : 4); for (int i = 0; i < 8; i++) { // extract the 5-th bits for x0 and x1 - const uint8_t xh_0 = ((qh >> (qh_mv + 2*i )) << qh_bk) & 0x10; - const uint8_t xh_1 = ((qh >> (qh_mv + 2*i+1)) << qh_bk) & 0x10; + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; // combine the 4-bits from qs with the 5th bit const int32_t x0 = (((qs[i] & mask) >> x_mv) | xh_0); From fce44a76281d4e49c10d4944f6c622c5ff45bba6 Mon Sep 17 00:00:00 2001 From: Jhen Date: Tue, 17 Oct 2023 13:44:44 +0800 Subject: [PATCH 4/7] metal : implement dequantize_q5_1 --- ggml-metal.metal | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index c6e14f46e7f70..e47d97beaea08 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -29,6 +29,7 @@ typedef struct { typedef struct { half d; // delta half m; // min + uint8_t qh[4]; // 5-th bit of quants uint8_t qs[QK5_1 / 2]; // nibbles / quants } block_q5_1; @@ -2235,10 +2236,10 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg const uint32_t qh = *((device const uint32_t *)xb->qh); - const int x_mv = (il ? 4 : 0); + const int x_mv = il ? 4 : 0; - const int gh_mv = (il ? 12 : 0); - const int gh_bk = (il ? 0 : 4); + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; for (int i = 0; i < 8; i++) { // extract the 5-th bits for x0 and x1 @@ -2256,7 +2257,30 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg template void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { - // TODO + device const uint16_t * qs = ((device const uint16_t *)xb + 4); + const float d = xb->d; + const float m = xb->m; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = (((qs[i] & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + m; + reg[i/2][2*(i%2)+1] = d * x1 + m; + } } template From 79d4732c70e5c9055cffecd7dbe4b60c4222a808 Mon Sep 17 00:00:00 2001 From: Jhen Date: Tue, 17 Oct 2023 13:45:08 +0800 Subject: [PATCH 5/7] metal : block_q_n_dot_y for q5_1 (broken) --- ggml-metal.metal | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index e47d97beaea08..9df4b8cb45497 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -466,7 +466,18 @@ inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thre // we assume that the yl's have been multiplied with the appropriate scale factor // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { - // TODO + float d = qb_curr->d; + float m = qb_curr->m; + float2 acc = 0.f; + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x0010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x1000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_1/2) << 4 ) & 0x0010)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_1/2) << 12) & 0x1000)); + } + return d * (acc[0] + acc[1]) + sumy * m; } // putting them in the kernel cause a significant performance penalty From 9db276f0c2e85bf25ea8d163686d4d3fe57d4eb5 Mon Sep 17 00:00:00 2001 From: Jhen Date: Wed, 18 Oct 2023 09:02:58 +0800 Subject: [PATCH 6/7] metal : fix block_q_n_dot_y --- ggml-metal.metal | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 9df4b8cb45497..a345b83d9c27d 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -453,10 +453,10 @@ inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thre device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); const uint32_t qh = *((device const uint32_t *)qb_curr->qh); for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x0010)) - + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x1000)); - acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 4 ) & 0x0010)) - + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 12) & 0x1000)); + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); } return d * (sumy * -16.f + acc[0] + acc[1]); } @@ -472,10 +472,10 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); const uint32_t qh = *((device const uint32_t *)qb_curr->qh); for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x0010)) - + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x1000)); - acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_1/2) << 4 ) & 0x0010)) - + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_1/2) << 12) & 0x1000)); + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); } return d * (acc[0] + acc[1]) + sumy * m; } From 7a885229753438b5f625d0954df5d2f3e8b82079 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 18 Oct 2023 15:20:35 +0300 Subject: [PATCH 7/7] minor : spaces / formatting --- ggml-metal.metal | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index a345b83d9c27d..69fc7136279ad 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -414,8 +414,11 @@ kernel void kernel_rms_norm( // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; + float2 acc = 0.f; + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + for (int i = 0; i < 8; i+=2) { acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + yl[i + 1] * (qs[i / 2] & 0x0F00); @@ -432,8 +435,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; float m = qb_curr->m; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + for (int i = 0; i < 8; i+=2) { acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + yl[i + 1] * (qs[i / 2] & 0x0F00); @@ -449,9 +455,12 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; + float2 acc = 0.f; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); - const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + for (int i = 0; i < 8; i+=2) { acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); @@ -468,9 +477,12 @@ inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thre inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; float m = qb_curr->m; + float2 acc = 0.f; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); - const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + for (int i = 0; i < 8; i+=2) { acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); @@ -2250,7 +2262,7 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg const int x_mv = il ? 4 : 0; const int gh_mv = il ? 12 : 0; - const int gh_bk = il ? 0 : 4; + const int gh_bk = il ? 0 : 4; for (int i = 0; i < 8; i++) { // extract the 5-th bits for x0 and x1 @@ -2258,7 +2270,7 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; // combine the 4-bits from qs with the 5th bit - const int32_t x0 = (((qs[i] & mask) >> x_mv) | xh_0); + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); reg[i/2][2*(i%2)+0] = d * x0 + md; @@ -2278,7 +2290,7 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg const int x_mv = il ? 4 : 0; const int gh_mv = il ? 12 : 0; - const int gh_bk = il ? 0 : 4; + const int gh_bk = il ? 0 : 4; for (int i = 0; i < 8; i++) { // extract the 5-th bits for x0 and x1 @@ -2286,7 +2298,7 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; // combine the 4-bits from qs with the 5th bit - const int32_t x0 = (((qs[i] & mask) >> x_mv) | xh_0); + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); reg[i/2][2*(i%2)+0] = d * x0 + m;