Skip to content

Commit 049a32f

Browse files
committed
metal : normalize mat-vec kernel signatures
1 parent ad7cf37 commit 049a32f

File tree

1 file changed

+147
-84
lines changed

1 file changed

+147
-84
lines changed

ggml-metal.metal

Lines changed: 147 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -920,14 +920,21 @@ kernel void kernel_mul_mv_q4_0_f32(
920920
device const float * src1,
921921
device float * dst,
922922
constant int64_t & ne00,
923-
constant int64_t & ne01[[buffer(4)]],
924-
constant int64_t & ne02[[buffer(5)]],
925-
constant int64_t & ne10[[buffer(9)]],
926-
constant int64_t & ne12[[buffer(11)]],
927-
constant int64_t & ne0 [[buffer(15)]],
928-
constant int64_t & ne1 [[buffer(16)]],
929-
constant uint & r2 [[buffer(17)]],
930-
constant uint & r3 [[buffer(18)]],
923+
constant int64_t & ne01,
924+
constant int64_t & ne02,
925+
constant uint64_t & nb00,
926+
constant uint64_t & nb01,
927+
constant uint64_t & nb02,
928+
constant int64_t & ne10,
929+
constant int64_t & ne11,
930+
constant int64_t & ne12,
931+
constant uint64_t & nb10,
932+
constant uint64_t & nb11,
933+
constant uint64_t & nb12,
934+
constant int64_t & ne0,
935+
constant int64_t & ne1,
936+
constant uint & r2,
937+
constant uint & r3,
931938
uint3 tgpig[[threadgroup_position_in_grid]],
932939
uint tiisg[[thread_index_in_simdgroup]],
933940
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -939,14 +946,21 @@ kernel void kernel_mul_mv_q4_1_f32(
939946
device const float * src1,
940947
device float * dst,
941948
constant int64_t & ne00,
942-
constant int64_t & ne01[[buffer(4)]],
943-
constant int64_t & ne02[[buffer(5)]],
944-
constant int64_t & ne10[[buffer(9)]],
945-
constant int64_t & ne12[[buffer(11)]],
946-
constant int64_t & ne0 [[buffer(15)]],
947-
constant int64_t & ne1 [[buffer(16)]],
948-
constant uint & r2 [[buffer(17)]],
949-
constant uint & r3 [[buffer(18)]],
949+
constant int64_t & ne01,
950+
constant int64_t & ne02,
951+
constant uint64_t & nb00,
952+
constant uint64_t & nb01,
953+
constant uint64_t & nb02,
954+
constant int64_t & ne10,
955+
constant int64_t & ne11,
956+
constant int64_t & ne12,
957+
constant uint64_t & nb10,
958+
constant uint64_t & nb11,
959+
constant uint64_t & nb12,
960+
constant int64_t & ne0,
961+
constant int64_t & ne1,
962+
constant uint & r2,
963+
constant uint & r3,
950964
uint3 tgpig[[threadgroup_position_in_grid]],
951965
uint tiisg[[thread_index_in_simdgroup]],
952966
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -958,14 +972,21 @@ kernel void kernel_mul_mv_q5_0_f32(
958972
device const float * src1,
959973
device float * dst,
960974
constant int64_t & ne00,
961-
constant int64_t & ne01[[buffer(4)]],
962-
constant int64_t & ne02[[buffer(5)]],
963-
constant int64_t & ne10[[buffer(9)]],
964-
constant int64_t & ne12[[buffer(11)]],
965-
constant int64_t & ne0 [[buffer(15)]],
966-
constant int64_t & ne1 [[buffer(16)]],
967-
constant uint & r2 [[buffer(17)]],
968-
constant uint & r3 [[buffer(18)]],
975+
constant int64_t & ne01,
976+
constant int64_t & ne02,
977+
constant uint64_t & nb00,
978+
constant uint64_t & nb01,
979+
constant uint64_t & nb02,
980+
constant int64_t & ne10,
981+
constant int64_t & ne11,
982+
constant int64_t & ne12,
983+
constant uint64_t & nb10,
984+
constant uint64_t & nb11,
985+
constant uint64_t & nb12,
986+
constant int64_t & ne0,
987+
constant int64_t & ne1,
988+
constant uint & r2,
989+
constant uint & r3,
969990
uint3 tgpig[[threadgroup_position_in_grid]],
970991
uint tiisg[[thread_index_in_simdgroup]],
971992
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -977,14 +998,21 @@ kernel void kernel_mul_mv_q5_1_f32(
977998
device const float * src1,
978999
device float * dst,
9791000
constant int64_t & ne00,
980-
constant int64_t & ne01[[buffer(4)]],
981-
constant int64_t & ne02[[buffer(5)]],
982-
constant int64_t & ne10[[buffer(9)]],
983-
constant int64_t & ne12[[buffer(11)]],
984-
constant int64_t & ne0 [[buffer(15)]],
985-
constant int64_t & ne1 [[buffer(16)]],
986-
constant uint & r2 [[buffer(17)]],
987-
constant uint & r3 [[buffer(18)]],
1001+
constant int64_t & ne01,
1002+
constant int64_t & ne02,
1003+
constant uint64_t & nb00,
1004+
constant uint64_t & nb01,
1005+
constant uint64_t & nb02,
1006+
constant int64_t & ne10,
1007+
constant int64_t & ne11,
1008+
constant int64_t & ne12,
1009+
constant uint64_t & nb10,
1010+
constant uint64_t & nb11,
1011+
constant uint64_t & nb12,
1012+
constant int64_t & ne0,
1013+
constant int64_t & ne1,
1014+
constant uint & r2,
1015+
constant uint & r3,
9881016
uint3 tgpig[[threadgroup_position_in_grid]],
9891017
uint tiisg[[thread_index_in_simdgroup]],
9901018
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1082,8 +1110,8 @@ kernel void kernel_mul_mv_q8_0_f32(
10821110
constant uint64_t & nb12,
10831111
constant int64_t & ne0,
10841112
constant int64_t & ne1,
1085-
constant uint & r2 [[buffer(17)]],
1086-
constant uint & r3 [[buffer(18)]],
1113+
constant uint & r2,
1114+
constant uint & r3,
10871115
uint3 tgpig[[threadgroup_position_in_grid]],
10881116
uint tiisg[[thread_index_in_simdgroup]],
10891117
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1189,8 +1217,8 @@ kernel void kernel_mul_mv_f32_f32(
11891217
constant uint64_t & nb12,
11901218
constant int64_t & ne0,
11911219
constant int64_t & ne1,
1192-
constant uint & r2 [[buffer(17)]],
1193-
constant uint & r3 [[buffer(18)]],
1220+
constant uint & r2,
1221+
constant uint & r3,
11941222
uint3 tgpig[[threadgroup_position_in_grid]],
11951223
uint tiisg[[thread_index_in_simdgroup]]) {
11961224
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1216,8 +1244,8 @@ kernel void kernel_mul_mv_f16_f16(
12161244
constant uint64_t & nb12,
12171245
constant int64_t & ne0,
12181246
constant int64_t & ne1,
1219-
constant uint & r2 [[buffer(17)]],
1220-
constant uint & r3 [[buffer(18)]],
1247+
constant uint & r2,
1248+
constant uint & r3,
12211249
uint3 tgpig[[threadgroup_position_in_grid]],
12221250
uint tiisg[[thread_index_in_simdgroup]]) {
12231251

@@ -1353,8 +1381,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
13531381
constant uint64_t & nb12,
13541382
constant int64_t & ne0,
13551383
constant int64_t & ne1,
1356-
constant uint & r2 [[buffer(17)]],
1357-
constant uint & r3 [[buffer(18)]],
1384+
constant uint & r2,
1385+
constant uint & r3,
13581386
uint3 tgpig[[threadgroup_position_in_grid]],
13591387
uint tiisg[[thread_index_in_simdgroup]]) {
13601388
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1459,8 +1487,8 @@ kernel void kernel_mul_mv_f16_f32(
14591487
constant uint64_t & nb12,
14601488
constant int64_t & ne0,
14611489
constant int64_t & ne1,
1462-
constant uint & r2 [[buffer(17)]],
1463-
constant uint & r3 [[buffer(18)]],
1490+
constant uint & r2,
1491+
constant uint & r3,
14641492
uint3 tgpig[[threadgroup_position_in_grid]],
14651493
uint tiisg[[thread_index_in_simdgroup]]) {
14661494
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1485,8 +1513,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
14851513
constant uint64_t & nb12,
14861514
constant int64_t & ne0,
14871515
constant int64_t & ne1,
1488-
constant uint & r2 [[buffer(17)]],
1489-
constant uint & r3 [[buffer(18)]],
1516+
constant uint & r2,
1517+
constant uint & r3,
14901518
uint3 tgpig[[threadgroup_position_in_grid]],
14911519
uint tiisg[[thread_index_in_simdgroup]]) {
14921520

@@ -2576,14 +2604,21 @@ kernel void kernel_mul_mv_q2_K_f32(
25762604
device const float * src1,
25772605
device float * dst,
25782606
constant int64_t & ne00,
2579-
constant int64_t & ne01[[buffer(4)]],
2580-
constant int64_t & ne02[[buffer(5)]],
2581-
constant int64_t & ne10[[buffer(9)]],
2582-
constant int64_t & ne12[[buffer(11)]],
2583-
constant int64_t & ne0 [[buffer(15)]],
2584-
constant int64_t & ne1 [[buffer(16)]],
2585-
constant uint & r2 [[buffer(17)]],
2586-
constant uint & r3 [[buffer(18)]],
2607+
constant int64_t & ne01,
2608+
constant int64_t & ne02,
2609+
constant uint64_t & nb00,
2610+
constant uint64_t & nb01,
2611+
constant uint64_t & nb02,
2612+
constant int64_t & ne10,
2613+
constant int64_t & ne11,
2614+
constant int64_t & ne12,
2615+
constant uint64_t & nb10,
2616+
constant uint64_t & nb11,
2617+
constant uint64_t & nb12,
2618+
constant int64_t & ne0,
2619+
constant int64_t & ne1,
2620+
constant uint & r2,
2621+
constant uint & r3,
25872622
uint3 tgpig[[threadgroup_position_in_grid]],
25882623
uint tiisg[[thread_index_in_simdgroup]],
25892624
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2833,14 +2868,21 @@ kernel void kernel_mul_mv_q3_K_f32(
28332868
device const float * src1,
28342869
device float * dst,
28352870
constant int64_t & ne00,
2836-
constant int64_t & ne01[[buffer(4)]],
2837-
constant int64_t & ne02[[buffer(5)]],
2838-
constant int64_t & ne10[[buffer(9)]],
2839-
constant int64_t & ne12[[buffer(11)]],
2840-
constant int64_t & ne0 [[buffer(15)]],
2841-
constant int64_t & ne1 [[buffer(16)]],
2842-
constant uint & r2 [[buffer(17)]],
2843-
constant uint & r3 [[buffer(18)]],
2871+
constant int64_t & ne01,
2872+
constant int64_t & ne02,
2873+
constant uint64_t & nb00,
2874+
constant uint64_t & nb01,
2875+
constant uint64_t & nb02,
2876+
constant int64_t & ne10,
2877+
constant int64_t & ne11,
2878+
constant int64_t & ne12,
2879+
constant uint64_t & nb10,
2880+
constant uint64_t & nb11,
2881+
constant uint64_t & nb12,
2882+
constant int64_t & ne0,
2883+
constant int64_t & ne1,
2884+
constant uint & r2,
2885+
constant uint & r3,
28442886
uint3 tgpig[[threadgroup_position_in_grid]],
28452887
uint tiisg[[thread_index_in_simdgroup]],
28462888
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3064,14 +3106,21 @@ kernel void kernel_mul_mv_q4_K_f32(
30643106
device const float * src1,
30653107
device float * dst,
30663108
constant int64_t & ne00,
3067-
constant int64_t & ne01[[buffer(4)]],
3068-
constant int64_t & ne02[[buffer(5)]],
3069-
constant int64_t & ne10[[buffer(9)]],
3070-
constant int64_t & ne12[[buffer(11)]],
3071-
constant int64_t & ne0 [[buffer(15)]],
3072-
constant int64_t & ne1 [[buffer(16)]],
3073-
constant uint & r2 [[buffer(17)]],
3074-
constant uint & r3 [[buffer(18)]],
3109+
constant int64_t & ne01,
3110+
constant int64_t & ne02,
3111+
constant uint64_t & nb00,
3112+
constant uint64_t & nb01,
3113+
constant uint64_t & nb02,
3114+
constant int64_t & ne10,
3115+
constant int64_t & ne11,
3116+
constant int64_t & ne12,
3117+
constant uint64_t & nb10,
3118+
constant uint64_t & nb11,
3119+
constant uint64_t & nb12,
3120+
constant int64_t & ne0,
3121+
constant int64_t & ne1,
3122+
constant uint & r2,
3123+
constant uint & r3,
30753124
uint3 tgpig[[threadgroup_position_in_grid]],
30763125
uint tiisg[[thread_index_in_simdgroup]],
30773126
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3263,14 +3312,21 @@ kernel void kernel_mul_mv_q5_K_f32(
32633312
device const float * src1,
32643313
device float * dst,
32653314
constant int64_t & ne00,
3266-
constant int64_t & ne01[[buffer(4)]],
3267-
constant int64_t & ne02[[buffer(5)]],
3268-
constant int64_t & ne10[[buffer(9)]],
3269-
constant int64_t & ne12[[buffer(11)]],
3270-
constant int64_t & ne0 [[buffer(15)]],
3271-
constant int64_t & ne1 [[buffer(16)]],
3272-
constant uint & r2 [[buffer(17)]],
3273-
constant uint & r3 [[buffer(18)]],
3315+
constant int64_t & ne01,
3316+
constant int64_t & ne02,
3317+
constant uint64_t & nb00,
3318+
constant uint64_t & nb01,
3319+
constant uint64_t & nb02,
3320+
constant int64_t & ne10,
3321+
constant int64_t & ne11,
3322+
constant int64_t & ne12,
3323+
constant uint64_t & nb10,
3324+
constant uint64_t & nb11,
3325+
constant uint64_t & nb12,
3326+
constant int64_t & ne0,
3327+
constant int64_t & ne1,
3328+
constant uint & r2,
3329+
constant uint & r3,
32743330
uint3 tgpig[[threadgroup_position_in_grid]],
32753331
uint tiisg[[thread_index_in_simdgroup]],
32763332
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3390,14 +3446,21 @@ kernel void kernel_mul_mv_q6_K_f32(
33903446
device const float * src1,
33913447
device float * dst,
33923448
constant int64_t & ne00,
3393-
constant int64_t & ne01[[buffer(4)]],
3394-
constant int64_t & ne02[[buffer(5)]],
3395-
constant int64_t & ne10[[buffer(9)]],
3396-
constant int64_t & ne12[[buffer(11)]],
3397-
constant int64_t & ne0 [[buffer(15)]],
3398-
constant int64_t & ne1 [[buffer(16)]],
3399-
constant uint & r2 [[buffer(17)]],
3400-
constant uint & r3 [[buffer(18)]],
3449+
constant int64_t & ne01,
3450+
constant int64_t & ne02,
3451+
constant uint64_t & nb00,
3452+
constant uint64_t & nb01,
3453+
constant uint64_t & nb02,
3454+
constant int64_t & ne10,
3455+
constant int64_t & ne11,
3456+
constant int64_t & ne12,
3457+
constant uint64_t & nb10,
3458+
constant uint64_t & nb11,
3459+
constant uint64_t & nb12,
3460+
constant int64_t & ne0,
3461+
constant int64_t & ne1,
3462+
constant uint & r2,
3463+
constant uint & r3,
34013464
uint3 tgpig[[threadgroup_position_in_grid]],
34023465
uint tiisg[[thread_index_in_simdgroup]],
34033466
uint sgitg[[simdgroup_index_in_threadgroup]]) {

0 commit comments

Comments
 (0)