From b759afaa2a5943360fe485082b6583a221c00589 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 24 Jul 2023 18:32:41 +0300 Subject: [PATCH 1/2] Another speed gain for Q4_0 and Q4_1 on Metal --- ggml-metal.metal | 104 ++++++++++++++++++++++++----------------------- 1 file changed, 54 insertions(+), 50 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 987376d560879..0c4945844a3d2 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -387,39 +387,50 @@ kernel void kernel_rms_norm( } } -// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i]) -float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) { +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; - float4 acc = 0.f; - device uint16_t * qs = ((device uint16_t *)qb_curr + 1); - for (int i = 0; i < 16; i+=2) { - acc[0] += yl[i] * (qs[i / 2] & 0x000F); - acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); - acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); + float2 acc = 0.f; + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); } - return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f); + return d * (sumy * -8.f + acc[0] + acc[1]); } -// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i]) -float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) { +// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; float m = qb_curr->m; - float4 acc = 0.f; - device uint16_t * qs = ((device uint16_t *)qb_curr + 2); - for (int i = 0; i < 16; i+=2) { - acc[0] += yl[i] * (qs[i / 2] & 0x000F); - acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); - acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + float2 acc = 0.f; + for (int i = 0; i < 8; i+=2) { + const uint16_t qss = qs[i / 2] >> 8; + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); } - return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m; + return d * (acc[0] + acc[1]) + sumy * m; } // putting them in the kernel cause a significant performance penalty #define N_DST 4 // each SIMD group works on 4 rows #define N_SIMDGROUP 2 // number of SIMD groups in a thread group #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 +# Note: This is a template, but strictly speaking it only applies to +# quantizations where the block size is 32. It also does not +# giard against the number of rows not being divisible by +# N_DST, so this is another explicit assumption of the implementation. template void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01, @@ -427,47 +438,40 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device const int nb = ne00/QK4_0; const int r0 = tgpig.x; const int r1 = tgpig.y; - device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb; device const float * y = (device const float *) src1 + r1*ne10; - float4 y_curr[8]; // src1 vector cache - float sumf[N_DST]={0.f}, all_sum; - thread float * yl=(thread float *)y_curr; + float yl[16]; // src1 vector cache + float sumf[N_DST]={0.f}; - // each thread in a SIMD group deals with 1 block. - for (int column = 0; column < nb / N_SIMDWIDTH; column++) { - float sumy = 0; - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i); - sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; - } + const int ix = tiisg/2; + const int il = 8*(tiisg%2); - for (int row = 0; row < N_DST; row++) { - sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl); - } - } + device const float * yb = y + ix * QK4_0 + il; - // from now loads two rows every time and 16 blocks per row - int ir = tiisg / (N_SIMDWIDTH / 2); - int ib = tiisg % (N_SIMDWIDTH / 2); - for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) { - int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { float sumy = 0; - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i); - sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; + for (int i = 0; i < 8; i += 2) { + sumy += yb[i] + yb[i+1]; + yl[i+0] = yb[i+ 0]; + yl[i+1] = yb[i+ 1]/256.f; + sumy += yb[i+16] + yb[i+17]; + yl[i+8] = yb[i+16]/16.f; + yl[i+9] = yb[i+17]/4096.f; } - for (int row = 0; row < N_DST; row+=2) { - if (nb_start + ib < nb) { - sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl); - } + for (int row = 0; row < N_DST; row++) { + sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); } + + yb += QK4_0 * 16; } for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < ne01) { + dst[r1*ne0 + first_row + row] = tot; } } } From 7f98561243a8a3c43b55a8579a12fcf856078447 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 24 Jul 2023 19:05:44 +0300 Subject: [PATCH 2/2] Have N_DST, etc., be template parameters --- ggml-metal.metal | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 0c4945844a3d2..696b33ce75cf4 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -414,7 +414,6 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); float2 acc = 0.f; for (int i = 0; i < 8; i+=2) { - const uint16_t qss = qs[i / 2] >> 8; acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + yl[i + 1] * (qs[i / 2] & 0x0F00); acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) @@ -427,22 +426,22 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre #define N_DST 4 // each SIMD group works on 4 rows #define N_SIMDGROUP 2 // number of SIMD groups in a thread group #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 -# Note: This is a template, but strictly speaking it only applies to -# quantizations where the block size is 32. It also does not -# giard against the number of rows not being divisible by -# N_DST, so this is another explicit assumption of the implementation. -template +//Note: This is a template, but strictly speaking it only applies to +// quantizations where the block size is 32. It also does not +// giard against the number of rows not being divisible by +// N_DST, so this is another explicit assumption of the implementation. +template void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01, uint2 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; const int r0 = tgpig.x; const int r1 = tgpig.y; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr; device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb; device const float * y = (device const float *) src1 + r1*ne10; float yl[16]; // src1 vector cache - float sumf[N_DST]={0.f}; + float sumf[nr]={0.f}; const int ix = tiisg/2; const int il = 8*(tiisg%2); @@ -450,7 +449,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device device const float * yb = y + ix * QK4_0 + il; // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + for (int ib = ix; ib < nb; ib += nw/2) { float sumy = 0; for (int i = 0; i < 8; i += 2) { sumy += yb[i] + yb[i+1]; @@ -461,14 +460,14 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device yl[i+9] = yb[i+17]/4096.f; } - for (int row = 0; row < N_DST; row++) { + for (int row = 0; row < nr; row++) { sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); } yb += QK4_0 * 16; } - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < ne01) { dst[r1*ne0 + first_row + row] = tot; @@ -487,7 +486,7 @@ kernel void kernel_mul_mat_q4_0_f32( uint2 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); + mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); } kernel void kernel_mul_mat_q4_1_f32( @@ -501,7 +500,7 @@ kernel void kernel_mul_mat_q4_1_f32( uint2 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); + mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); } kernel void kernel_mul_mat_f16_f32(