@@ -1390,8 +1390,8 @@ static enum ggml_status ggml_metal_graph_compute(
1390
1390
const int64_t nrows_x = ggml_nrows (src0);
1391
1391
const int64_t nrows_y = src0->ne [1 ];
1392
1392
1393
- const uint32_t n_head_kv = nrows_x/nrows_y;
1394
- const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head_kv ));
1393
+ const uint32_t n_head = nrows_x/nrows_y;
1394
+ const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head ));
1395
1395
1396
1396
const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
1397
1397
const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
@@ -2513,7 +2513,7 @@ static enum ggml_status ggml_metal_graph_compute(
2513
2513
" the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big" );
2514
2514
2515
2515
const int64_t ne30 = src3 ? src3->ne [0 ] : 0 ; GGML_UNUSED (ne30);
2516
- const int64_t ne31 = src3 ? src3->ne [1 ] : 0 ;
2516
+ // const int64_t ne31 = src3 ? src3->ne[1] : 0;
2517
2517
const int64_t ne32 = src3 ? src3->ne [2 ] : 0 ; GGML_UNUSED (ne32);
2518
2518
const int64_t ne33 = src3 ? src3->ne [3 ] : 0 ; GGML_UNUSED (ne33);
2519
2519
@@ -2525,7 +2525,16 @@ static enum ggml_status ggml_metal_graph_compute(
2525
2525
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED (src2t);
2526
2526
2527
2527
float scale;
2528
- memcpy (&scale, dst->op_params , sizeof (float ));
2528
+ float max_bias;
2529
+
2530
+ memcpy (&scale, ((int32_t *) dst->op_params ) + 0 , sizeof (scale));
2531
+ memcpy (&max_bias, ((int32_t *) dst->op_params ) + 1 , sizeof (max_bias));
2532
+
2533
+ const uint32_t n_head = src0->ne [2 ];
2534
+ const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
2535
+
2536
+ const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
2537
+ const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
2529
2538
2530
2539
id <MTLComputePipelineState > pipeline = nil ;
2531
2540
@@ -2562,34 +2571,37 @@ static enum ggml_status ggml_metal_graph_compute(
2562
2571
}
2563
2572
2564
2573
[encoder setComputePipelineState: pipeline];
2565
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2566
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
2567
- [encoder setBuffer: id_src2 offset: offs_src2 atIndex: 2 ];
2568
- [encoder setBuffer: id_src3 offset: offs_src3 atIndex: 3 ];
2569
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 4 ];
2570
- [encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 5 ];
2571
- [encoder setBytes: &ne01 length: sizeof ( int64_t ) atIndex: 6 ];
2572
- [encoder setBytes: &ne02 length: sizeof ( int64_t ) atIndex: 7 ];
2573
- [encoder setBytes: &ne03 length: sizeof ( int64_t ) atIndex: 8 ];
2574
- [encoder setBytes: &nb00 length: sizeof (uint64_t ) atIndex: 9 ];
2575
- [encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 10 ];
2576
- [encoder setBytes: &nb02 length: sizeof (uint64_t ) atIndex: 11 ];
2577
- [encoder setBytes: &nb03 length: sizeof (uint64_t ) atIndex: 12 ];
2578
- [encoder setBytes: &ne10 length: sizeof ( int64_t ) atIndex: 13 ];
2579
- [encoder setBytes: &ne11 length: sizeof ( int64_t ) atIndex: 14 ];
2580
- [encoder setBytes: &ne12 length: sizeof ( int64_t ) atIndex: 15 ];
2581
- [encoder setBytes: &ne13 length: sizeof ( int64_t ) atIndex: 16 ];
2582
- [encoder setBytes: &nb10 length: sizeof (uint64_t ) atIndex: 17 ];
2583
- [encoder setBytes: &nb11 length: sizeof (uint64_t ) atIndex: 18 ];
2584
- [encoder setBytes: &nb12 length: sizeof (uint64_t ) atIndex: 19 ];
2585
- [encoder setBytes: &nb13 length: sizeof (uint64_t ) atIndex: 20 ];
2586
- [encoder setBytes: &ne31 length: sizeof ( int64_t ) atIndex: 21 ];
2587
- [encoder setBytes: &nb31 length: sizeof (uint64_t ) atIndex: 22 ];
2588
- [encoder setBytes: &ne0 length: sizeof ( int64_t ) atIndex: 23 ];
2589
- [encoder setBytes: &ne1 length: sizeof ( int64_t ) atIndex: 24 ];
2590
- [encoder setBytes: &ne2 length: sizeof ( int64_t ) atIndex: 25 ];
2591
- [encoder setBytes: &ne3 length: sizeof ( int64_t ) atIndex: 26 ];
2592
- [encoder setBytes: &scale length: sizeof ( float ) atIndex: 27 ];
2574
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2575
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
2576
+ [encoder setBuffer: id_src2 offset: offs_src2 atIndex: 2 ];
2577
+ [encoder setBuffer: id_src3 offset: offs_src3 atIndex: 3 ];
2578
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 4 ];
2579
+ [encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 5 ];
2580
+ [encoder setBytes: &ne01 length: sizeof ( int64_t ) atIndex: 6 ];
2581
+ [encoder setBytes: &ne02 length: sizeof ( int64_t ) atIndex: 7 ];
2582
+ [encoder setBytes: &ne03 length: sizeof ( int64_t ) atIndex: 8 ];
2583
+ [encoder setBytes: &nb00 length: sizeof (uint64_t ) atIndex: 9 ];
2584
+ [encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 10 ];
2585
+ [encoder setBytes: &nb02 length: sizeof (uint64_t ) atIndex: 11 ];
2586
+ [encoder setBytes: &nb03 length: sizeof (uint64_t ) atIndex: 12 ];
2587
+ [encoder setBytes: &ne10 length: sizeof ( int64_t ) atIndex: 13 ];
2588
+ [encoder setBytes: &ne11 length: sizeof ( int64_t ) atIndex: 14 ];
2589
+ [encoder setBytes: &ne12 length: sizeof ( int64_t ) atIndex: 15 ];
2590
+ [encoder setBytes: &ne13 length: sizeof ( int64_t ) atIndex: 16 ];
2591
+ [encoder setBytes: &nb10 length: sizeof (uint64_t ) atIndex: 17 ];
2592
+ [encoder setBytes: &nb11 length: sizeof (uint64_t ) atIndex: 18 ];
2593
+ [encoder setBytes: &nb12 length: sizeof (uint64_t ) atIndex: 19 ];
2594
+ [encoder setBytes: &nb13 length: sizeof (uint64_t ) atIndex: 20 ];
2595
+ [encoder setBytes: &nb31 length: sizeof (uint64_t ) atIndex: 21 ];
2596
+ [encoder setBytes: &ne0 length: sizeof ( int64_t ) atIndex: 22 ];
2597
+ [encoder setBytes: &ne1 length: sizeof ( int64_t ) atIndex: 23 ];
2598
+ [encoder setBytes: &ne2 length: sizeof ( int64_t ) atIndex: 24 ];
2599
+ [encoder setBytes: &ne3 length: sizeof ( int64_t ) atIndex: 25 ];
2600
+ [encoder setBytes: &scale length: sizeof ( float ) atIndex: 26 ];
2601
+ [encoder setBytes: &max_bias length: sizeof ( float ) atIndex: 27 ];
2602
+ [encoder setBytes: &m0 length: sizeof (m0) atIndex: 28 ];
2603
+ [encoder setBytes: &m1 length: sizeof (m1) atIndex: 29 ];
2604
+ [encoder setBytes: &n_head_log2 length: sizeof (n_head_log2) atIndex: 30 ];
2593
2605
2594
2606
if (!use_vec_kernel) {
2595
2607
// half8x8 kernel
0 commit comments