Skip to content

Commit 96d0052

Browse files
committed
mtl : mul_mat fixes (still wrong)
1 parent 2a24994 commit 96d0052

File tree

2 files changed

+28
-33
lines changed

2 files changed

+28
-33
lines changed

examples/mtl/mtl.m

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -377,29 +377,27 @@ int llama_mtl_eval(
377377
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
378378
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
379379

380-
const int64_t ncols0 = gf->nodes[i]->src0->ne[0];
381-
const int64_t nrows0 = gf->nodes[i]->src0->ne[1];
382-
383-
const int64_t ncols1 = gf->nodes[i]->src1->ne[0];
384-
const int64_t nrows1 = gf->nodes[i]->src1->ne[1];
385-
386-
const int64_t ncols = gf->nodes[i]->ne[0];
387-
const int64_t nrows = gf->nodes[i]->ne[1];
380+
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
381+
const int64_t ne01 = gf->nodes[i]->src0->ne[1];
382+
const int64_t ne10 = gf->nodes[i]->src1->ne[0];
383+
const int64_t ne11 = gf->nodes[i]->src1->ne[1];
384+
const int64_t ne0 = gf->nodes[i]->ne[0];
385+
const int64_t ne1 = gf->nodes[i]->ne[1];
388386

389387
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0];
390388
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
391389
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
392390
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
393-
[encoder setBytes:&ncols0 length:sizeof(ncols0) atIndex:3];
394-
[encoder setBytes:&nrows0 length:sizeof(nrows0) atIndex:4];
395-
[encoder setBytes:&ncols1 length:sizeof(ncols1) atIndex:5];
396-
[encoder setBytes:&nrows1 length:sizeof(nrows1) atIndex:6];
397-
[encoder setBytes:&ncols length:sizeof(ncols) atIndex:7];
398-
[encoder setBytes:&nrows length:sizeof(nrows) atIndex:8];
391+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
392+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
393+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:5];
394+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:6];
395+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:7];
396+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:8];
399397

400-
printf("mul_mat: %lldx%lld * %lldx%lld -> %lldx%lld\n", ncols0, nrows0, ncols1, nrows1, ncols, nrows);
398+
printf("mul_mat: %lldx%lld * %lldx%lld -> %lldx%lld\n", ne00, ne01, ne10, ne11, ne0, ne1);
401399

402-
[encoder dispatchThreadgroups:MTLSizeMake(nrows0, nrows1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
400+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
403401
} break;
404402
case GGML_OP_GET_ROWS:
405403
{

examples/mtl/mtl.metal

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -144,19 +144,16 @@ kernel void kernel_mul_mat_q4_0(
144144
sum[tpitg.x] = 0.0f;
145145

146146
for (int i = 0; i < nb; i += tptg.x) {
147-
device const uint4 * x0p = (device const uint4 *) (x + i);
147+
device const uint4 * x0p = (device const uint4 *) (x + i)->qs;
148148
device const float4 * y0p = (device const float4 *) (y + i*qk);
149149

150150
const uint4 x0 = *x0p;
151151

152-
const uint4 x0l = x0 & uint4(0x0F0F0F0F);
153-
const uint4 x0h = x0 >> 4;
152+
const uint4 x0l = (x0 & uint4(0x0F0F0F0F));
153+
const uint4 x0h = (x0 & uint4(0xF0F0F0F0)) >> 4;
154154

155-
const int4 x0ls = as_type<int4>(x0l) - int4(8);
156-
const int4 x0hs = as_type<int4>(x0h) - int4(8);
157-
158-
thread const uchar * x0lsb = (thread const uchar *) &x0ls;
159-
thread const uchar * x0hsb = (thread const uchar *) &x0hs;
155+
thread const char * x0lsb = (thread const char *) &x0l;
156+
thread const char * x0hsb = (thread const char *) &x0h;
160157

161158
const float4 y00 = *(y0p + 0);
162159
const float4 y01 = *(y0p + 1);
@@ -167,17 +164,17 @@ kernel void kernel_mul_mat_q4_0(
167164
const float4 y06 = *(y0p + 6);
168165
const float4 y07 = *(y0p + 7);
169166

170-
const float d = (x + i)->d;
167+
const half d = (x + i)->d;
171168

172169
sum[tpitg.x] += (
173-
x0lsb[ 0]*y00[0] + x0lsb[ 1]*y00[1] + x0lsb[ 2]*y00[2] + x0lsb[ 3]*y00[3] +
174-
x0lsb[ 4]*y01[0] + x0lsb[ 5]*y01[1] + x0lsb[ 6]*y01[2] + x0lsb[ 7]*y01[3] +
175-
x0lsb[ 8]*y02[0] + x0lsb[ 9]*y02[1] + x0lsb[10]*y02[2] + x0lsb[11]*y02[3] +
176-
x0lsb[12]*y03[0] + x0lsb[13]*y03[1] + x0lsb[14]*y03[2] + x0lsb[15]*y03[3] +
177-
x0hsb[ 0]*y04[0] + x0hsb[ 1]*y04[1] + x0hsb[ 2]*y04[2] + x0hsb[ 3]*y04[3] +
178-
x0hsb[ 4]*y05[0] + x0hsb[ 5]*y05[1] + x0hsb[ 6]*y05[2] + x0hsb[ 7]*y05[3] +
179-
x0hsb[ 8]*y06[0] + x0hsb[ 9]*y06[1] + x0hsb[10]*y06[2] + x0hsb[11]*y06[3] +
180-
x0hsb[12]*y07[0] + x0hsb[13]*y07[1] + x0hsb[14]*y07[2] + x0hsb[15]*y07[3]
170+
(x0lsb[ 0] - 8)*y00[0] + (x0lsb[ 1] - 8)*y00[1] + (x0lsb[ 2] - 8)*y00[2] + (x0lsb[ 3] - 8)*y00[3] +
171+
(x0lsb[ 4] - 8)*y01[0] + (x0lsb[ 5] - 8)*y01[1] + (x0lsb[ 6] - 8)*y01[2] + (x0lsb[ 7] - 8)*y01[3] +
172+
(x0lsb[ 8] - 8)*y02[0] + (x0lsb[ 9] - 8)*y02[1] + (x0lsb[10] - 8)*y02[2] + (x0lsb[11] - 8)*y02[3] +
173+
(x0lsb[12] - 8)*y03[0] + (x0lsb[13] - 8)*y03[1] + (x0lsb[14] - 8)*y03[2] + (x0lsb[15] - 8)*y03[3] +
174+
(x0hsb[ 0] - 8)*y04[0] + (x0hsb[ 1] - 8)*y04[1] + (x0hsb[ 2] - 8)*y04[2] + (x0hsb[ 3] - 8)*y04[3] +
175+
(x0hsb[ 4] - 8)*y05[0] + (x0hsb[ 5] - 8)*y05[1] + (x0hsb[ 6] - 8)*y05[2] + (x0hsb[ 7] - 8)*y05[3] +
176+
(x0hsb[ 8] - 8)*y06[0] + (x0hsb[ 9] - 8)*y06[1] + (x0hsb[10] - 8)*y06[2] + (x0hsb[11] - 8)*y06[3] +
177+
(x0hsb[12] - 8)*y07[0] + (x0hsb[13] - 8)*y07[1] + (x0hsb[14] - 8)*y07[2] + (x0hsb[15] - 8)*y07[3]
181178
) * d;
182179
}
183180

0 commit comments

Comments
 (0)