diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 54c0f66d2dfed..500d4022c562e 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -225,6 +225,91 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) { memcpy(dsti->qh, &qh, sizeof(qh)); } +static __device__ void cpy_blck_q5_0_f32(const char * cxi, char * cdsti) { + const block_q5_0 * xi = (const block_q5_0 *) cxi; + float * dst = (float *) cdsti; + float d = xi->d; // scale factor (computed as vmax / -16) + const float shift = 16.0f; + + // Safely copy the 32-bit qh value to avoid misaligned access. + unsigned int qh; + memcpy(&qh, xi->qh, sizeof(qh)); + + // First half: lower nibble stores element j. + for (int j = 0; j < QK5_0/2; j++) { + uint8_t lower = xi->qs[j] & 0xF; + uint8_t high = (qh >> j) & 1; + uint8_t q = (high << 4) | lower; + dst[j] = ((float)q - shift) * d; + } + // Second half: upper nibble stores element j + QK5_0/2. + for (int j = QK5_0/2; j < QK5_0; j++) { + int k = j - QK5_0/2; + uint8_t lower = (xi->qs[k] >> 4) & 0xF; + uint8_t high = (qh >> j) & 1; + uint8_t q = (high << 4) | lower; + dst[j] = ((float)q - shift) * d; + } +} + +static __device__ void cpy_blck_q5_1_f32(const char * cxi, char * cdsti) { + const block_q5_1 * xi = (const block_q5_1 *) cxi; + float * dst = (float *) cdsti; + float d = xi->dm.x; // scale + float min_val = xi->dm.y; // minimum value + + // Safely copy the 32-bit qh value to avoid misaligned access. + unsigned int qh; + memcpy(&qh, xi->qh, sizeof(qh)); + + // Decode first half: lower nibble of xi->qs[j] holds element j. + for (int j = 0; j < QK5_1/2; j++) { + uint8_t lower = xi->qs[j] & 0xF; + uint8_t high = (qh >> j) & 1; + uint8_t q = (high << 4) | lower; + dst[j] = min_val + d * (float)q; + } + // Decode second half: upper nibble of xi->qs[j] holds element j+QK5_1/2. + for (int j = QK5_1/2; j < QK5_1; j++) { + int k = j - QK5_1/2; + uint8_t lower = (xi->qs[k] >> 4) & 0xF; + uint8_t high = (qh >> j) & 1; + uint8_t q = (high << 4) | lower; + dst[j] = min_val + d * (float)q; + } +} + +static __device__ void cpy_blck_q4_0_f32(const char * cxi, char * cdsti) { + const block_q4_0 * xi = (const block_q4_0 *) cxi; + float * dst = (float *) cdsti; + float d = xi->d; + const float shift = 8.0f; + + // Each byte packs two 4-bit quantized values. + for (int j = 0; j < QK4_0/2; j++) { + uint8_t q_val = xi->qs[j]; + uint8_t q0 = q_val & 0x0F; + uint8_t q1 = (q_val >> 4) & 0x0F; + dst[j] = ((float)q0 - shift) * d; + dst[j + QK4_0/2] = ((float)q1 - shift) * d; + } +} + +static __device__ void cpy_blck_q4_1_f32(const char * cxi, char * cdsti) { + const block_q4_1 * xi = (const block_q4_1 *) cxi; + float * dst = (float *) cdsti; + const float d = xi->dm.x; + const float vmin = xi->dm.y; + + // Each byte packs two 4-bit quantized values. + for (int j = 0; j < QK4_1/2; ++j) { + uint8_t byte_val = xi->qs[j]; + uint8_t q0 = byte_val & 0x0F; + uint8_t q1 = (byte_val >> 4) & 0x0F; + dst[j] = vmin + d * (float)q0; + dst[j + QK4_1/2] = vmin + d * (float)q1; + } +} static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) { if (x <= val[0]) return 0; @@ -420,6 +505,58 @@ static void ggml_cpy_f32_q5_1_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } +static void ggml_cpy_q5_1_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream) { + const int num_blocks = ne; + cpy_q_f32<<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + +static void ggml_cpy_q5_0_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream) { + const int num_blocks = ne; + cpy_q_f32<<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + +static void ggml_cpy_q4_1_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream) { + const int num_blocks = ne; + cpy_q_f32<<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + +static void ggml_cpy_q4_0_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream) { + const int num_blocks = ne; + cpy_q_f32<<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + static void ggml_cpy_f32_iq4_nl_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -488,14 +625,25 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { @@ -524,14 +672,22 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return (void*) cpy_q_f32; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { return (void*) cpy_f32_f16; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 093ad70991b5a..67bf980edac51 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3073,15 +3073,27 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) { return true; } + if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) { + return true; + } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) { return true; } + if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) { + return true; + } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) { return true; } + if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) { + return true; + } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) { return true; } + if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) { + return true; + } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) { return true; } diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index e3dc25f1686fb..546c7bb5e8cb3 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -84,6 +84,7 @@ typedef struct { } ggml_metal_kargs_repeat; typedef struct { + int64_t ne; int64_t ne00; int64_t ne01; int64_t ne02; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 087e7f58149f6..679251b31fc3b 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -407,6 +407,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, + GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, GGML_METAL_KERNEL_TYPE_CONCAT, GGML_METAL_KERNEL_TYPE_SQR, GGML_METAL_KERNEL_TYPE_SQRT, @@ -1012,6 +1017,11 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); @@ -1287,6 +1297,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex default: return false; } + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + return (op->type == GGML_TYPE_F32); default: return false; }; @@ -1615,7 +1631,10 @@ static void ggml_metal_encode_node( const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; + const int64_t ne = ggml_nelements(src0); + ggml_metal_kargs_cpy args = { + /*.ne =*/ ne, /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, @@ -3899,10 +3918,7 @@ static void ggml_metal_encode_node( case GGML_OP_CPY: case GGML_OP_CONT: { - GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - - int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); - + const int64_t ne = ggml_nelements(src0); id pipeline = nil; switch (src0t) { @@ -3936,13 +3952,33 @@ static void ggml_metal_encode_node( switch (dstt) { case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break; case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break; - default: GGML_ASSERT(false && "not implemented"); + default: GGML_ABORT("not implemented"); }; } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + { + if (dstt == GGML_TYPE_F32) { + switch (src0t) { + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break; + default: GGML_ABORT("not implemented"); + } + } else { + GGML_ABORT("not implemented"); + } + } break; default: GGML_ABORT("not implemented"); } ggml_metal_kargs_cpy args = { + /*.ne =*/ ne, /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, @@ -3966,7 +4002,17 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + int nth; + + if (src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) { + GGML_ASSERT(dstt == GGML_TYPE_F32); + nth = MIN(1024, ne); + } else { + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + nth = MIN(1024, ne00/ggml_blck_size(src0->type)); + } [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_SET: { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 83e7ac9f411ef..da5e24249d98c 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4372,6 +4372,256 @@ kernel void kernel_concat( } } +kernel void kernel_cpy_q4_0_f32( + constant ggml_metal_kargs_cpy & args, + device const char *cx [[ buffer(1) ]], + device char *cdst [[ buffer(2) ]], + uint tid [[ thread_position_in_grid ]] +) +{ + // Compute the global index multiplied by QK, matching: + // i = (blockDim.x*blockIdx.x + threadIdx.x)*qk + const int i = int(tid) * QK4_0; + + // Bounds check + if (i >= args.ne) { + return; + } + + const int i03 = i/(args.ne00 * args.ne01 * args.ne02); + const int i02 = (i - i03*args.ne00*args.ne01*args.ne02 )/ (args.ne00*args.ne01); + const int i01 = (i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00) / args.ne00; + const int i00 = i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00 - i01*args.ne00; + const int x_offset = (i00/QK4_0)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03; + + const int i13 = i/(args.ne0 * args.ne1 * args.ne2); + const int i12 = (i - i13*args.ne0*args.ne1*args.ne2) / (args.ne0*args.ne1); + const int i11 = (i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1) / args.ne0; + const int i10 = i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1 - i11*args.ne0; + const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3; + + device const block_q4_0 * src_block = (device const block_q4_0 *)(cx + x_offset); + device float * dst = (device float *)(cdst + dst_offset); + + float d = float(src_block->d); + const float shift = 8.0f; + + // Unpack 2 x 4-bit values per byte. + for (int j = 0; j < QK4_0/2; j++) { + uint8_t q = src_block->qs[j]; + uint8_t q0 = q & 0x0F; + uint8_t q1 = (q >> 4) & 0x0F; + dst[j] = (float(q0) - shift) * d; + dst[j + QK4_0/2] = (float(q1) - shift) * d; + } +} + +kernel void kernel_cpy_q4_1_f32( + constant ggml_metal_kargs_cpy & args, + device const char *cx [[ buffer(1) ]], + device char *cdst [[ buffer(2) ]], + uint tid [[ thread_position_in_grid ]] +) +{ + // Compute the global index multiplied by QK, matching: + // i = (blockDim.x*blockIdx.x + threadIdx.x)*qk + const int i = int(tid) * QK4_1; + + // Bounds check + if (i >= args.ne) { + return; + } + + const int i03 = i/(args.ne00 * args.ne01 * args.ne02); + const int i02 = (i - i03*args.ne00*args.ne01*args.ne02 )/ (args.ne00*args.ne01); + const int i01 = (i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00) / args.ne00; + const int i00 = i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00 - i01*args.ne00; + const int x_offset = (i00/QK4_1)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03; + + const int i13 = i/(args.ne0 * args.ne1 * args.ne2); + const int i12 = (i - i13*args.ne0*args.ne1*args.ne2) / (args.ne0*args.ne1); + const int i11 = (i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1) / args.ne0; + const int i10 = i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1 - i11*args.ne0; + const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3; + + device const block_q4_1 * src_block = (device const block_q4_1 *)(cx + x_offset); + device float * dst = (device float *)(cdst + dst_offset); + + float d = float(src_block->d); + float vmin = float(src_block->m); + + for (int j = 0; j < QK4_1/2; j++) { + uint8_t q = src_block->qs[j]; + uint8_t q0 = q & 0x0F; + uint8_t q1 = (q >> 4) & 0x0F; + dst[j] = vmin + d * float(q0); + dst[j + QK4_1/2] = vmin + d * float(q1); + } +} + + +kernel void kernel_cpy_q5_0_f32( + constant ggml_metal_kargs_cpy & args, + device const char *cx [[ buffer(1) ]], + device char *cdst [[ buffer(2) ]], + uint tid [[ thread_position_in_grid ]] +) +{ + // Compute the global index multiplied by QK, matching: + // i = (blockDim.x*blockIdx.x + threadIdx.x)*qk + const int i = int(tid) * QK5_0; + + // Bounds check + if (i >= args.ne) { + return; + } + + const int i03 = i/(args.ne00 * args.ne01 * args.ne02); + const int i02 = (i - i03*args.ne00*args.ne01*args.ne02 )/ (args.ne00*args.ne01); + const int i01 = (i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00) / args.ne00; + const int i00 = i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00 - i01*args.ne00; + const int x_offset = (i00/QK5_0)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03; + + const int i13 = i/(args.ne0 * args.ne1 * args.ne2); + const int i12 = (i - i13*args.ne0*args.ne1*args.ne2) / (args.ne0*args.ne1); + const int i11 = (i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1) / args.ne0; + const int i10 = i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1 - i11*args.ne0; + const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3; + + device const block_q5_0 * src_block = (device const block_q5_0 *)(cx + x_offset); + device float * dst = (device float *)(cdst + dst_offset); + + float d = float(src_block->d); + const float shift = 16.f; + + // Combine the four qh bytes into a 32-bit value. + uint32_t qhVal = 0 + | ((uint32_t) src_block->qh[0] << 0) + | ((uint32_t) src_block->qh[1] << 8) + | ((uint32_t) src_block->qh[2] << 16) + | ((uint32_t) src_block->qh[3] << 24); + + // First half + for (int j = 0; j < QK5_0/2; j++) { + uint8_t q = src_block->qs[j]; + uint8_t lowNib = q & 0x0F; + uint8_t highBit = (qhVal >> j) & 0x1; + uint8_t qVal = (highBit << 4) | lowNib; + dst[j] = (float(qVal) - shift) * d; + } + // Second half + for (int j = QK5_0/2; j < QK5_0; j++) { + int k = j - QK5_0/2; + uint8_t q = src_block->qs[k]; + uint8_t hiNib = (q >> 4) & 0x0F; + uint8_t highBit = (qhVal >> j) & 0x1; + uint8_t qVal = (highBit << 4) | hiNib; + dst[j] = (float(qVal) - shift) * d; + } +} + + +kernel void kernel_cpy_q5_1_f32( + constant ggml_metal_kargs_cpy & args, + device const char *cx [[ buffer(1) ]], + device char *cdst [[ buffer(2) ]], + uint tid [[ thread_position_in_grid ]] +) +{ + // Compute the global index multiplied by QK, matching: + // i = (blockDim.x*blockIdx.x + threadIdx.x)*qk + const int i = int(tid) * QK5_1; + + // Bounds check + if (i >= args.ne) { + return; + } + + const int i03 = i/(args.ne00 * args.ne01 * args.ne02); + const int i02 = (i - i03*args.ne00*args.ne01*args.ne02 )/ (args.ne00*args.ne01); + const int i01 = (i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00) / args.ne00; + const int i00 = i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00 - i01*args.ne00; + const int x_offset = (i00/QK5_1)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03; + + const int i13 = i/(args.ne0 * args.ne1 * args.ne2); + const int i12 = (i - i13*args.ne0*args.ne1*args.ne2) / (args.ne0*args.ne1); + const int i11 = (i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1) / args.ne0; + const int i10 = i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1 - i11*args.ne0; + const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3; + + device const block_q5_1 * src_block = (device const block_q5_1 *)(cx + x_offset); + device float * dst = (device float *)(cdst + dst_offset); + + float d = float(src_block->d); + float vmin = float(src_block->m); + + uint32_t qhVal = 0 + | ((uint32_t) src_block->qh[0] << 0) + | ((uint32_t) src_block->qh[1] << 8) + | ((uint32_t) src_block->qh[2] << 16) + | ((uint32_t) src_block->qh[3] << 24); + + // First half + for (int j = 0; j < QK5_1/2; j++) { + uint8_t q = src_block->qs[j]; + uint8_t lowNib = q & 0x0F; + uint8_t highBit = (qhVal >> j) & 0x1; + uint8_t qVal = (highBit << 4) | lowNib; + dst[j] = vmin + d * float(qVal); + } + // Second half + for (int j = QK5_1/2; j < QK5_1; j++) { + int k = j - QK5_1/2; + uint8_t q = src_block->qs[k]; + uint8_t hiNib = (q >> 4) & 0x0F; + uint8_t highBit = (qhVal >> j) & 0x1; + uint8_t qVal = (highBit << 4) | hiNib; + dst[j] = vmin + d * float(qVal); + } +} + +kernel void kernel_cpy_q8_0_f32( + constant ggml_metal_kargs_cpy &args [[ buffer(0) ]], + device const char *cx [[ buffer(1) ]], + device char *cdst [[ buffer(2) ]], + uint tid [[ thread_position_in_grid ]] +) { + // Compute the global index multiplied by QK, matching: + // i = (blockDim.x*blockIdx.x + threadIdx.x)*qk + const int i = int(tid) * QK8_0; + + // Bounds check + if (i >= args.ne) { + return; + } + + const int i03 = i/(args.ne00 * args.ne01 * args.ne02); + const int i02 = (i - i03*args.ne00*args.ne01*args.ne02 )/ (args.ne00*args.ne01); + const int i01 = (i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00) / args.ne00; + const int i00 = i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00 - i01*args.ne00; + const int x_offset = (i00/QK8_0)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03; + + const int i13 = i/(args.ne0 * args.ne1 * args.ne2); + const int i12 = (i - i13*args.ne0*args.ne1*args.ne2) / (args.ne0*args.ne1); + const int i11 = (i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1) / args.ne0; + const int i10 = i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1 - i11*args.ne0; + const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3; + + // Call the device function that performs the copy/dequantization. + // cpy_blck(cx + x_offset, cdst + dst_offset); + device const char * src_block = cx + x_offset; + device char * dst = cdst + dst_offset; + + const device block_q8_0 * xi = (device const block_q8_0 *) src_block; + device float * dsti = (device float *) dst; + + const float d = (float)xi->d; + + for (int j = 0; j < QK8_0; j++) { + dsti[j] = xi->qs[j] * d; + } +} + template void kernel_mul_mv_q2_K_f32_impl( args_t args,