@@ -377,29 +377,27 @@ int llama_mtl_eval(
377
377
id <MTLBuffer > id_src1 = llama_mtl_get_buffer (ctx, gf->nodes [i]->src1 , &offs_src1);
378
378
id <MTLBuffer > id_dst = llama_mtl_get_buffer (ctx, gf->nodes [i], &offs_dst);
379
379
380
- const int64_t ncols0 = gf->nodes [i]->src0 ->ne [0 ];
381
- const int64_t nrows0 = gf->nodes [i]->src0 ->ne [1 ];
382
-
383
- const int64_t ncols1 = gf->nodes [i]->src1 ->ne [0 ];
384
- const int64_t nrows1 = gf->nodes [i]->src1 ->ne [1 ];
385
-
386
- const int64_t ncols = gf->nodes [i]->ne [0 ];
387
- const int64_t nrows = gf->nodes [i]->ne [1 ];
380
+ const int64_t ne00 = gf->nodes [i]->src0 ->ne [0 ];
381
+ const int64_t ne01 = gf->nodes [i]->src0 ->ne [1 ];
382
+ const int64_t ne10 = gf->nodes [i]->src1 ->ne [0 ];
383
+ const int64_t ne11 = gf->nodes [i]->src1 ->ne [1 ];
384
+ const int64_t ne0 = gf->nodes [i]->ne [0 ];
385
+ const int64_t ne1 = gf->nodes [i]->ne [1 ];
388
386
389
387
[encoder setComputePipelineState: ctx->pipeline_mul_mat_q4_0];
390
388
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
391
389
[encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
392
390
[encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
393
- [encoder setBytes: &ncols0 length: sizeof (ncols0 ) atIndex: 3 ];
394
- [encoder setBytes: &nrows0 length: sizeof (nrows0 ) atIndex: 4 ];
395
- [encoder setBytes: &ncols1 length: sizeof (ncols1 ) atIndex: 5 ];
396
- [encoder setBytes: &nrows1 length: sizeof (nrows1 ) atIndex: 6 ];
397
- [encoder setBytes: &ncols length: sizeof (ncols ) atIndex: 7 ];
398
- [encoder setBytes: &nrows length: sizeof (nrows ) atIndex: 8 ];
391
+ [encoder setBytes: &ne00 length: sizeof (ne00 ) atIndex: 3 ];
392
+ [encoder setBytes: &ne00 length: sizeof (ne00 ) atIndex: 4 ];
393
+ [encoder setBytes: &ne11 length: sizeof (ne11 ) atIndex: 5 ];
394
+ [encoder setBytes: &ne11 length: sizeof (ne11 ) atIndex: 6 ];
395
+ [encoder setBytes: &ne0 length: sizeof (ne0 ) atIndex: 7 ];
396
+ [encoder setBytes: &ne1 length: sizeof (ne1 ) atIndex: 8 ];
399
397
400
- printf (" mul_mat: %lld x%lld * %lld x%lld -> %lld x%lld \n " , ncols0, nrows0, ncols1, nrows1, ncols, nrows );
398
+ printf (" mul_mat: %lld x%lld * %lld x%lld -> %lld x%lld \n " , ne00, ne01, ne10, ne11, ne0, ne1 );
401
399
402
- [encoder dispatchThreadgroups: MTLSizeMake (nrows0, nrows1 , 1 ) threadsPerThreadgroup: MTLSizeMake (32 , 1 , 1 )];
400
+ [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne11 , 1 ) threadsPerThreadgroup: MTLSizeMake (32 , 1 , 1 )];
403
401
} break ;
404
402
case GGML_OP_GET_ROWS:
405
403
{
0 commit comments