@@ -120,6 +120,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
120
120
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
121
121
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
122
122
GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
123
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
123
124
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
124
125
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
125
126
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
@@ -150,6 +151,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
150
151
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
151
152
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
152
153
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
154
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
155
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
156
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
157
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
153
158
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
154
159
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
155
160
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
@@ -195,6 +200,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
195
200
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
196
201
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
197
202
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
203
+ GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
198
204
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
199
205
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
200
206
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
@@ -300,8 +306,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
300
306
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
301
307
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
302
308
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
309
+ GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
303
310
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
304
311
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
312
+ GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
305
313
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
306
314
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
307
315
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
@@ -615,6 +623,7 @@ @implementation GGMLMetalClass
615
623
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
616
624
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
617
625
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
626
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, true);
618
627
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
619
628
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
620
629
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
@@ -641,6 +650,10 @@ @implementation GGMLMetalClass
641
650
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
642
651
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
643
652
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, support_simdgroup_reduction);
653
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, support_simdgroup_reduction);
654
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, support_simdgroup_reduction);
655
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, support_simdgroup_reduction);
656
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, support_simdgroup_reduction);
644
657
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, support_simdgroup_reduction);
645
658
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, support_simdgroup_reduction);
646
659
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction);
@@ -690,6 +703,7 @@ @implementation GGMLMetalClass
690
703
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, support_simdgroup_reduction);
691
704
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, support_simdgroup_mm);
692
705
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, support_simdgroup_mm);
706
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, support_simdgroup_mm);
693
707
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, support_simdgroup_mm);
694
708
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, support_simdgroup_mm);
695
709
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, support_simdgroup_mm);
@@ -793,10 +807,12 @@ @implementation GGMLMetalClass
793
807
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, support_simdgroup_reduction);
794
808
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, support_simdgroup_reduction);
795
809
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, support_simdgroup_reduction);
796
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
797
810
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
798
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
811
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
812
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, true);
799
813
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
814
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
815
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, true);
800
816
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
801
817
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
802
818
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
@@ -887,8 +903,13 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
887
903
888
904
static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
889
905
for (size_t i = 0, n = 3; i < n; ++i) {
890
- if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
891
- return false;
906
+ if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16 &&
907
+ op->op != GGML_OP_GET_ROWS &&
908
+ op->op != GGML_OP_MUL_MAT &&
909
+ op->op != GGML_OP_VIEW &&
910
+ op->op != GGML_OP_CPY) {
911
+ GGML_LOG_ERROR("unsupported BF16 op = %s, src[%zu] = %s\n", ggml_op_name(op->op), i, ggml_type_name(op->src[i]->type));
912
+ GGML_ASSERT(false);
892
913
}
893
914
}
894
915
@@ -969,6 +990,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
969
990
switch (op->type) {
970
991
case GGML_TYPE_F32:
971
992
case GGML_TYPE_F16:
993
+ case GGML_TYPE_BF16:
972
994
case GGML_TYPE_Q8_0:
973
995
case GGML_TYPE_Q4_0:
974
996
case GGML_TYPE_Q4_1:
@@ -980,11 +1002,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
980
1002
return false;
981
1003
}
982
1004
case GGML_TYPE_F16:
1005
+ case GGML_TYPE_BF16:
983
1006
switch (op->type) {
984
- case GGML_TYPE_F32:
985
- case GGML_TYPE_F16:
1007
+ case GGML_TYPE_F32:
1008
+ case GGML_TYPE_F16:
1009
+ case GGML_TYPE_BF16:
986
1010
return true;
987
- default:
1011
+ default:
988
1012
return false;
989
1013
}
990
1014
default:
@@ -1855,6 +1879,7 @@ static void ggml_metal_encode_node(
1855
1879
switch (src0->type) {
1856
1880
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1857
1881
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1882
+ case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1858
1883
default: break;
1859
1884
}
1860
1885
@@ -1863,6 +1888,7 @@ static void ggml_metal_encode_node(
1863
1888
switch (src0->type) {
1864
1889
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
1865
1890
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
1891
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
1866
1892
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
1867
1893
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
1868
1894
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
@@ -1940,6 +1966,25 @@ static void ggml_metal_encode_node(
1940
1966
nrows = 4;
1941
1967
}
1942
1968
} break;
1969
+ case GGML_TYPE_BF16:
1970
+ {
1971
+ nth0 = 32;
1972
+ nth1 = 1;
1973
+ if (src1t == GGML_TYPE_F32) {
1974
+ if (ne11 * ne12 < 4) {
1975
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
1976
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1977
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
1978
+ nrows = ne11;
1979
+ } else {
1980
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
1981
+ nrows = 4;
1982
+ }
1983
+ } else {
1984
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
1985
+ nrows = 4;
1986
+ }
1987
+ } break;
1943
1988
case GGML_TYPE_Q4_0:
1944
1989
{
1945
1990
nth0 = 8;
@@ -2438,6 +2483,7 @@ static void ggml_metal_encode_node(
2438
2483
switch (src0->type) {
2439
2484
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
2440
2485
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
2486
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline; break;
2441
2487
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
2442
2488
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
2443
2489
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
@@ -3237,6 +3283,7 @@ static void ggml_metal_encode_node(
3237
3283
switch (dstt) {
3238
3284
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
3239
3285
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
3286
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
3240
3287
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
3241
3288
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
3242
3289
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
@@ -3254,6 +3301,13 @@ static void ggml_metal_encode_node(
3254
3301
default: GGML_ABORT("not implemented");
3255
3302
};
3256
3303
} break;
3304
+ case GGML_TYPE_BF16:
3305
+ {
3306
+ switch (dstt) {
3307
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
3308
+ default: GGML_ASSERT(false && "not implemented");
3309
+ };
3310
+ } break;
3257
3311
default: GGML_ABORT("not implemented");
3258
3312
}
3259
3313
0 commit comments