63
63
GGML_METAL_DECL_KERNEL (relu);
64
64
GGML_METAL_DECL_KERNEL (gelu);
65
65
GGML_METAL_DECL_KERNEL (soft_max);
66
+ GGML_METAL_DECL_KERNEL (soft_max_4);
66
67
GGML_METAL_DECL_KERNEL (diag_mask_inf);
68
+ GGML_METAL_DECL_KERNEL (diag_mask_inf_8);
67
69
GGML_METAL_DECL_KERNEL (get_rows_f16);
68
70
GGML_METAL_DECL_KERNEL (get_rows_q4_0);
69
71
GGML_METAL_DECL_KERNEL (get_rows_q4_1);
77
79
GGML_METAL_DECL_KERNEL (norm);
78
80
GGML_METAL_DECL_KERNEL (mul_mat_f16_f32);
79
81
GGML_METAL_DECL_KERNEL (mul_mat_f16_f32_1row);
82
+ GGML_METAL_DECL_KERNEL (mul_mat_f16_f32_l4);
80
83
GGML_METAL_DECL_KERNEL (mul_mat_q4_0_f32);
81
84
GGML_METAL_DECL_KERNEL (mul_mat_q4_1_f32);
82
85
GGML_METAL_DECL_KERNEL (mul_mat_q8_0_f32);
@@ -218,7 +221,9 @@ @implementation GGMLMetalClass
218
221
GGML_METAL_ADD_KERNEL (relu);
219
222
GGML_METAL_ADD_KERNEL (gelu);
220
223
GGML_METAL_ADD_KERNEL (soft_max);
224
+ GGML_METAL_ADD_KERNEL (soft_max_4);
221
225
GGML_METAL_ADD_KERNEL (diag_mask_inf);
226
+ GGML_METAL_ADD_KERNEL (diag_mask_inf_8);
222
227
GGML_METAL_ADD_KERNEL (get_rows_f16);
223
228
GGML_METAL_ADD_KERNEL (get_rows_q4_0);
224
229
GGML_METAL_ADD_KERNEL (get_rows_q4_1);
@@ -232,6 +237,7 @@ @implementation GGMLMetalClass
232
237
GGML_METAL_ADD_KERNEL (norm);
233
238
GGML_METAL_ADD_KERNEL (mul_mat_f16_f32);
234
239
GGML_METAL_ADD_KERNEL (mul_mat_f16_f32_1row);
240
+ GGML_METAL_ADD_KERNEL (mul_mat_f16_f32_l4);
235
241
GGML_METAL_ADD_KERNEL (mul_mat_q4_0_f32);
236
242
GGML_METAL_ADD_KERNEL (mul_mat_q4_1_f32);
237
243
GGML_METAL_ADD_KERNEL (mul_mat_q8_0_f32);
@@ -286,7 +292,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
286
292
GGML_METAL_DEL_KERNEL (relu);
287
293
GGML_METAL_DEL_KERNEL (gelu);
288
294
GGML_METAL_DEL_KERNEL (soft_max);
289
- GGML_METAL_DEL_KERNEL (diag_mask_inf);
295
+ GGML_METAL_DEL_KERNEL (soft_max_4);
296
+ GGML_METAL_DEL_KERNEL (diag_mask_inf_8);
290
297
GGML_METAL_DEL_KERNEL (get_rows_f16);
291
298
GGML_METAL_DEL_KERNEL (get_rows_q4_0);
292
299
GGML_METAL_DEL_KERNEL (get_rows_q4_1);
@@ -300,6 +307,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
300
307
GGML_METAL_DEL_KERNEL (norm);
301
308
GGML_METAL_DEL_KERNEL (mul_mat_f16_f32);
302
309
GGML_METAL_DEL_KERNEL (mul_mat_f16_f32_1row);
310
+ GGML_METAL_DEL_KERNEL (mul_mat_f16_f32_l4);
303
311
GGML_METAL_DEL_KERNEL (mul_mat_q4_0_f32);
304
312
GGML_METAL_DEL_KERNEL (mul_mat_q4_1_f32);
305
313
GGML_METAL_DEL_KERNEL (mul_mat_q8_0_f32);
@@ -767,7 +775,7 @@ void ggml_metal_graph_compute(
767
775
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
768
776
[encoder setBytes: &scale length: sizeof (scale) atIndex: 2 ];
769
777
770
- const int64_t n = ggml_nelements (dst);
778
+ const int64_t n = ggml_nelements (dst)/ 4 ;
771
779
772
780
[encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
773
781
} break ;
@@ -779,7 +787,7 @@ void ggml_metal_graph_compute(
779
787
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
780
788
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
781
789
782
- const int64_t n = ggml_nelements (dst);
790
+ const int64_t n = ggml_nelements (dst)/ 4 ;
783
791
784
792
[encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
785
793
} break ;
@@ -799,7 +807,7 @@ void ggml_metal_graph_compute(
799
807
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
800
808
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
801
809
802
- const int64_t n = ggml_nelements (dst);
810
+ const int64_t n = ggml_nelements (dst)/ 4 ;
803
811
804
812
[encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
805
813
} break ;
@@ -813,28 +821,40 @@ void ggml_metal_graph_compute(
813
821
{
814
822
const int nth = 32 ;
815
823
816
- [encoder setComputePipelineState: ctx->pipeline_soft_max];
824
+ if (ne00%4 == 0 ) {
825
+ [encoder setComputePipelineState: ctx->pipeline_soft_max_4];
826
+ } else {
827
+ [encoder setComputePipelineState: ctx->pipeline_soft_max];
828
+ }
817
829
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
818
830
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
819
831
[encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 2 ];
820
832
[encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 3 ];
821
833
[encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 4 ];
822
- [encoder setThreadgroupMemoryLength: nth*sizeof (float ) atIndex: 0 ];
823
834
824
835
[encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
825
836
} break ;
826
837
case GGML_OP_DIAG_MASK_INF:
827
838
{
828
839
const int n_past = ((int32_t *)(dst->op_params ))[0 ];
829
840
830
- [encoder setComputePipelineState: ctx->pipeline_diag_mask_inf];
841
+ if (ne00%8 == 0 ) {
842
+ [encoder setComputePipelineState: ctx->pipeline_diag_mask_inf_8];
843
+ } else {
844
+ [encoder setComputePipelineState: ctx->pipeline_diag_mask_inf];
845
+ }
831
846
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
832
847
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
833
848
[encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 2 ];
834
849
[encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 3 ];
835
850
[encoder setBytes: &n_past length: sizeof (int ) atIndex: 4 ];
836
851
837
- [encoder dispatchThreadgroups: MTLSizeMake (ne00, ne01, ne02) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
852
+ if (ne00%8 == 0 ) {
853
+ [encoder dispatchThreadgroups: MTLSizeMake (ne00*ne01*ne02/8 , 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
854
+ }
855
+ else {
856
+ [encoder dispatchThreadgroups: MTLSizeMake (ne00, ne01, ne02) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
857
+ }
838
858
} break ;
839
859
case GGML_OP_MUL_MAT:
840
860
{
@@ -881,6 +901,7 @@ void ggml_metal_graph_compute(
881
901
} else {
882
902
int nth0 = 32 ;
883
903
int nth1 = 1 ;
904
+ int nrows = 1 ;
884
905
885
906
// use custom matrix x vector kernel
886
907
switch (src0t) {
@@ -890,8 +911,12 @@ void ggml_metal_graph_compute(
890
911
nth1 = 1 ;
891
912
if (ne11 * ne12 < 4 ) {
892
913
[encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32_1row];
914
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0 ) {
915
+ [encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32_l4];
916
+ nrows = ne11;
893
917
} else {
894
918
[encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32];
919
+ nrows = 4 ;
895
920
}
896
921
} break ;
897
922
case GGML_TYPE_Q4_0:
@@ -1012,7 +1037,7 @@ void ggml_metal_graph_compute(
1012
1037
else if (src0t == GGML_TYPE_Q6_K) {
1013
1038
[encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 1 )/2 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1014
1039
} else {
1015
- int64_t ny = (ne11 + 3 )/ 4 ;
1040
+ int64_t ny = (ne11 + nrows - 1 )/nrows ;
1016
1041
[encoder dispatchThreadgroups: MTLSizeMake (ne01, ny, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1017
1042
}
1018
1043
}
0 commit comments