58
58
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
59
59
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
60
60
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
61
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
61
62
GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
62
63
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
63
64
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
83
84
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
84
85
GGML_METAL_KERNEL_TYPE_NORM,
85
86
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
87
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
88
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
89
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
90
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
86
91
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
87
92
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
88
93
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
131
136
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
132
137
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
133
138
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
139
+ GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
134
140
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
135
141
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
136
142
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
194
200
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
195
201
// GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
196
202
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
203
+ GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
197
204
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
205
+ GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
198
206
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
199
207
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
200
208
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@@ -514,6 +522,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
514
522
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true );
515
523
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true );
516
524
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true );
525
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, true );
517
526
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true );
518
527
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true );
519
528
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true );
@@ -539,6 +548,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
539
548
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction );
540
549
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NORM, norm, true );
541
550
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction );
551
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, ctx->support_simdgroup_reduction );
552
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, ctx->support_simdgroup_reduction );
553
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, ctx->support_simdgroup_reduction );
554
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, ctx->support_simdgroup_reduction );
542
555
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction );
543
556
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction );
544
557
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction );
@@ -587,6 +600,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
587
600
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction );
588
601
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction );
589
602
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm );
603
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, ctx->support_simdgroup_mm );
590
604
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm );
591
605
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm );
592
606
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm );
@@ -649,8 +663,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
649
663
// GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
650
664
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction );
651
665
// GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
666
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, true );
652
667
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true );
653
668
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true );
669
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, true );
654
670
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true );
655
671
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true );
656
672
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true );
@@ -736,8 +752,13 @@ static void ggml_metal_free(struct ggml_metal_context * ctx) {
736
752
737
753
static bool ggml_metal_supports_op (const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
738
754
for (size_t i = 0 , n = 3 ; i < n; ++i) {
739
- if (op->src [i] != NULL && op->src [i]->type == GGML_TYPE_BF16) {
740
- return false ;
755
+ if (op->src [i] != NULL && op->src [i]->type == GGML_TYPE_BF16 &&
756
+ op->op != GGML_OP_GET_ROWS &&
757
+ op->op != GGML_OP_MUL_MAT &&
758
+ op->op != GGML_OP_VIEW &&
759
+ op->op != GGML_OP_CPY) {
760
+ printf (" op = %s , src[%zu ] = %s \n " , ggml_op_name (op->op ), i, ggml_type_name (op->src [i]->type ));
761
+ GGML_ASSERT (false );
741
762
}
742
763
}
743
764
@@ -811,6 +832,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
811
832
case GGML_TYPE_F32:
812
833
switch (op->type ) {
813
834
case GGML_TYPE_F32:
835
+ case GGML_TYPE_BF16:
814
836
case GGML_TYPE_F16:
815
837
case GGML_TYPE_Q8_0:
816
838
case GGML_TYPE_Q4_0:
@@ -830,6 +852,14 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
830
852
default :
831
853
return false ;
832
854
}
855
+ case GGML_TYPE_BF16:
856
+ switch (op->type ) {
857
+ case GGML_TYPE_F32:
858
+ case GGML_TYPE_F16:
859
+ return true ;
860
+ default :
861
+ return false ;
862
+ }
833
863
default :
834
864
return false ;
835
865
};
@@ -1581,6 +1611,7 @@ static enum ggml_status ggml_metal_graph_compute(
1581
1611
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1582
1612
switch (src0->type ) {
1583
1613
case GGML_TYPE_F32: GGML_ASSERT (nb01 % 16 == 0 ); break ;
1614
+ case GGML_TYPE_BF16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1584
1615
case GGML_TYPE_F16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1585
1616
default : break ;
1586
1617
}
@@ -1589,6 +1620,7 @@ static enum ggml_status ggml_metal_graph_compute(
1589
1620
1590
1621
switch (src0->type ) {
1591
1622
case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline ; break ;
1623
+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline ; break ;
1592
1624
case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline ; break ;
1593
1625
case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline ; break ;
1594
1626
case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline ; break ;
@@ -1665,6 +1697,25 @@ static enum ggml_status ggml_metal_graph_compute(
1665
1697
nrows = 4 ;
1666
1698
}
1667
1699
} break ;
1700
+ case GGML_TYPE_BF16:
1701
+ {
1702
+ nth0 = 32 ;
1703
+ nth1 = 1 ;
1704
+ if (src1t == GGML_TYPE_F32) {
1705
+ if (ne11 * ne12 < 4 ) {
1706
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline ;
1707
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0 ) {
1708
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline ;
1709
+ nrows = ne11;
1710
+ } else {
1711
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline ;
1712
+ nrows = 4 ;
1713
+ }
1714
+ } else {
1715
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline ;
1716
+ nrows = 4 ;
1717
+ }
1718
+ } break ;
1668
1719
case GGML_TYPE_Q4_0:
1669
1720
{
1670
1721
nth0 = 8 ;
@@ -2161,6 +2212,7 @@ static enum ggml_status ggml_metal_graph_compute(
2161
2212
2162
2213
switch (src0->type ) {
2163
2214
case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline ; break ;
2215
+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline ; break ;
2164
2216
case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline ; break ;
2165
2217
case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline ; break ;
2166
2218
case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline ; break ;
@@ -2776,6 +2828,7 @@ static enum ggml_status ggml_metal_graph_compute(
2776
2828
2777
2829
switch (dstt) {
2778
2830
case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline ; break ;
2831
+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline ; break ;
2779
2832
case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline ; break ;
2780
2833
case GGML_TYPE_Q8_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline ; break ;
2781
2834
case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline ; break ;
@@ -2794,6 +2847,13 @@ static enum ggml_status ggml_metal_graph_compute(
2794
2847
default : GGML_ASSERT (false && " not implemented" );
2795
2848
};
2796
2849
} break ;
2850
+ case GGML_TYPE_BF16:
2851
+ {
2852
+ switch (dstt) {
2853
+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline ; break ;
2854
+ default : GGML_ASSERT (false && " not implemented" );
2855
+ };
2856
+ } break ;
2797
2857
default : GGML_ASSERT (false && " not implemented" );
2798
2858
}
2799
2859
0 commit comments