Closed
Description
Prerequisites
Please answer the following questions for yourself before submitting an issue.
- I am running the latest code. Development is very rapid so there are no tagged versions as of now.
- I carefully followed the README.md.
- I searched using keywords relevant to my issue to make sure that I am creating a new issue that is not already open (or closed).
- I reviewed the Discussions, and have a new bug or useful enhancement to share.
Expected Behavior
I‘m a noob of Metal and I tried to implement Q5_0, but it seems there is something wrong at mat_mul. Many thanks!
Here is my code:
ggml-metal.m
void ggml_metal_graph_compute(
struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) {
...
case GGML_OP_MUL_MAT:
...
case GGML_TYPE_Q5_0:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 8;
nth1 = 4;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_0_f32];
} break;
...
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||) {
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
ggml-metal.metal
#define QK5_0 32
#define QR5_0 2
typedef struct {
half d; // delta
uint32_t qh; // 5-th bit of quants
uint8_t qs[QK5_0/2]; // nibbles / quants
} block_q5_0;
...
static void dequantize_row_q5_0(device const block_q5_0 * x, device float * y, int k) {
const int qk = QK5_0;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
const half d = x[i].d;
const uint32_t qh = x[i].qh;;
for (int j = 0; j < qk/2; ++j) {
const uint8_t xh0 = ((qh >> (j + 0)) << 4) & 0x10;
const uint8_t xh1 = ((qh >> (j + 12))) & 0x10;
const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh0) - 16;
const int32_t x1 = ((x[i].qs[j] >> 4) | xh1) - 16;
y[i*qk + j + 0 ] = x0*d;
y[i*qk + j + qk/2] = x1*d;
}
}
}
...
kernel void kernel_get_rows_q5_0(
device const void * src0,
device const int * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb1,
uint tpig[[thread_position_in_grid]]) {
const int i = tpig;
const int r = ((device int32_t *) src1)[i];
dequantize_row_q5_0(
(device const block_q5_0 *) ((device char *) src0 + r*nb01),
(device float *) ((device char *) dst + i*nb1), ne00);
}
...
kernel void kernel_mul_mat_q5_0_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
threadgroup float * sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpig[[thread_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]]) {
const int nb = ne00/QK5_0; // block number of the first matrix
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
const uint nth = tptg.x*tptg.y;
const uint ith = tptg.y*tpitg.x + tpitg.y;
device const block_q5_0 * x = (device const block_q5_0 *) src0 + r0*nb;
device const float * y = (device const float *) src1 + r1*ne10;
sum[ith] = 0.0f;
for (int i = tpitg.x; i < nb; i += tptg.x) {
device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs;
device const float4 * y0p = (device const float4 *) (y + i*QK5_0);
const float d = (float)((x + i)->d);
const uint32_t qh = (x + i)->qh;
const uchar4 x0v = *(x0p + tpitg.y);
const float4 y0v = *(y0p + tpitg.y + 0);
const float4 y1v = *(y0p + tpitg.y + 4);
float acc = 0.0f;
for (int j = 0; j < 4; ++j) {
const int xh0 = ((qh >> (j + 0)) << 4) & 0x10;
const int xh1 = ((qh >> (j + 12))) & 0x10;
const int x0 = ((x0v[j] & 0x0F) | xh0) - 16;
const int x1 = ((x0v[j] >> 4) | xh1) - 16;
const float y0 = y0v[j];
const float y1 = y1v[j];
acc += x0*y0 + x1*y1;
}
sum[ith] += acc*d;
}
// accumulate the sum from all threads in the threadgroup
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = nth/2; i > 0; i /= 2) {
if (ith < i) {
sum[ith] += sum[ith + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
if (ith == 0) {
dst[r1*ne0 + r0] = sum[0];
}
}