Skip to content

[Requirement] Implement Metal support for Q5_0 #2229

Closed
@Kaguya-19

Description

@Kaguya-19

Prerequisites

Please answer the following questions for yourself before submitting an issue.

  • I am running the latest code. Development is very rapid so there are no tagged versions as of now.
  • I carefully followed the README.md.
  • I searched using keywords relevant to my issue to make sure that I am creating a new issue that is not already open (or closed).
  • I reviewed the Discussions, and have a new bug or useful enhancement to share.

Expected Behavior

I‘m a noob of Metal and I tried to implement Q5_0, but it seems there is something wrong at mat_mul. Many thanks!

Here is my code:
ggml-metal.m

void ggml_metal_graph_compute(
        struct ggml_metal_context * ctx,
               struct ggml_cgraph * gf) {
...
case GGML_OP_MUL_MAT:
...
        case GGML_TYPE_Q5_0:
                                        {
                                            GGML_ASSERT(ne02 == 1);
                                            GGML_ASSERT(ne12 == 1);

                                            nth0 = 8;
                                            nth1 = 4;
                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_0_f32];
                                        } break;
...
if (src0t == GGML_TYPE_Q4_0 ||  src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||) {
                                    [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                }

ggml-metal.metal

#define QK5_0 32
#define QR5_0 2
typedef struct {
    half    d;             // delta
    uint32_t qh;         // 5-th bit of quants
    uint8_t qs[QK5_0/2]; // nibbles / quants
} block_q5_0;
...
static void dequantize_row_q5_0(device const block_q5_0 * x, device float * y, int k) {
    const int qk = QK5_0;

    assert(k % qk == 0);

    const int nb = k / qk;

    

    for (int i = 0; i < nb; i++) {
        const half d = x[i].d;
        const uint32_t qh = x[i].qh;;

        for (int j = 0; j < qk/2; ++j) {
            const uint8_t xh0 = ((qh >> (j + 0)) << 4) & 0x10;
            const uint8_t xh1 = ((qh >> (j + 12))) & 0x10;

            const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh0) - 16;
            const int32_t x1 = ((x[i].qs[j] >>   4) | xh1) - 16;

            y[i*qk + j + 0   ] = x0*d;
            y[i*qk + j + qk/2] = x1*d;
        }
    }
}
...
kernel void kernel_get_rows_q5_0(
        device const  void * src0,
        device const   int * src1,
        device       float * dst,
        constant   int64_t & ne00,
        constant  uint64_t & nb01,
        constant  uint64_t & nb1,
        uint tpig[[thread_position_in_grid]]) {
    const int i = tpig;
    const int r = ((device int32_t *) src1)[i];

    dequantize_row_q5_0(
            (device const block_q5_0 *) ((device char *) src0 + r*nb01),
                       (device float *) ((device char *)  dst + i*nb1), ne00);
}
...
kernel void kernel_mul_mat_q5_0_f32(
        device const  void * src0,
        device const  float * src1,
        device       float * dst,
        constant   int64_t & ne00,
        constant   int64_t & ne01,
        constant  uint64_t & nb00,
        constant  uint64_t & nb01,
        constant  uint64_t & nb02,
        constant   int64_t & ne10,
        constant   int64_t & ne11,
        constant  uint64_t & nb10,
        constant  uint64_t & nb11,
        constant  uint64_t & nb12,
        constant   int64_t & ne0,
        constant   int64_t & ne1,
        threadgroup float  * sum [[threadgroup(0)]],
        uint3 tgpig[[threadgroup_position_in_grid]],
        uint3  tpig[[thread_position_in_grid]],
        uint3 tpitg[[thread_position_in_threadgroup]],
        uint3  tptg[[threads_per_threadgroup]]) {

    const int nb = ne00/QK5_0; // block number of the first matrix

    const int64_t r0 = tgpig.x;
    const int64_t r1 = tgpig.y;

    const uint nth = tptg.x*tptg.y;
    const uint ith = tptg.y*tpitg.x + tpitg.y;

    device const block_q5_0  * x = (device const block_q5_0 *) src0 + r0*nb;
    device const float * y = (device const float *) src1 + r1*ne10;

    sum[ith] = 0.0f;

    for (int i = tpitg.x; i < nb; i += tptg.x) {
        device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs;
        device const float4 * y0p = (device const float4 *) (y + i*QK5_0);

        const float d = (float)((x + i)->d);
        const uint32_t qh = (x + i)->qh;

        const uchar4 x0v = *(x0p + tpitg.y);
        const float4 y0v = *(y0p + tpitg.y + 0);
        const float4 y1v = *(y0p + tpitg.y + 4);

        float acc = 0.0f;

        for (int j = 0; j < 4; ++j) {
            const int xh0 = ((qh >> (j + 0)) << 4) & 0x10;
            const int xh1 = ((qh >> (j + 12))) & 0x10;

            const int x0 = ((x0v[j] & 0x0F) | xh0) - 16;
            const int x1 = ((x0v[j] >>   4) | xh1) - 16;

            const float y0 = y0v[j];
            const float y1 = y1v[j];

            acc += x0*y0 + x1*y1;
        }

        sum[ith] += acc*d;
    }

    // accumulate the sum from all threads in the threadgroup
    threadgroup_barrier(mem_flags::mem_threadgroup);
    for (uint i = nth/2; i > 0; i /= 2) {
       if (ith < i) {
           sum[ith] += sum[ith + i];
       }
       threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    if (ith == 0) {
       dst[r1*ne0 + r0] = sum[0];
    }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions